torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,878 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import gzip
|
|
9
|
+
import io
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import shutil
|
|
13
|
+
import subprocess
|
|
14
|
+
import tempfile
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
from tensordict import MemoryMappedTensor, TensorDict, TensorDictBase
|
|
22
|
+
from torch import multiprocessing as mp
|
|
23
|
+
from torchrl._utils import logger as torchrl_logger
|
|
24
|
+
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
|
|
25
|
+
from torchrl.data.replay_buffers.samplers import (
|
|
26
|
+
SamplerWithoutReplacement,
|
|
27
|
+
SliceSampler,
|
|
28
|
+
SliceSamplerWithoutReplacement,
|
|
29
|
+
)
|
|
30
|
+
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
|
|
31
|
+
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter
|
|
32
|
+
from torchrl.data.utils import CloudpickleWrapper
|
|
33
|
+
from torchrl.envs.utils import _classproperty
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AtariDQNExperienceReplay(BaseDatasetExperienceReplay):
|
|
37
|
+
"""Atari DQN Experience replay class.
|
|
38
|
+
|
|
39
|
+
The Atari DQN dataset (https://offline-rl.github.io/) is a collection of 5 training
|
|
40
|
+
iterations of DQN over each of the Arari 2600 games for a total of 200 million frames.
|
|
41
|
+
The sub-sampling rate (frame-skip) is equal to 4, meaning that each game dataset
|
|
42
|
+
has 50 million steps in total.
|
|
43
|
+
|
|
44
|
+
The data format follows the :ref:`TED convention <TED-format>`. Since the dataset is quite heavy,
|
|
45
|
+
the data formatting is done on-line, at sampling time.
|
|
46
|
+
|
|
47
|
+
To make training more modular, we split the dataset in each of the Atari games
|
|
48
|
+
and separate each training round. Consequently, each dataset is presented as
|
|
49
|
+
a Storage of length 50x10^6 elements. Under the hood, this dataset is split
|
|
50
|
+
in 50 memory-mapped tensordicts of length 1 million each.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
dataset_id (str): The dataset to be downloaded.
|
|
54
|
+
Must be part of ``AtariDQNExperienceReplay.available_datasets``.
|
|
55
|
+
batch_size (int): Batch-size used during sampling.
|
|
56
|
+
Can be overridden by `data.sample(batch_size)` if necessary.
|
|
57
|
+
|
|
58
|
+
Keyword Args:
|
|
59
|
+
root (Path or str, optional): The AtariDQN dataset root directory.
|
|
60
|
+
The actual dataset memory-mapped files will be saved under
|
|
61
|
+
`<root>/<dataset_id>`. If none is provided, it defaults to
|
|
62
|
+
`~/.cache/torchrl/atari`.atari`.
|
|
63
|
+
num_procs (int, optional): number of processes to launch for preprocessing.
|
|
64
|
+
Has no effect whenever the data is already downloaded. Defaults to 0
|
|
65
|
+
(no multiprocessing used).
|
|
66
|
+
download (bool or str, optional): Whether the dataset should be downloaded if
|
|
67
|
+
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
|
|
68
|
+
in which case the downloaded data will be overwritten.
|
|
69
|
+
sampler (Sampler, optional): the sampler to be used. If none is provided
|
|
70
|
+
a default RandomSampler() will be used.
|
|
71
|
+
writer (Writer, optional): the writer to be used. If none is provided
|
|
72
|
+
a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
|
|
73
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
74
|
+
mini-batch of Tensor(s)/outputs. Used when using batched
|
|
75
|
+
loading from a map-style dataset.
|
|
76
|
+
pin_memory (bool): whether pin_memory() should be called on the rb
|
|
77
|
+
samples.
|
|
78
|
+
prefetch (int, optional): number of next batches to be prefetched
|
|
79
|
+
using multithreading.
|
|
80
|
+
transform (Transform, optional): Transform to be executed when sample() is called.
|
|
81
|
+
To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
|
|
82
|
+
num_slices (int, optional): the number of slices to be sampled. The batch-size
|
|
83
|
+
must be greater or equal to the ``num_slices`` argument. Exclusive
|
|
84
|
+
with ``slice_len``. Defaults to ``None`` (no slice sampling).
|
|
85
|
+
The ``sampler`` arg will override this value.
|
|
86
|
+
slice_len (int, optional): the length of the slices to be sampled. The batch-size
|
|
87
|
+
must be greater or equal to the ``slice_len`` argument and divisible
|
|
88
|
+
by it. Exclusive with ``num_slices``. Defaults to ``None`` (no slice sampling).
|
|
89
|
+
The ``sampler`` arg will override this value.
|
|
90
|
+
strict_length (bool, optional): if ``False``, trajectories of length
|
|
91
|
+
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
|
92
|
+
allowed to appear in the batch.
|
|
93
|
+
Be mindful that this can result in effective `batch_size` shorter
|
|
94
|
+
than the one asked for! Trajectories can be split using
|
|
95
|
+
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
|
96
|
+
The ``sampler`` arg will override this value.
|
|
97
|
+
replacement (bool, optional): if ``False``, sampling will occur without replacement.
|
|
98
|
+
The ``sampler`` arg will override this value.
|
|
99
|
+
mp_start_method (str, optional): the start method for multiprocessed
|
|
100
|
+
download. Defaults to ``"fork"``.
|
|
101
|
+
|
|
102
|
+
Attributes:
|
|
103
|
+
available_datasets: list of available datasets, formatted as `<game_name>/<run>`. Example:
|
|
104
|
+
`"Pong/5"`, `"Krull/2"`, ...
|
|
105
|
+
dataset_id (str): the name of the dataset.
|
|
106
|
+
episodes (torch.Tensor): a 1d tensor indicating to what run each of the
|
|
107
|
+
1M frames belongs. To be used with :class:`~torchrl.data.replay_buffers.SliceSampler`
|
|
108
|
+
to cheaply sample slices of episodes.
|
|
109
|
+
|
|
110
|
+
Examples:
|
|
111
|
+
>>> from torchrl.data.datasets import AtariDQNExperienceReplay
|
|
112
|
+
>>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128)
|
|
113
|
+
>>> for data in dataset:
|
|
114
|
+
... print(data)
|
|
115
|
+
... break
|
|
116
|
+
TensorDict(
|
|
117
|
+
fields={
|
|
118
|
+
action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
119
|
+
done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
120
|
+
index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
121
|
+
metadata: NonTensorData(
|
|
122
|
+
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}},
|
|
123
|
+
batch_size=torch.Size([128]),
|
|
124
|
+
device=None,
|
|
125
|
+
is_shared=False),
|
|
126
|
+
next: TensorDict(
|
|
127
|
+
fields={
|
|
128
|
+
done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
129
|
+
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
130
|
+
reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
131
|
+
terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
132
|
+
truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
133
|
+
batch_size=torch.Size([128]),
|
|
134
|
+
device=None,
|
|
135
|
+
is_shared=False),
|
|
136
|
+
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
137
|
+
terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
138
|
+
truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
139
|
+
batch_size=torch.Size([128]),
|
|
140
|
+
device=None,
|
|
141
|
+
is_shared=False)
|
|
142
|
+
|
|
143
|
+
.. warning::
|
|
144
|
+
Atari-DQN does not provide the next observation after a termination signal.
|
|
145
|
+
In other words, there is no way to obtain the ``("next", "observation")`` state
|
|
146
|
+
when ``("next", "done")`` is ``True``. This value is filled with 0s but should
|
|
147
|
+
not be used in practice. If TorchRL's value estimators (:class:`~torchrl.objectives.values.ValueEstimator`)
|
|
148
|
+
are used, this should not be an issue.
|
|
149
|
+
|
|
150
|
+
.. note::
|
|
151
|
+
Because the construction of the sampler for episode sampling is slightly
|
|
152
|
+
convoluted, we made it easy for users to pass the arguments of the
|
|
153
|
+
:class:`~torchrl.data.replay_buffers.SliceSampler` directly to the
|
|
154
|
+
``AtariDQNExperienceReplay`` dataset: any of the ``num_slices`` or
|
|
155
|
+
``slice_len`` arguments will make the sampler an instance of
|
|
156
|
+
:class:`~torchrl.data.replay_buffers.SliceSampler`. The ``strict_length``
|
|
157
|
+
can also be passed.
|
|
158
|
+
|
|
159
|
+
>>> from torchrl.data.datasets import AtariDQNExperienceReplay
|
|
160
|
+
>>> from torchrl.data.replay_buffers import SliceSampler
|
|
161
|
+
>>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64)
|
|
162
|
+
>>> for data in dataset:
|
|
163
|
+
... print(data)
|
|
164
|
+
... print(data.get("index")) # indices are in 4 groups of consecutive values
|
|
165
|
+
... break
|
|
166
|
+
TensorDict(
|
|
167
|
+
fields={
|
|
168
|
+
action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
169
|
+
done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
170
|
+
index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
171
|
+
metadata: NonTensorData(
|
|
172
|
+
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}},
|
|
173
|
+
batch_size=torch.Size([128]),
|
|
174
|
+
device=None,
|
|
175
|
+
is_shared=False),
|
|
176
|
+
next: TensorDict(
|
|
177
|
+
fields={
|
|
178
|
+
done: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
179
|
+
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
180
|
+
reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
181
|
+
terminated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
182
|
+
truncated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
183
|
+
batch_size=torch.Size([128]),
|
|
184
|
+
device=None,
|
|
185
|
+
is_shared=False),
|
|
186
|
+
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
187
|
+
terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
188
|
+
truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
189
|
+
batch_size=torch.Size([128]),
|
|
190
|
+
device=None,
|
|
191
|
+
is_shared=False)
|
|
192
|
+
tensor([2657628, 2657629, 2657630, 2657631, 2657632, 2657633, 2657634, 2657635,
|
|
193
|
+
2657636, 2657637, 2657638, 2657639, 2657640, 2657641, 2657642, 2657643,
|
|
194
|
+
2657644, 2657645, 2657646, 2657647, 2657648, 2657649, 2657650, 2657651,
|
|
195
|
+
2657652, 2657653, 2657654, 2657655, 2657656, 2657657, 2657658, 2657659,
|
|
196
|
+
2657660, 2657661, 2657662, 2657663, 2657664, 2657665, 2657666, 2657667,
|
|
197
|
+
2657668, 2657669, 2657670, 2657671, 2657672, 2657673, 2657674, 2657675,
|
|
198
|
+
2657676, 2657677, 2657678, 2657679, 2657680, 2657681, 2657682, 2657683,
|
|
199
|
+
2657684, 2657685, 2657686, 2657687, 2657688, 2657689, 2657690, 2657691,
|
|
200
|
+
1995687, 1995688, 1995689, 1995690, 1995691, 1995692, 1995693, 1995694,
|
|
201
|
+
1995695, 1995696, 1995697, 1995698, 1995699, 1995700, 1995701, 1995702,
|
|
202
|
+
1995703, 1995704, 1995705, 1995706, 1995707, 1995708, 1995709, 1995710,
|
|
203
|
+
1995711, 1995712, 1995713, 1995714, 1995715, 1995716, 1995717, 1995718,
|
|
204
|
+
1995719, 1995720, 1995721, 1995722, 1995723, 1995724, 1995725, 1995726,
|
|
205
|
+
1995727, 1995728, 1995729, 1995730, 1995731, 1995732, 1995733, 1995734,
|
|
206
|
+
1995735, 1995736, 1995737, 1995738, 1995739, 1995740, 1995741, 1995742,
|
|
207
|
+
1995743, 1995744, 1995745, 1995746, 1995747, 1995748, 1995749, 1995750])
|
|
208
|
+
|
|
209
|
+
.. note::
|
|
210
|
+
As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble`:
|
|
211
|
+
|
|
212
|
+
>>> from torchrl.data.datasets import AtariDQNExperienceReplay
|
|
213
|
+
>>> from torchrl.data.replay_buffers import ReplayBufferEnsemble
|
|
214
|
+
>>> # we change this parameter for quick experimentation, in practice it should be left untouched
|
|
215
|
+
>>> AtariDQNExperienceReplay._max_runs = 2
|
|
216
|
+
>>> dataset_asterix = AtariDQNExperienceReplay("Asterix/5", batch_size=128, slice_len=64, num_procs=4)
|
|
217
|
+
>>> dataset_pong = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64, num_procs=4)
|
|
218
|
+
>>> dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True)
|
|
219
|
+
>>> sample = dataset.sample()
|
|
220
|
+
>>> print("first sample, Asterix", sample[0])
|
|
221
|
+
first sample, Asterix TensorDict(
|
|
222
|
+
fields={
|
|
223
|
+
action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
224
|
+
done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
225
|
+
index: TensorDict(
|
|
226
|
+
fields={
|
|
227
|
+
buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
228
|
+
index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
229
|
+
batch_size=torch.Size([64]),
|
|
230
|
+
device=None,
|
|
231
|
+
is_shared=False),
|
|
232
|
+
metadata: NonTensorData(
|
|
233
|
+
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'},
|
|
234
|
+
batch_size=torch.Size([64]),
|
|
235
|
+
device=None,
|
|
236
|
+
is_shared=False),
|
|
237
|
+
next: TensorDict(
|
|
238
|
+
fields={
|
|
239
|
+
done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
240
|
+
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
241
|
+
reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
242
|
+
terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
243
|
+
truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
244
|
+
batch_size=torch.Size([64]),
|
|
245
|
+
device=None,
|
|
246
|
+
is_shared=False),
|
|
247
|
+
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
248
|
+
terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
249
|
+
truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
250
|
+
batch_size=torch.Size([64]),
|
|
251
|
+
device=None,
|
|
252
|
+
is_shared=False)
|
|
253
|
+
>>> print("second sample, Pong", sample[1])
|
|
254
|
+
second sample, Pong TensorDict(
|
|
255
|
+
fields={
|
|
256
|
+
action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
257
|
+
done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
258
|
+
index: TensorDict(
|
|
259
|
+
fields={
|
|
260
|
+
buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
261
|
+
index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
262
|
+
batch_size=torch.Size([64]),
|
|
263
|
+
device=None,
|
|
264
|
+
is_shared=False),
|
|
265
|
+
metadata: NonTensorData(
|
|
266
|
+
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Asterix/5'},
|
|
267
|
+
batch_size=torch.Size([64]),
|
|
268
|
+
device=None,
|
|
269
|
+
is_shared=False),
|
|
270
|
+
next: TensorDict(
|
|
271
|
+
fields={
|
|
272
|
+
done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
273
|
+
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
274
|
+
reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
275
|
+
terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
276
|
+
truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
277
|
+
batch_size=torch.Size([64]),
|
|
278
|
+
device=None,
|
|
279
|
+
is_shared=False),
|
|
280
|
+
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
281
|
+
terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
282
|
+
truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
283
|
+
batch_size=torch.Size([64]),
|
|
284
|
+
device=None,
|
|
285
|
+
is_shared=False)
|
|
286
|
+
>>> print("Aggregate (metadata hidden)", sample)
|
|
287
|
+
Aggregate (metadata hidden) LazyStackedTensorDict(
|
|
288
|
+
fields={
|
|
289
|
+
action: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
290
|
+
done: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
291
|
+
index: LazyStackedTensorDict(
|
|
292
|
+
fields={
|
|
293
|
+
buffer_ids: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
294
|
+
index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
295
|
+
exclusive_fields={
|
|
296
|
+
},
|
|
297
|
+
batch_size=torch.Size([2, 64]),
|
|
298
|
+
device=None,
|
|
299
|
+
is_shared=False,
|
|
300
|
+
stack_dim=0),
|
|
301
|
+
metadata: LazyStackedTensorDict(
|
|
302
|
+
fields={
|
|
303
|
+
},
|
|
304
|
+
exclusive_fields={
|
|
305
|
+
},
|
|
306
|
+
batch_size=torch.Size([2, 64]),
|
|
307
|
+
device=None,
|
|
308
|
+
is_shared=False,
|
|
309
|
+
stack_dim=0),
|
|
310
|
+
next: LazyStackedTensorDict(
|
|
311
|
+
fields={
|
|
312
|
+
done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
313
|
+
observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
314
|
+
reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
315
|
+
terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
316
|
+
truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
317
|
+
exclusive_fields={
|
|
318
|
+
},
|
|
319
|
+
batch_size=torch.Size([2, 64]),
|
|
320
|
+
device=None,
|
|
321
|
+
is_shared=False,
|
|
322
|
+
stack_dim=0),
|
|
323
|
+
observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
324
|
+
terminated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
325
|
+
truncated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
326
|
+
exclusive_fields={
|
|
327
|
+
},
|
|
328
|
+
batch_size=torch.Size([2, 64]),
|
|
329
|
+
device=None,
|
|
330
|
+
is_shared=False,
|
|
331
|
+
stack_dim=0)
|
|
332
|
+
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
@_classproperty
|
|
336
|
+
def available_datasets(cls):
|
|
337
|
+
games = [
|
|
338
|
+
"AirRaid",
|
|
339
|
+
"Alien",
|
|
340
|
+
"Amidar",
|
|
341
|
+
"Assault",
|
|
342
|
+
"Asterix",
|
|
343
|
+
"Asteroids",
|
|
344
|
+
"Atlantis",
|
|
345
|
+
"BankHeist",
|
|
346
|
+
"BattleZone",
|
|
347
|
+
"BeamRider",
|
|
348
|
+
"Berzerk",
|
|
349
|
+
"Bowling",
|
|
350
|
+
"Boxing",
|
|
351
|
+
"Breakout",
|
|
352
|
+
"Carnival",
|
|
353
|
+
"Centipede",
|
|
354
|
+
"ChopperCommand",
|
|
355
|
+
"CrazyClimber",
|
|
356
|
+
"DemonAttack",
|
|
357
|
+
"DoubleDunk",
|
|
358
|
+
"ElevatorAction",
|
|
359
|
+
"Enduro",
|
|
360
|
+
"FishingDerby",
|
|
361
|
+
"Freeway",
|
|
362
|
+
"Frostbite",
|
|
363
|
+
"Gopher",
|
|
364
|
+
"Gravitar",
|
|
365
|
+
"Hero",
|
|
366
|
+
"IceHockey",
|
|
367
|
+
"Jamesbond",
|
|
368
|
+
"JourneyEscape",
|
|
369
|
+
"Kangaroo",
|
|
370
|
+
"Krull",
|
|
371
|
+
"KungFuMaster",
|
|
372
|
+
"MontezumaRevenge",
|
|
373
|
+
"MsPacman",
|
|
374
|
+
"NameThisGame",
|
|
375
|
+
"Phoenix",
|
|
376
|
+
"Pitfall",
|
|
377
|
+
"Pong",
|
|
378
|
+
"Pooyan",
|
|
379
|
+
"PrivateEye",
|
|
380
|
+
"Qbert",
|
|
381
|
+
"Riverraid",
|
|
382
|
+
"RoadRunner",
|
|
383
|
+
"Robotank",
|
|
384
|
+
"Seaquest",
|
|
385
|
+
"Skiing",
|
|
386
|
+
"Solaris",
|
|
387
|
+
"SpaceInvaders",
|
|
388
|
+
]
|
|
389
|
+
return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)]
|
|
390
|
+
|
|
391
|
+
# If we want to keep track of the original atari files
|
|
392
|
+
tmpdir = None
|
|
393
|
+
# use _max_runs for debugging, avoids downloading the entire dataset
|
|
394
|
+
_max_runs = None
|
|
395
|
+
|
|
396
|
+
def __init__(
|
|
397
|
+
self,
|
|
398
|
+
dataset_id: str,
|
|
399
|
+
batch_size: int | None = None,
|
|
400
|
+
*,
|
|
401
|
+
root: str | Path | None = None,
|
|
402
|
+
download: bool | str = True,
|
|
403
|
+
sampler=None,
|
|
404
|
+
writer=None,
|
|
405
|
+
transform: Transform | None = None, # noqa: F821
|
|
406
|
+
num_procs: int = 0,
|
|
407
|
+
num_slices: int | None = None,
|
|
408
|
+
slice_len: int | None = None,
|
|
409
|
+
strict_len: bool = True,
|
|
410
|
+
replacement: bool = True,
|
|
411
|
+
mp_start_method: str = "fork",
|
|
412
|
+
**kwargs,
|
|
413
|
+
):
|
|
414
|
+
import warnings
|
|
415
|
+
|
|
416
|
+
warnings.warn(
|
|
417
|
+
"This dataset is no longer available. We are working on a fix, or possibly a deprecation.",
|
|
418
|
+
DeprecationWarning,
|
|
419
|
+
)
|
|
420
|
+
if dataset_id not in self.available_datasets:
|
|
421
|
+
raise ValueError(
|
|
422
|
+
"The dataseet_id is not part of the available datasets. The dataset should be named <game_name>/<run> "
|
|
423
|
+
"where <game_name> is one of the Atari 2600 games and the run is a number between 1 and 5. "
|
|
424
|
+
"The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets."
|
|
425
|
+
)
|
|
426
|
+
self.dataset_id = dataset_id
|
|
427
|
+
from torchrl.data.datasets.utils import _get_root_dir
|
|
428
|
+
|
|
429
|
+
if root is None:
|
|
430
|
+
root = _get_root_dir("atari")
|
|
431
|
+
self.root = root
|
|
432
|
+
self.num_procs = num_procs
|
|
433
|
+
self.mp_start_method = mp_start_method
|
|
434
|
+
if download == "force" or (download and not self._is_downloaded):
|
|
435
|
+
try:
|
|
436
|
+
self._download_and_preproc()
|
|
437
|
+
except Exception:
|
|
438
|
+
# remove temporary data
|
|
439
|
+
if os.path.exists(self.dataset_path):
|
|
440
|
+
shutil.rmtree(self.dataset_path)
|
|
441
|
+
raise
|
|
442
|
+
if self._downloaded_and_preproc:
|
|
443
|
+
storage = TensorStorage(TensorDict.load_memmap(self.dataset_path))
|
|
444
|
+
else:
|
|
445
|
+
storage = _AtariStorage(self.dataset_path)
|
|
446
|
+
if writer is None:
|
|
447
|
+
writer = ImmutableDatasetWriter()
|
|
448
|
+
if sampler is None:
|
|
449
|
+
if num_slices is not None or slice_len is not None:
|
|
450
|
+
if not replacement:
|
|
451
|
+
sampler = SliceSamplerWithoutReplacement(
|
|
452
|
+
num_slices=num_slices,
|
|
453
|
+
slice_len=slice_len,
|
|
454
|
+
trajectories=storage.episodes,
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
sampler = SliceSampler(
|
|
458
|
+
num_slices=num_slices,
|
|
459
|
+
slice_len=slice_len,
|
|
460
|
+
trajectories=storage.episodes,
|
|
461
|
+
cache_values=True,
|
|
462
|
+
)
|
|
463
|
+
elif not replacement:
|
|
464
|
+
sampler = SamplerWithoutReplacement()
|
|
465
|
+
|
|
466
|
+
super().__init__(
|
|
467
|
+
storage=storage,
|
|
468
|
+
batch_size=batch_size,
|
|
469
|
+
writer=writer,
|
|
470
|
+
sampler=sampler,
|
|
471
|
+
collate_fn=lambda x: x,
|
|
472
|
+
transform=transform,
|
|
473
|
+
**kwargs,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
@property
|
|
477
|
+
def episodes(self):
|
|
478
|
+
return self._storage.episodes
|
|
479
|
+
|
|
480
|
+
@property
|
|
481
|
+
def root(self) -> Path:
|
|
482
|
+
return self._root
|
|
483
|
+
|
|
484
|
+
@root.setter
|
|
485
|
+
def root(self, value):
|
|
486
|
+
self._root = Path(value)
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def dataset_path(self) -> Path:
|
|
490
|
+
return self._root / self.dataset_id
|
|
491
|
+
|
|
492
|
+
@property
|
|
493
|
+
def _downloaded_and_preproc(self):
|
|
494
|
+
return os.path.exists(self.dataset_path / "meta.json")
|
|
495
|
+
|
|
496
|
+
@property
|
|
497
|
+
def _is_downloaded(self):
|
|
498
|
+
if os.path.exists(self.dataset_path / "meta.json"):
|
|
499
|
+
return True
|
|
500
|
+
if os.path.exists(self.dataset_path / "processed.json"):
|
|
501
|
+
with open(self.dataset_path / "processed.json") as jsonfile:
|
|
502
|
+
return json.load(jsonfile).get("processed", False) == self._max_runs
|
|
503
|
+
return False
|
|
504
|
+
|
|
505
|
+
def _download_and_preproc(self):
|
|
506
|
+
torchrl_logger.info(
|
|
507
|
+
f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while..."
|
|
508
|
+
)
|
|
509
|
+
if os.path.exists(self.dataset_path):
|
|
510
|
+
shutil.rmtree(self.dataset_path)
|
|
511
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
|
512
|
+
if self.tmpdir is not None:
|
|
513
|
+
tempdir = self.tmpdir
|
|
514
|
+
if not os.listdir(tempdir):
|
|
515
|
+
os.makedirs(tempdir, exist_ok=True)
|
|
516
|
+
# get the list of runs
|
|
517
|
+
try:
|
|
518
|
+
subprocess.run(
|
|
519
|
+
["gsutil", "version"], check=True, capture_output=True
|
|
520
|
+
)
|
|
521
|
+
except subprocess.CalledProcessError:
|
|
522
|
+
raise RuntimeError("gsutil is not installed or not found in PATH.")
|
|
523
|
+
command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs"
|
|
524
|
+
output = subprocess.run(
|
|
525
|
+
command, shell=True, capture_output=True
|
|
526
|
+
) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
527
|
+
files = [
|
|
528
|
+
file.decode("utf-8").replace("$", r"\$") # noqa: W605
|
|
529
|
+
for file in output.stdout.splitlines()
|
|
530
|
+
if file.endswith(b".gz")
|
|
531
|
+
]
|
|
532
|
+
self.remote_gz_files = self._list_runs(None, files)
|
|
533
|
+
remote_gz_files = list(self.remote_gz_files)
|
|
534
|
+
if not len(remote_gz_files):
|
|
535
|
+
raise RuntimeError("No files in file list.")
|
|
536
|
+
|
|
537
|
+
total_runs = remote_gz_files[-1]
|
|
538
|
+
if self.num_procs == 0:
|
|
539
|
+
for run, run_files in self.remote_gz_files.items():
|
|
540
|
+
self._download_and_proc_split(
|
|
541
|
+
run,
|
|
542
|
+
run_files,
|
|
543
|
+
tempdir=tempdir,
|
|
544
|
+
dataset_path=self.dataset_path,
|
|
545
|
+
total_episodes=total_runs,
|
|
546
|
+
max_runs=self._max_runs,
|
|
547
|
+
multithreaded=True,
|
|
548
|
+
)
|
|
549
|
+
else:
|
|
550
|
+
func = functools.partial(
|
|
551
|
+
self._download_and_proc_split,
|
|
552
|
+
tempdir=tempdir,
|
|
553
|
+
dataset_path=self.dataset_path,
|
|
554
|
+
total_episodes=total_runs,
|
|
555
|
+
max_runs=self._max_runs,
|
|
556
|
+
multithreaded=False,
|
|
557
|
+
)
|
|
558
|
+
args = [
|
|
559
|
+
(run, run_files)
|
|
560
|
+
for (run, run_files) in self.remote_gz_files.items()
|
|
561
|
+
]
|
|
562
|
+
ctx = mp.get_context(self.mp_start_method)
|
|
563
|
+
with ctx.Pool(self.num_procs) as pool:
|
|
564
|
+
pool.starmap(func, args)
|
|
565
|
+
with open(self.dataset_path / "processed.json", "w") as file:
|
|
566
|
+
# we save self._max_runs such that changing the number of runs to process
|
|
567
|
+
# forces the data to be re-downloaded
|
|
568
|
+
json.dump({"processed": self._max_runs}, file)
|
|
569
|
+
|
|
570
|
+
@classmethod
|
|
571
|
+
def _download_and_proc_split(
|
|
572
|
+
cls,
|
|
573
|
+
run,
|
|
574
|
+
run_files,
|
|
575
|
+
*,
|
|
576
|
+
tempdir,
|
|
577
|
+
dataset_path,
|
|
578
|
+
total_episodes,
|
|
579
|
+
max_runs,
|
|
580
|
+
multithreaded=True,
|
|
581
|
+
):
|
|
582
|
+
if (max_runs is not None) and (run >= max_runs):
|
|
583
|
+
return
|
|
584
|
+
tempdir = Path(tempdir)
|
|
585
|
+
os.makedirs(tempdir / str(run))
|
|
586
|
+
files_str = " ".join(run_files) # .decode("utf-8")
|
|
587
|
+
torchrl_logger.info(f"downloading {files_str}")
|
|
588
|
+
if multithreaded:
|
|
589
|
+
command = f"gsutil -m cp {files_str} {tempdir}/{run}"
|
|
590
|
+
else:
|
|
591
|
+
command = f"gsutil cp {files_str} {tempdir}/{run}"
|
|
592
|
+
subprocess.run(
|
|
593
|
+
command, shell=True
|
|
594
|
+
) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
595
|
+
local_gz_files = cls._list_runs(tempdir / str(run))
|
|
596
|
+
# we iterate over the dict but this one has length 1
|
|
597
|
+
for run in local_gz_files:
|
|
598
|
+
path = dataset_path / str(run)
|
|
599
|
+
try:
|
|
600
|
+
cls._preproc_run(path, local_gz_files, run)
|
|
601
|
+
except Exception:
|
|
602
|
+
shutil.rmtree(path)
|
|
603
|
+
raise
|
|
604
|
+
shutil.rmtree(tempdir / str(run))
|
|
605
|
+
torchrl_logger.info(f"Concluded run {run} out of {total_episodes}")
|
|
606
|
+
|
|
607
|
+
@classmethod
|
|
608
|
+
def _preproc_run(cls, path, gz_files, run):
|
|
609
|
+
files = gz_files[run]
|
|
610
|
+
td = TensorDict()
|
|
611
|
+
path = Path(path)
|
|
612
|
+
for file in files:
|
|
613
|
+
name = str(Path(file).parts[-1]).split(".")[0]
|
|
614
|
+
with gzip.GzipFile(file, mode="rb") as f:
|
|
615
|
+
file_content = f.read()
|
|
616
|
+
file_content = io.BytesIO(file_content)
|
|
617
|
+
file_content = np.load(file_content)
|
|
618
|
+
t = torch.as_tensor(file_content)
|
|
619
|
+
# Create the memmap file
|
|
620
|
+
key = cls._process_name(name)
|
|
621
|
+
if key == ("data", "observation"):
|
|
622
|
+
shape = t.shape
|
|
623
|
+
shape = [shape[0] + 1] + list(shape[1:])
|
|
624
|
+
filename = path / "data" / "observation.memmap"
|
|
625
|
+
os.makedirs(filename.parent, exist_ok=True)
|
|
626
|
+
mmap = MemoryMappedTensor.empty(shape, dtype=t.dtype, filename=filename)
|
|
627
|
+
mmap[:-1].copy_(t)
|
|
628
|
+
td[key] = mmap
|
|
629
|
+
# td["data", "next", key[1:]] = mmap[1:]
|
|
630
|
+
else:
|
|
631
|
+
if key in (
|
|
632
|
+
("data", "reward"),
|
|
633
|
+
("data", "done"),
|
|
634
|
+
("data", "terminated"),
|
|
635
|
+
):
|
|
636
|
+
filename = path / "data" / "next" / (key[-1] + ".memmap")
|
|
637
|
+
os.makedirs(filename.parent, exist_ok=True)
|
|
638
|
+
mmap = MemoryMappedTensor.from_tensor(t, filename=filename)
|
|
639
|
+
td["data", "next", key[1:]] = mmap
|
|
640
|
+
else:
|
|
641
|
+
filename = path
|
|
642
|
+
for i, _key in enumerate(key):
|
|
643
|
+
if i == len(key) - 1:
|
|
644
|
+
_key = _key + ".memmap"
|
|
645
|
+
filename = filename / _key
|
|
646
|
+
os.makedirs(filename.parent, exist_ok=True)
|
|
647
|
+
mmap = MemoryMappedTensor.from_tensor(t, filename=filename)
|
|
648
|
+
td[key] = mmap
|
|
649
|
+
td.set_non_tensor("dataset_id", "/".join(path.parts[-3:-1]))
|
|
650
|
+
td.memmap_(path, copy_existing=False)
|
|
651
|
+
|
|
652
|
+
@staticmethod
|
|
653
|
+
def _process_name(name):
|
|
654
|
+
if name.endswith("_ckpt"):
|
|
655
|
+
name = name[:-5]
|
|
656
|
+
if "store" in name:
|
|
657
|
+
key = ("data", name.split("_")[1])
|
|
658
|
+
else:
|
|
659
|
+
key = (name,)
|
|
660
|
+
if key[-1] == "terminal":
|
|
661
|
+
key = (*key[:-1], "terminated")
|
|
662
|
+
return key
|
|
663
|
+
|
|
664
|
+
@classmethod
|
|
665
|
+
def _list_runs(cls, download_path, gz_files=None) -> dict:
|
|
666
|
+
path = download_path
|
|
667
|
+
if gz_files is None:
|
|
668
|
+
gz_files = []
|
|
669
|
+
for root, _, files in os.walk(path):
|
|
670
|
+
for file in files:
|
|
671
|
+
if file.endswith(".gz"):
|
|
672
|
+
gz_files.append(os.path.join(root, file))
|
|
673
|
+
runs = defaultdict(list)
|
|
674
|
+
for file in gz_files:
|
|
675
|
+
filename = Path(file).parts[-1]
|
|
676
|
+
name, episode, extension = str(filename).split(".")
|
|
677
|
+
episode = int(episode)
|
|
678
|
+
runs[episode].append(file)
|
|
679
|
+
return dict(sorted(runs.items(), key=lambda x: x[0]))
|
|
680
|
+
|
|
681
|
+
def preprocess(
|
|
682
|
+
self,
|
|
683
|
+
fn: Callable[[TensorDictBase], TensorDictBase],
|
|
684
|
+
dim: int = 0,
|
|
685
|
+
num_workers: int | None = None,
|
|
686
|
+
*,
|
|
687
|
+
chunksize: int | None = None,
|
|
688
|
+
num_chunks: int | None = None,
|
|
689
|
+
pool: mp.Pool | None = None,
|
|
690
|
+
generator: torch.Generator | None = None,
|
|
691
|
+
max_tasks_per_child: int | None = None,
|
|
692
|
+
worker_threads: int = 1,
|
|
693
|
+
index_with_generator: bool = False,
|
|
694
|
+
pbar: bool = False,
|
|
695
|
+
mp_start_method: str | None = None,
|
|
696
|
+
dest: str | Path,
|
|
697
|
+
num_frames: int | None = None,
|
|
698
|
+
):
|
|
699
|
+
# Copy data to a tensordict
|
|
700
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
701
|
+
first_item = self[0]
|
|
702
|
+
metadata = first_item.pop("metadata")
|
|
703
|
+
|
|
704
|
+
mmap = fn(first_item)
|
|
705
|
+
if num_frames is None:
|
|
706
|
+
num_frames = len(self)
|
|
707
|
+
mmap = mmap.expand(num_frames, *first_item.shape)
|
|
708
|
+
mmap = mmap.memmap_like(tmpdir, num_threads=32)
|
|
709
|
+
with mmap.unlock_():
|
|
710
|
+
mmap["_indices"] = torch.arange(mmap.shape[0])
|
|
711
|
+
mmap.memmap_(tmpdir, num_threads=32)
|
|
712
|
+
|
|
713
|
+
def func(mmap: TensorDictBase):
|
|
714
|
+
idx = mmap["_indices"]
|
|
715
|
+
orig = self[idx].exclude("metadata")
|
|
716
|
+
orig = fn(orig)
|
|
717
|
+
mmap.update(orig, inplace=True)
|
|
718
|
+
return
|
|
719
|
+
|
|
720
|
+
if dim != 0:
|
|
721
|
+
raise RuntimeError("dim != 0 is not supported.")
|
|
722
|
+
|
|
723
|
+
mmap.map(
|
|
724
|
+
fn=CloudpickleWrapper(func),
|
|
725
|
+
dim=dim,
|
|
726
|
+
num_workers=num_workers,
|
|
727
|
+
chunksize=chunksize,
|
|
728
|
+
num_chunks=num_chunks,
|
|
729
|
+
pool=pool,
|
|
730
|
+
generator=generator,
|
|
731
|
+
max_tasks_per_child=max_tasks_per_child,
|
|
732
|
+
worker_threads=worker_threads,
|
|
733
|
+
index_with_generator=index_with_generator,
|
|
734
|
+
mp_start_method=mp_start_method,
|
|
735
|
+
pbar=pbar,
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
with mmap.unlock_():
|
|
739
|
+
return TensorStorage(mmap.set("metadata", metadata))
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
class _AtariStorage(Storage):
|
|
743
|
+
def __init__(self, path):
|
|
744
|
+
self.path = Path(path)
|
|
745
|
+
|
|
746
|
+
def get_folders(path):
|
|
747
|
+
return [
|
|
748
|
+
name
|
|
749
|
+
for name in os.listdir(path)
|
|
750
|
+
if os.path.isdir(os.path.join(path, name))
|
|
751
|
+
]
|
|
752
|
+
|
|
753
|
+
# Usage
|
|
754
|
+
self.splits = []
|
|
755
|
+
folders = get_folders(path)
|
|
756
|
+
for folder in folders:
|
|
757
|
+
self.splits.append(int(Path(folder).parts[-1]))
|
|
758
|
+
self.splits = sorted(self.splits)
|
|
759
|
+
self._split_tds = []
|
|
760
|
+
frames_per_split = {}
|
|
761
|
+
for split in self.splits:
|
|
762
|
+
path = self.path / str(split)
|
|
763
|
+
self._split_tds.append(self._load_split(path))
|
|
764
|
+
# take away 1 because we padded with 1 empty val
|
|
765
|
+
frames_per_split[split] = (
|
|
766
|
+
self._split_tds[-1].get(("data", "observation")).shape[0] - 1
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
frames_per_split = torch.tensor(
|
|
770
|
+
[[split, length] for (split, length) in frames_per_split.items()]
|
|
771
|
+
)
|
|
772
|
+
frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0)
|
|
773
|
+
self.frames_per_split = torch.cat(
|
|
774
|
+
# [torch.tensor([[-1, 0]]), frames_per_split], 0
|
|
775
|
+
[torch.tensor([[-1, 0]]), frames_per_split],
|
|
776
|
+
0,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# retrieve episodes
|
|
780
|
+
self.episodes = torch.cumsum(
|
|
781
|
+
torch.cat(
|
|
782
|
+
[td.get(("data", "next", "terminated")) for td in self._split_tds], 0
|
|
783
|
+
),
|
|
784
|
+
0,
|
|
785
|
+
)
|
|
786
|
+
super().__init__(max_size=len(self))
|
|
787
|
+
|
|
788
|
+
def __len__(self):
|
|
789
|
+
return self.frames_per_split[-1, 1].item()
|
|
790
|
+
|
|
791
|
+
def _read_from_splits(self, item: int | torch.Tensor):
|
|
792
|
+
# We need to allocate each item to its storage.
|
|
793
|
+
# We don't assume each storage has the same size (too expensive to test)
|
|
794
|
+
# so we keep a map of each storage cumulative length and retrieve the
|
|
795
|
+
# storages one after the other.
|
|
796
|
+
item = torch.as_tensor(item)
|
|
797
|
+
if not item.ndim:
|
|
798
|
+
is_int = True
|
|
799
|
+
item = item.reshape(-1)
|
|
800
|
+
else:
|
|
801
|
+
is_int = False
|
|
802
|
+
split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (
|
|
803
|
+
item >= self.frames_per_split[:-1, 1].unsqueeze(1)
|
|
804
|
+
)
|
|
805
|
+
# split_tmp, idx = split.squeeze().nonzero().unbind(-1)
|
|
806
|
+
split_tmp, idx = split.nonzero().unbind(-1)
|
|
807
|
+
split = split_tmp.squeeze()
|
|
808
|
+
idx = idx.squeeze()
|
|
809
|
+
|
|
810
|
+
if not is_int:
|
|
811
|
+
split = torch.zeros_like(split_tmp)
|
|
812
|
+
split[idx] = split_tmp
|
|
813
|
+
split = self.frames_per_split[split + 1, 0]
|
|
814
|
+
item = item - self.frames_per_split[split, 1]
|
|
815
|
+
if is_int:
|
|
816
|
+
item = item.squeeze()
|
|
817
|
+
return self._proc_td(self._split_tds[split], item)
|
|
818
|
+
unique_splits, split_inverse = torch.unique(split, return_inverse=True)
|
|
819
|
+
unique_splits = unique_splits.tolist()
|
|
820
|
+
out = []
|
|
821
|
+
for i, split in enumerate(unique_splits):
|
|
822
|
+
_item = item[split_inverse == i] if split_inverse is not None else item
|
|
823
|
+
out.append(self._proc_td(self._split_tds[split], _item))
|
|
824
|
+
return torch.cat(out, 0)
|
|
825
|
+
|
|
826
|
+
def _load_split(self, path):
|
|
827
|
+
return TensorDict.load_memmap(path)
|
|
828
|
+
|
|
829
|
+
def _proc_td(self, td, index):
|
|
830
|
+
td_data = td.get("data")
|
|
831
|
+
obs_ = td_data.get("observation")[index + 1]
|
|
832
|
+
done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool()
|
|
833
|
+
if done.ndim and done.any():
|
|
834
|
+
obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0)
|
|
835
|
+
td_idx = td.empty()
|
|
836
|
+
td_idx.set(("next", "observation"), obs_)
|
|
837
|
+
non_tensor = td.exclude("data").to_dict()
|
|
838
|
+
td_idx.update(td_data.apply(lambda x: x[index]))
|
|
839
|
+
if isinstance(index, torch.Tensor) and index.ndim:
|
|
840
|
+
td_idx.batch_size = [len(index)]
|
|
841
|
+
td_idx.set_non_tensor("metadata", non_tensor)
|
|
842
|
+
|
|
843
|
+
terminated = td_idx.get(("next", "terminated"))
|
|
844
|
+
zterminated = torch.zeros_like(terminated)
|
|
845
|
+
td_idx.set(("next", "done"), terminated.clone())
|
|
846
|
+
td_idx.set(("next", "truncated"), zterminated)
|
|
847
|
+
td_idx.set("terminated", zterminated)
|
|
848
|
+
td_idx.set("done", zterminated)
|
|
849
|
+
td_idx.set("truncated", zterminated)
|
|
850
|
+
|
|
851
|
+
return td_idx
|
|
852
|
+
|
|
853
|
+
def get(self, index):
|
|
854
|
+
if isinstance(index, int):
|
|
855
|
+
return self._read_from_splits(index)
|
|
856
|
+
if isinstance(index, tuple):
|
|
857
|
+
if len(index) == 1:
|
|
858
|
+
return self.get(index[0])
|
|
859
|
+
return self.get(index[0])[(Ellipsis, *index[1:])]
|
|
860
|
+
if isinstance(index, torch.Tensor):
|
|
861
|
+
if index.ndim <= 1:
|
|
862
|
+
return self._read_from_splits(index)
|
|
863
|
+
elif index.shape[1] == 1:
|
|
864
|
+
index = index.squeeze(1)
|
|
865
|
+
return self.get(index)
|
|
866
|
+
else:
|
|
867
|
+
raise RuntimeError("Only 1d tensors are accepted")
|
|
868
|
+
# with ThreadPoolExecutor(16) as pool:
|
|
869
|
+
# results = map(self.__getitem__, index.tolist())
|
|
870
|
+
# return torch.stack(list(results))
|
|
871
|
+
if isinstance(index, (range, list)):
|
|
872
|
+
return self[torch.tensor(index)]
|
|
873
|
+
if isinstance(index, slice):
|
|
874
|
+
start = index.start if index.start is not None else 0
|
|
875
|
+
stop = index.stop if index.stop is not None else len(self)
|
|
876
|
+
step = index.step if index.step is not None else 1
|
|
877
|
+
return self.get(torch.arange(start, stop, step))
|
|
878
|
+
return self[torch.arange(len(self))[index]]
|