torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from functools import partial
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from omegaconf import MISSING
|
|
13
|
+
|
|
14
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
15
|
+
from torchrl.trainers.algorithms.configs.envs import EnvConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class BaseCollectorConfig(ConfigBase):
|
|
20
|
+
"""Parent class to configure a data collector."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class CollectorConfig(BaseCollectorConfig):
|
|
25
|
+
"""A class to configure a synchronous data collector (Collector)."""
|
|
26
|
+
|
|
27
|
+
create_env_fn: ConfigBase = MISSING
|
|
28
|
+
policy: Any = None
|
|
29
|
+
policy_factory: Any = None
|
|
30
|
+
frames_per_batch: int | None = None
|
|
31
|
+
total_frames: int = -1
|
|
32
|
+
init_random_frames: int | None = 0
|
|
33
|
+
device: str | None = None
|
|
34
|
+
storing_device: str | None = None
|
|
35
|
+
policy_device: str | None = None
|
|
36
|
+
env_device: str | None = None
|
|
37
|
+
create_env_kwargs: dict | None = None
|
|
38
|
+
max_frames_per_traj: int | None = None
|
|
39
|
+
reset_at_each_iter: bool = False
|
|
40
|
+
postproc: Any = None
|
|
41
|
+
split_trajs: bool = False
|
|
42
|
+
exploration_type: str = "RANDOM"
|
|
43
|
+
return_same_td: bool = False
|
|
44
|
+
interruptor: Any = None
|
|
45
|
+
set_truncated: bool = False
|
|
46
|
+
use_buffers: bool = False
|
|
47
|
+
replay_buffer: Any = None
|
|
48
|
+
extend_buffer: bool = False
|
|
49
|
+
trust_policy: bool = True
|
|
50
|
+
compile_policy: Any = None
|
|
51
|
+
cudagraph_policy: Any = None
|
|
52
|
+
no_cuda_sync: bool = False
|
|
53
|
+
weight_updater: Any = None
|
|
54
|
+
weight_sync_schemes: Any = None
|
|
55
|
+
track_policy_version: bool = False
|
|
56
|
+
local_init_rb: bool = False
|
|
57
|
+
_target_: str = "torchrl.collectors.Collector"
|
|
58
|
+
_partial_: bool = False
|
|
59
|
+
|
|
60
|
+
def __post_init__(self):
|
|
61
|
+
self.create_env_fn._partial_ = True
|
|
62
|
+
if self.policy_factory is not None:
|
|
63
|
+
self.policy_factory._partial_ = True
|
|
64
|
+
if self.weight_updater is not None:
|
|
65
|
+
self.weight_updater._partial_ = True
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# Legacy alias
|
|
69
|
+
SyncDataCollectorConfig = CollectorConfig
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class AsyncCollectorConfig(BaseCollectorConfig):
|
|
74
|
+
"""Configuration for asynchronous data collector (AsyncCollector)."""
|
|
75
|
+
|
|
76
|
+
create_env_fn: ConfigBase = field(
|
|
77
|
+
default_factory=partial(EnvConfig, _partial_=True)
|
|
78
|
+
)
|
|
79
|
+
policy: Any = None
|
|
80
|
+
policy_factory: Any = None
|
|
81
|
+
frames_per_batch: int | None = None
|
|
82
|
+
init_random_frames: int | None = 0
|
|
83
|
+
total_frames: int = -1
|
|
84
|
+
device: str | None = None
|
|
85
|
+
storing_device: str | None = None
|
|
86
|
+
policy_device: str | None = None
|
|
87
|
+
env_device: str | None = None
|
|
88
|
+
create_env_kwargs: dict | None = None
|
|
89
|
+
max_frames_per_traj: int | None = None
|
|
90
|
+
reset_at_each_iter: bool = False
|
|
91
|
+
postproc: ConfigBase | None = None
|
|
92
|
+
split_trajs: bool = False
|
|
93
|
+
exploration_type: str = "RANDOM"
|
|
94
|
+
set_truncated: bool = False
|
|
95
|
+
use_buffers: bool = False
|
|
96
|
+
replay_buffer: ConfigBase | None = None
|
|
97
|
+
extend_buffer: bool = False
|
|
98
|
+
trust_policy: bool = True
|
|
99
|
+
compile_policy: Any = None
|
|
100
|
+
cudagraph_policy: Any = None
|
|
101
|
+
no_cuda_sync: bool = False
|
|
102
|
+
weight_updater: Any = None
|
|
103
|
+
weight_sync_schemes: Any = None
|
|
104
|
+
track_policy_version: bool = False
|
|
105
|
+
local_init_rb: bool = False
|
|
106
|
+
_target_: str = "torchrl.collectors.AsyncCollector"
|
|
107
|
+
_partial_: bool = False
|
|
108
|
+
|
|
109
|
+
def __post_init__(self):
|
|
110
|
+
self.create_env_fn._partial_ = True
|
|
111
|
+
if self.policy_factory is not None:
|
|
112
|
+
self.policy_factory._partial_ = True
|
|
113
|
+
if self.weight_updater is not None:
|
|
114
|
+
self.weight_updater._partial_ = True
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# Legacy alias
|
|
118
|
+
AsyncDataCollectorConfig = AsyncCollectorConfig
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass
|
|
122
|
+
class MultiSyncCollectorConfig(BaseCollectorConfig):
|
|
123
|
+
"""Configuration for multi-synchronous data collector (MultiSyncCollector)."""
|
|
124
|
+
|
|
125
|
+
create_env_fn: Any = MISSING
|
|
126
|
+
num_workers: int | None = None
|
|
127
|
+
policy: Any = None
|
|
128
|
+
policy_factory: Any = None
|
|
129
|
+
frames_per_batch: int | None = None
|
|
130
|
+
init_random_frames: int | None = 0
|
|
131
|
+
total_frames: int = -1
|
|
132
|
+
device: str | None = None
|
|
133
|
+
storing_device: str | None = None
|
|
134
|
+
policy_device: str | None = None
|
|
135
|
+
env_device: str | None = None
|
|
136
|
+
create_env_kwargs: dict | None = None
|
|
137
|
+
max_frames_per_traj: int | None = None
|
|
138
|
+
reset_at_each_iter: bool = False
|
|
139
|
+
postproc: ConfigBase | None = None
|
|
140
|
+
split_trajs: bool = False
|
|
141
|
+
exploration_type: str = "RANDOM"
|
|
142
|
+
set_truncated: bool = False
|
|
143
|
+
use_buffers: bool = False
|
|
144
|
+
replay_buffer: ConfigBase | None = None
|
|
145
|
+
extend_buffer: bool = False
|
|
146
|
+
trust_policy: bool = True
|
|
147
|
+
compile_policy: Any = None
|
|
148
|
+
cudagraph_policy: Any = None
|
|
149
|
+
no_cuda_sync: bool = False
|
|
150
|
+
weight_updater: Any = None
|
|
151
|
+
weight_sync_schemes: Any = None
|
|
152
|
+
track_policy_version: bool = False
|
|
153
|
+
local_init_rb: bool = False
|
|
154
|
+
_target_: str = "torchrl.collectors.MultiSyncCollector"
|
|
155
|
+
_partial_: bool = False
|
|
156
|
+
|
|
157
|
+
def __post_init__(self):
|
|
158
|
+
for env_cfg in self.create_env_fn:
|
|
159
|
+
env_cfg._partial_ = True
|
|
160
|
+
if self.policy_factory is not None:
|
|
161
|
+
self.policy_factory._partial_ = True
|
|
162
|
+
if self.weight_updater is not None:
|
|
163
|
+
self.weight_updater._partial_ = True
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# Legacy alias
|
|
167
|
+
MultiSyncCollectorConfig = MultiSyncCollectorConfig
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclass
|
|
171
|
+
class MultiAsyncCollectorConfig(BaseCollectorConfig):
|
|
172
|
+
"""Configuration for multi-asynchronous data collector (MultiAsyncCollector)."""
|
|
173
|
+
|
|
174
|
+
create_env_fn: Any = MISSING
|
|
175
|
+
num_workers: int | None = None
|
|
176
|
+
policy: Any = None
|
|
177
|
+
policy_factory: Any = None
|
|
178
|
+
frames_per_batch: int | None = None
|
|
179
|
+
init_random_frames: int | None = 0
|
|
180
|
+
total_frames: int = -1
|
|
181
|
+
device: str | None = None
|
|
182
|
+
storing_device: str | None = None
|
|
183
|
+
policy_device: str | None = None
|
|
184
|
+
env_device: str | None = None
|
|
185
|
+
create_env_kwargs: dict | None = None
|
|
186
|
+
max_frames_per_traj: int | None = None
|
|
187
|
+
reset_at_each_iter: bool = False
|
|
188
|
+
postproc: ConfigBase | None = None
|
|
189
|
+
split_trajs: bool = False
|
|
190
|
+
exploration_type: str = "RANDOM"
|
|
191
|
+
set_truncated: bool = False
|
|
192
|
+
use_buffers: bool = False
|
|
193
|
+
replay_buffer: ConfigBase | None = None
|
|
194
|
+
extend_buffer: bool = False
|
|
195
|
+
trust_policy: bool = True
|
|
196
|
+
compile_policy: Any = None
|
|
197
|
+
cudagraph_policy: Any = None
|
|
198
|
+
no_cuda_sync: bool = False
|
|
199
|
+
weight_updater: Any = None
|
|
200
|
+
weight_sync_schemes: Any = None
|
|
201
|
+
track_policy_version: bool = False
|
|
202
|
+
local_init_rb: bool = False
|
|
203
|
+
_target_: str = "torchrl.collectors.MultiAsyncCollector"
|
|
204
|
+
_partial_: bool = False
|
|
205
|
+
|
|
206
|
+
def __post_init__(self):
|
|
207
|
+
for env_cfg in self.create_env_fn:
|
|
208
|
+
env_cfg._partial_ = True
|
|
209
|
+
if self.policy_factory is not None:
|
|
210
|
+
self.policy_factory._partial_ = True
|
|
211
|
+
if self.weight_updater is not None:
|
|
212
|
+
self.weight_updater._partial_ = True
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# Legacy alias
|
|
216
|
+
MultiAsyncCollectorConfig = MultiAsyncCollectorConfig
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ConfigBase(ABC):
|
|
16
|
+
"""Abstract base class for all configuration classes.
|
|
17
|
+
|
|
18
|
+
This class serves as the foundation for all configuration classes in the
|
|
19
|
+
configurable configuration system, providing a common interface and structure.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
"""Post-initialization hook for configuration classes."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Config:
|
|
29
|
+
"""A flexible config that allows arbitrary fields."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, **kwargs):
|
|
32
|
+
self._config = DictConfig(kwargs)
|
|
33
|
+
|
|
34
|
+
def __getattr__(self, name):
|
|
35
|
+
return getattr(self._config, name)
|
|
36
|
+
|
|
37
|
+
def __setattr__(self, name, value):
|
|
38
|
+
if name == "_config":
|
|
39
|
+
super().__setattr__(name, value)
|
|
40
|
+
else:
|
|
41
|
+
setattr(self._config, name, value)
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from omegaconf import MISSING
|
|
12
|
+
|
|
13
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class WriterConfig(ConfigBase):
|
|
18
|
+
"""Base configuration class for replay buffer writers."""
|
|
19
|
+
|
|
20
|
+
_target_: str = "torchrl.data.replay_buffers.Writer"
|
|
21
|
+
|
|
22
|
+
def __post_init__(self) -> None:
|
|
23
|
+
"""Post-initialization hook for writer configurations."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class RoundRobinWriterConfig(WriterConfig):
|
|
28
|
+
"""Configuration for round-robin writer that distributes data across multiple storages."""
|
|
29
|
+
|
|
30
|
+
_target_: str = "torchrl.data.replay_buffers.RoundRobinWriter"
|
|
31
|
+
compilable: bool = False
|
|
32
|
+
|
|
33
|
+
def __post_init__(self) -> None:
|
|
34
|
+
"""Post-initialization hook for round-robin writer configurations."""
|
|
35
|
+
super().__post_init__()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class SamplerConfig(ConfigBase):
|
|
40
|
+
"""Base configuration class for replay buffer samplers."""
|
|
41
|
+
|
|
42
|
+
_target_: str = "torchrl.data.replay_buffers.Sampler"
|
|
43
|
+
|
|
44
|
+
def __post_init__(self) -> None:
|
|
45
|
+
"""Post-initialization hook for sampler configurations."""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class RandomSamplerConfig(SamplerConfig):
|
|
50
|
+
"""Configuration for random sampling from replay buffer."""
|
|
51
|
+
|
|
52
|
+
_target_: str = "torchrl.data.replay_buffers.RandomSampler"
|
|
53
|
+
|
|
54
|
+
def __post_init__(self) -> None:
|
|
55
|
+
"""Post-initialization hook for random sampler configurations."""
|
|
56
|
+
super().__post_init__()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class WriterEnsembleConfig(WriterConfig):
|
|
61
|
+
"""Configuration for ensemble writer that combines multiple writers."""
|
|
62
|
+
|
|
63
|
+
_target_: str = "torchrl.data.replay_buffers.WriterEnsemble"
|
|
64
|
+
writers: list[Any] = field(default_factory=list)
|
|
65
|
+
p: Any = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class TensorDictMaxValueWriterConfig(WriterConfig):
|
|
70
|
+
"""Configuration for TensorDict max value writer."""
|
|
71
|
+
|
|
72
|
+
_target_: str = "torchrl.data.replay_buffers.TensorDictMaxValueWriter"
|
|
73
|
+
rank_key: Any = None
|
|
74
|
+
reduction: str = "sum"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class TensorDictRoundRobinWriterConfig(WriterConfig):
|
|
79
|
+
"""Configuration for TensorDict round-robin writer."""
|
|
80
|
+
|
|
81
|
+
_target_: str = "torchrl.data.replay_buffers.TensorDictRoundRobinWriter"
|
|
82
|
+
compilable: bool = False
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class ImmutableDatasetWriterConfig(WriterConfig):
|
|
87
|
+
"""Configuration for immutable dataset writer."""
|
|
88
|
+
|
|
89
|
+
_target_: str = "torchrl.data.replay_buffers.ImmutableDatasetWriter"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class SamplerEnsembleConfig(SamplerConfig):
|
|
94
|
+
"""Configuration for ensemble sampler that combines multiple samplers."""
|
|
95
|
+
|
|
96
|
+
_target_: str = "torchrl.data.replay_buffers.SamplerEnsemble"
|
|
97
|
+
samplers: list[Any] = field(default_factory=list)
|
|
98
|
+
p: Any = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class PrioritizedSliceSamplerConfig(SamplerConfig):
|
|
103
|
+
"""Configuration for prioritized slice sampling from replay buffer."""
|
|
104
|
+
|
|
105
|
+
num_slices: int | None = None
|
|
106
|
+
slice_len: int | None = None
|
|
107
|
+
end_key: Any = None
|
|
108
|
+
traj_key: Any = None
|
|
109
|
+
ends: Any = None
|
|
110
|
+
trajectories: Any = None
|
|
111
|
+
cache_values: bool = False
|
|
112
|
+
truncated_key: Any = ("next", "truncated")
|
|
113
|
+
strict_length: bool = True
|
|
114
|
+
compile: Any = False
|
|
115
|
+
span: Any = False
|
|
116
|
+
use_gpu: Any = False
|
|
117
|
+
max_capacity: int | None = None
|
|
118
|
+
alpha: float | None = None
|
|
119
|
+
beta: float | None = None
|
|
120
|
+
eps: float | None = None
|
|
121
|
+
reduction: str | None = None
|
|
122
|
+
_target_: str = "torchrl.data.replay_buffers.PrioritizedSliceSampler"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass
|
|
126
|
+
class SliceSamplerWithoutReplacementConfig(SamplerConfig):
|
|
127
|
+
"""Configuration for slice sampling without replacement."""
|
|
128
|
+
|
|
129
|
+
_target_: str = "torchrl.data.replay_buffers.SliceSamplerWithoutReplacement"
|
|
130
|
+
num_slices: int | None = None
|
|
131
|
+
slice_len: int | None = None
|
|
132
|
+
end_key: Any = None
|
|
133
|
+
traj_key: Any = None
|
|
134
|
+
ends: Any = None
|
|
135
|
+
trajectories: Any = None
|
|
136
|
+
cache_values: bool = False
|
|
137
|
+
truncated_key: Any = ("next", "truncated")
|
|
138
|
+
strict_length: bool = True
|
|
139
|
+
compile: Any = False
|
|
140
|
+
span: Any = False
|
|
141
|
+
use_gpu: Any = False
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class SliceSamplerConfig(SamplerConfig):
|
|
146
|
+
"""Configuration for slice sampling from replay buffer."""
|
|
147
|
+
|
|
148
|
+
_target_: str = "torchrl.data.replay_buffers.SliceSampler"
|
|
149
|
+
num_slices: int | None = None
|
|
150
|
+
slice_len: int | None = None
|
|
151
|
+
end_key: Any = None
|
|
152
|
+
traj_key: Any = None
|
|
153
|
+
ends: Any = None
|
|
154
|
+
trajectories: Any = None
|
|
155
|
+
cache_values: bool = False
|
|
156
|
+
truncated_key: Any = ("next", "truncated")
|
|
157
|
+
strict_length: bool = True
|
|
158
|
+
compile: Any = False
|
|
159
|
+
span: Any = False
|
|
160
|
+
use_gpu: Any = False
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@dataclass
|
|
164
|
+
class PrioritizedSamplerConfig(SamplerConfig):
|
|
165
|
+
"""Configuration for prioritized sampling from replay buffer."""
|
|
166
|
+
|
|
167
|
+
max_capacity: int | None = None
|
|
168
|
+
alpha: float | None = None
|
|
169
|
+
beta: float | None = None
|
|
170
|
+
eps: float | None = None
|
|
171
|
+
reduction: str | None = None
|
|
172
|
+
_target_: str = "torchrl.data.replay_buffers.PrioritizedSampler"
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@dataclass
|
|
176
|
+
class SamplerWithoutReplacementConfig(SamplerConfig):
|
|
177
|
+
"""Configuration for sampling without replacement."""
|
|
178
|
+
|
|
179
|
+
_target_: str = "torchrl.data.replay_buffers.SamplerWithoutReplacement"
|
|
180
|
+
drop_last: bool = False
|
|
181
|
+
shuffle: bool = True
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@dataclass
|
|
185
|
+
class StorageConfig(ConfigBase):
|
|
186
|
+
"""Base configuration class for replay buffer storage."""
|
|
187
|
+
|
|
188
|
+
_partial_: bool = False
|
|
189
|
+
_target_: str = "torchrl.data.replay_buffers.Storage"
|
|
190
|
+
|
|
191
|
+
def __post_init__(self) -> None:
|
|
192
|
+
"""Post-initialization hook for storage configurations."""
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclass
|
|
196
|
+
class TensorStorageConfig(StorageConfig):
|
|
197
|
+
"""Configuration for tensor-based storage in replay buffer."""
|
|
198
|
+
|
|
199
|
+
_target_: str = "torchrl.data.replay_buffers.TensorStorage"
|
|
200
|
+
max_size: int | None = None
|
|
201
|
+
storage: Any = None
|
|
202
|
+
device: Any = None
|
|
203
|
+
ndim: int | None = None
|
|
204
|
+
compilable: bool = False
|
|
205
|
+
|
|
206
|
+
def __post_init__(self) -> None:
|
|
207
|
+
"""Post-initialization hook for tensor storage configurations."""
|
|
208
|
+
super().__post_init__()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dataclass
|
|
212
|
+
class ListStorageConfig(StorageConfig):
|
|
213
|
+
"""Configuration for list-based storage in replay buffer."""
|
|
214
|
+
|
|
215
|
+
_target_: str = "torchrl.data.replay_buffers.ListStorage"
|
|
216
|
+
max_size: int | None = None
|
|
217
|
+
compilable: bool = False
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@dataclass
|
|
221
|
+
class StorageEnsembleWriterConfig(StorageConfig):
|
|
222
|
+
"""Configuration for storage ensemble writer."""
|
|
223
|
+
|
|
224
|
+
_target_: str = "torchrl.data.replay_buffers.StorageEnsembleWriter"
|
|
225
|
+
writers: list[Any] = MISSING
|
|
226
|
+
transforms: list[Any] = MISSING
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@dataclass
|
|
230
|
+
class LazyStackStorageConfig(StorageConfig):
|
|
231
|
+
"""Configuration for lazy stack storage."""
|
|
232
|
+
|
|
233
|
+
_target_: str = "torchrl.data.replay_buffers.LazyStackStorage"
|
|
234
|
+
max_size: int | None = None
|
|
235
|
+
compilable: bool = False
|
|
236
|
+
stack_dim: int = 0
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
@dataclass
|
|
240
|
+
class StorageEnsembleConfig(StorageConfig):
|
|
241
|
+
"""Configuration for storage ensemble."""
|
|
242
|
+
|
|
243
|
+
_target_: str = "torchrl.data.replay_buffers.StorageEnsemble"
|
|
244
|
+
storages: list[Any] = MISSING
|
|
245
|
+
transforms: list[Any] = MISSING
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@dataclass
|
|
249
|
+
class LazyMemmapStorageConfig(StorageConfig):
|
|
250
|
+
"""Configuration for lazy memory-mapped storage."""
|
|
251
|
+
|
|
252
|
+
_target_: str = "torchrl.data.replay_buffers.LazyMemmapStorage"
|
|
253
|
+
max_size: int | None = None
|
|
254
|
+
device: Any = None
|
|
255
|
+
ndim: int = 1
|
|
256
|
+
compilable: bool = False
|
|
257
|
+
shared_init: bool = False
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@dataclass
|
|
261
|
+
class LazyTensorStorageConfig(StorageConfig):
|
|
262
|
+
"""Configuration for lazy tensor storage."""
|
|
263
|
+
|
|
264
|
+
_target_: str = "torchrl.data.replay_buffers.LazyTensorStorage"
|
|
265
|
+
max_size: int | None = None
|
|
266
|
+
device: Any = None
|
|
267
|
+
ndim: int = 1
|
|
268
|
+
compilable: bool = False
|
|
269
|
+
shared_init: bool = False
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@dataclass
|
|
273
|
+
class ReplayBufferBaseConfig(ConfigBase):
|
|
274
|
+
"""Base configuration class for replay buffers."""
|
|
275
|
+
|
|
276
|
+
_partial_: bool = False
|
|
277
|
+
|
|
278
|
+
def __post_init__(self) -> None:
|
|
279
|
+
"""Post-initialization hook for replay buffer configurations."""
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@dataclass
|
|
283
|
+
class TensorDictReplayBufferConfig(ReplayBufferBaseConfig):
|
|
284
|
+
"""Configuration for TensorDict-based replay buffer."""
|
|
285
|
+
|
|
286
|
+
_target_: str = "torchrl.data.replay_buffers.TensorDictReplayBuffer"
|
|
287
|
+
sampler: Any = None
|
|
288
|
+
storage: Any = None
|
|
289
|
+
writer: Any = None
|
|
290
|
+
transform: Any = None
|
|
291
|
+
batch_size: int | None = None
|
|
292
|
+
|
|
293
|
+
def __post_init__(self) -> None:
|
|
294
|
+
"""Post-initialization hook for TensorDict replay buffer configurations."""
|
|
295
|
+
super().__post_init__()
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@dataclass
|
|
299
|
+
class ReplayBufferConfig(ReplayBufferBaseConfig):
|
|
300
|
+
"""Configuration for generic replay buffer."""
|
|
301
|
+
|
|
302
|
+
_target_: str = "torchrl.data.replay_buffers.ReplayBuffer"
|
|
303
|
+
sampler: Any = None
|
|
304
|
+
storage: Any = None
|
|
305
|
+
writer: Any = None
|
|
306
|
+
transform: Any = None
|
|
307
|
+
batch_size: int | None = None
|
|
308
|
+
shared: bool = False
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from omegaconf import MISSING
|
|
12
|
+
|
|
13
|
+
from torchrl.envs.common import EnvBase
|
|
14
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EnvConfig(ConfigBase):
|
|
19
|
+
"""Base configuration class for environments."""
|
|
20
|
+
|
|
21
|
+
_partial_: bool = False
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
"""Post-initialization hook for environment configurations."""
|
|
25
|
+
self._partial_ = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class BatchedEnvConfig(EnvConfig):
|
|
30
|
+
"""Configuration for batched environments."""
|
|
31
|
+
|
|
32
|
+
create_env_fn: Any = MISSING
|
|
33
|
+
num_workers: int = 1
|
|
34
|
+
create_env_kwargs: dict = field(default_factory=dict)
|
|
35
|
+
batched_env_type: str = "parallel"
|
|
36
|
+
device: str | None = None
|
|
37
|
+
# batched_env_type: Literal["parallel", "serial", "async"] = "parallel"
|
|
38
|
+
_target_: str = "torchrl.trainers.algorithms.configs.envs.make_batched_env"
|
|
39
|
+
|
|
40
|
+
def __post_init__(self) -> None:
|
|
41
|
+
"""Post-initialization hook for batched environment configurations."""
|
|
42
|
+
super().__post_init__()
|
|
43
|
+
if hasattr(self.create_env_fn, "_partial_"):
|
|
44
|
+
self.create_env_fn._partial_ = True
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class TransformedEnvConfig(EnvConfig):
|
|
49
|
+
"""Configuration for transformed environments."""
|
|
50
|
+
|
|
51
|
+
base_env: Any = MISSING
|
|
52
|
+
transform: Any = None
|
|
53
|
+
cache_specs: bool = True
|
|
54
|
+
auto_unwrap: bool | None = None
|
|
55
|
+
_target_: str = "torchrl.envs.TransformedEnv"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def make_batched_env(
|
|
59
|
+
create_env_fn, num_workers, batched_env_type="parallel", device=None, **kwargs
|
|
60
|
+
):
|
|
61
|
+
"""Create a batched environment.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
create_env_fn: Function to create individual environments or environment instance.
|
|
65
|
+
num_workers: Number of worker environments.
|
|
66
|
+
batched_env_type: Type of batched environment (parallel, serial, async).
|
|
67
|
+
device: Device to place the batched environment on.
|
|
68
|
+
**kwargs: Additional keyword arguments.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The created batched environment instance.
|
|
72
|
+
"""
|
|
73
|
+
from torchrl.envs import AsyncEnvPool, ParallelEnv, SerialEnv
|
|
74
|
+
|
|
75
|
+
if create_env_fn is None:
|
|
76
|
+
raise ValueError("create_env_fn must be provided")
|
|
77
|
+
|
|
78
|
+
if num_workers is None:
|
|
79
|
+
raise ValueError("num_workers must be provided")
|
|
80
|
+
|
|
81
|
+
# If create_env_fn is a config object, create a lambda that instantiates it each time
|
|
82
|
+
if isinstance(create_env_fn, EnvBase):
|
|
83
|
+
# Already an instance (either instantiated config or actual env), wrap in lambda
|
|
84
|
+
env_instance = create_env_fn
|
|
85
|
+
|
|
86
|
+
def env_fn(env_instance=env_instance):
|
|
87
|
+
return env_instance
|
|
88
|
+
|
|
89
|
+
else:
|
|
90
|
+
env_fn = create_env_fn
|
|
91
|
+
assert callable(env_fn), env_fn
|
|
92
|
+
|
|
93
|
+
# Add device to kwargs if provided
|
|
94
|
+
if device is not None:
|
|
95
|
+
kwargs["device"] = device
|
|
96
|
+
|
|
97
|
+
if batched_env_type == "parallel":
|
|
98
|
+
return ParallelEnv(num_workers, env_fn, **kwargs)
|
|
99
|
+
elif batched_env_type == "serial":
|
|
100
|
+
return SerialEnv(num_workers, env_fn, **kwargs)
|
|
101
|
+
elif batched_env_type == "async":
|
|
102
|
+
return AsyncEnvPool([env_fn] * num_workers, **kwargs)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Unknown batched_env_type: {batched_env_type}")
|