torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch.nn
|
|
8
|
+
import torch.optim
|
|
9
|
+
from torchrl.data import Composite
|
|
10
|
+
from torchrl.envs import RewardSum, StepCounter, TransformedEnv
|
|
11
|
+
from torchrl.envs.libs.gym import GymEnv
|
|
12
|
+
from torchrl.modules import MLP, QValueActor
|
|
13
|
+
from torchrl.record import VideoRecorder
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# ====================================================================
|
|
17
|
+
# Environment utils
|
|
18
|
+
# --------------------------------------------------------------------
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False):
|
|
22
|
+
env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
|
|
23
|
+
env = TransformedEnv(env)
|
|
24
|
+
env.append_transform(RewardSum())
|
|
25
|
+
env.append_transform(StepCounter())
|
|
26
|
+
return env
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# ====================================================================
|
|
30
|
+
# Model utils
|
|
31
|
+
# --------------------------------------------------------------------
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def make_dqn_modules(proof_environment, device):
|
|
35
|
+
|
|
36
|
+
# Define input shape
|
|
37
|
+
input_shape = proof_environment.observation_spec["observation"].shape
|
|
38
|
+
env_specs = proof_environment.specs
|
|
39
|
+
num_outputs = env_specs["input_spec", "full_action_spec", "action"].space.n
|
|
40
|
+
action_spec = env_specs["input_spec", "full_action_spec", "action"]
|
|
41
|
+
|
|
42
|
+
# Define Q-Value Module
|
|
43
|
+
mlp = MLP(
|
|
44
|
+
in_features=input_shape[-1],
|
|
45
|
+
activation_class=torch.nn.ReLU,
|
|
46
|
+
out_features=num_outputs,
|
|
47
|
+
num_cells=[120, 84],
|
|
48
|
+
device=device,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
qvalue_module = QValueActor(
|
|
52
|
+
module=mlp,
|
|
53
|
+
spec=Composite(action=action_spec).to(device),
|
|
54
|
+
in_keys=["observation"],
|
|
55
|
+
)
|
|
56
|
+
return qvalue_module
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def make_dqn_model(env_name, device):
|
|
60
|
+
proof_environment = make_env(env_name, device=device)
|
|
61
|
+
qvalue_module = make_dqn_modules(proof_environment, device=device)
|
|
62
|
+
del proof_environment
|
|
63
|
+
return qvalue_module
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ====================================================================
|
|
67
|
+
# Evaluation utils
|
|
68
|
+
# --------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def eval_model(actor, test_env, num_episodes=3):
|
|
72
|
+
test_rewards = torch.zeros(num_episodes, dtype=torch.float32)
|
|
73
|
+
for i in range(num_episodes):
|
|
74
|
+
td_test = test_env.rollout(
|
|
75
|
+
policy=actor,
|
|
76
|
+
auto_reset=True,
|
|
77
|
+
auto_cast_to_device=True,
|
|
78
|
+
break_when_any_done=True,
|
|
79
|
+
max_steps=10_000_000,
|
|
80
|
+
)
|
|
81
|
+
test_env.apply(dump_video)
|
|
82
|
+
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
|
|
83
|
+
test_rewards[i] = reward.sum()
|
|
84
|
+
del td_test
|
|
85
|
+
return test_rewards.mean()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def dump_video(module):
|
|
89
|
+
if isinstance(module, VideoRecorder):
|
|
90
|
+
module.dump()
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# Dreamer V1
|
|
2
|
+
|
|
3
|
+
This is an implementation of the Dreamer algorithm from the paper
|
|
4
|
+
["Dream to Control: Learning Behaviors by Latent Imagination"](https://arxiv.org/abs/1912.01603) (Hafner et al., ICLR 2020).
|
|
5
|
+
|
|
6
|
+
Dreamer is a model-based reinforcement learning algorithm that:
|
|
7
|
+
1. Learns a **world model** (RSSM) from experience
|
|
8
|
+
2. **Imagines** future trajectories in latent space
|
|
9
|
+
3. Trains **actor and critic** using analytic gradients through the imagined rollouts
|
|
10
|
+
|
|
11
|
+
## Setup
|
|
12
|
+
|
|
13
|
+
### Dependencies
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
# Create virtual environment
|
|
17
|
+
uv venv torchrl --python 3.12
|
|
18
|
+
source torchrl/bin/activate
|
|
19
|
+
|
|
20
|
+
# Install PyTorch (adjust for your CUDA version)
|
|
21
|
+
uv pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
|
|
22
|
+
|
|
23
|
+
# Install TorchRL and TensorDict
|
|
24
|
+
uv pip install tensordict torchrl
|
|
25
|
+
|
|
26
|
+
# Install additional dependencies
|
|
27
|
+
uv pip install mujoco dm_control wandb tqdm hydra-core
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
### System Dependencies (for MuJoCo rendering)
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
apt-get update && apt-get install -y \
|
|
34
|
+
libegl1 \
|
|
35
|
+
libgl1 \
|
|
36
|
+
libgles2 \
|
|
37
|
+
libglvnd0
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
### Environment Variables
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
export MUJOCO_GL=egl
|
|
44
|
+
export MUJOCO_EGL_DEVICE_ID=0
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
## Running
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
python dreamer.py
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
### Configuration
|
|
54
|
+
|
|
55
|
+
The default configuration trains on DMControl's `cheetah-run` task. You can override settings via command line:
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
# Different environment
|
|
59
|
+
python dreamer.py env.name=walker env.task=walk
|
|
60
|
+
|
|
61
|
+
# Mixed precision options: false, true (=bfloat16), float16, bfloat16
|
|
62
|
+
python dreamer.py optimization.autocast=bfloat16 # default
|
|
63
|
+
python dreamer.py optimization.autocast=float16 # for older GPUs
|
|
64
|
+
python dreamer.py optimization.autocast=false # disable autocast
|
|
65
|
+
|
|
66
|
+
# Adjust batch size
|
|
67
|
+
python dreamer.py replay_buffer.batch_size=1000
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## Known Caveats
|
|
71
|
+
|
|
72
|
+
### 1. Mixed Precision (Autocast) Compatibility
|
|
73
|
+
|
|
74
|
+
Some GPU/cuBLAS combinations have issues with `bfloat16` autocast, resulting in:
|
|
75
|
+
```
|
|
76
|
+
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasGemmEx
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
**Solutions:**
|
|
80
|
+
- Try float16: `optimization.autocast=float16`
|
|
81
|
+
- Or disable autocast entirely: `optimization.autocast=false`
|
|
82
|
+
|
|
83
|
+
Note: Ensure your PyTorch CUDA version matches your driver. For example, with CUDA 13.0:
|
|
84
|
+
```bash
|
|
85
|
+
uv pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu130
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### 2. Benchmarking Status
|
|
89
|
+
|
|
90
|
+
This implementation has not been fully benchmarked against the original paper's results.
|
|
91
|
+
Performance may differ from published numbers.
|
|
92
|
+
|
|
93
|
+
### 3. Video Logging
|
|
94
|
+
|
|
95
|
+
To enable video logging of both real and imagined rollouts:
|
|
96
|
+
```bash
|
|
97
|
+
python dreamer.py logger.video=True
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
This requires additional setup for rendering and significantly increases computation time.
|
|
101
|
+
|
|
102
|
+
## Architecture Overview
|
|
103
|
+
|
|
104
|
+
```
|
|
105
|
+
World Model:
|
|
106
|
+
- ObsEncoder: pixels -> encoded_latents
|
|
107
|
+
- RSSMPrior: (state, belief, action) -> next_belief, prior_dist
|
|
108
|
+
- RSSMPosterior: (belief, encoded_latents) -> posterior_dist, state
|
|
109
|
+
- ObsDecoder: (state, belief) -> reconstructed_pixels
|
|
110
|
+
- RewardModel: (state, belief) -> predicted_reward
|
|
111
|
+
|
|
112
|
+
Actor: (state, belief) -> action_distribution
|
|
113
|
+
Critic: (state, belief) -> state_value
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
## Training Loop
|
|
117
|
+
|
|
118
|
+
1. **Collect** real experience from environment
|
|
119
|
+
2. **Train world model** on sequences from replay buffer (KL + reconstruction + reward loss)
|
|
120
|
+
3. **Imagine** trajectories starting from encoded real states
|
|
121
|
+
4. **Train actor** to maximize imagined returns (gradients flow through dynamics)
|
|
122
|
+
5. **Train critic** to predict lambda returns on imagined trajectories
|
|
123
|
+
|
|
124
|
+
## References
|
|
125
|
+
|
|
126
|
+
- Original Paper: [Dream to Control: Learning Behaviors by Latent Imagination](https://arxiv.org/abs/1912.01603)
|
|
127
|
+
- PlaNet (predecessor): [Learning Latent Dynamics for Planning from Pixels](https://arxiv.org/abs/1811.04551)
|
|
128
|
+
- DreamerV2: [Mastering Atari with Discrete World Models](https://arxiv.org/abs/2010.02193)
|
|
129
|
+
- DreamerV3: [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104)
|