torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from torchrl.modules.tensordict_module.common import DistributionalDQNnet
|
|
8
|
+
|
|
9
|
+
from .batchrenorm import BatchRenorm1d
|
|
10
|
+
|
|
11
|
+
from .decision_transformer import DecisionTransformer
|
|
12
|
+
from .exploration import (
|
|
13
|
+
ConsistentDropout,
|
|
14
|
+
ConsistentDropoutModule,
|
|
15
|
+
NoisyLazyLinear,
|
|
16
|
+
NoisyLinear,
|
|
17
|
+
reset_noise,
|
|
18
|
+
)
|
|
19
|
+
from .llm import GPT2RewardModel
|
|
20
|
+
from .model_based import (
|
|
21
|
+
DreamerActor,
|
|
22
|
+
ObsDecoder,
|
|
23
|
+
ObsEncoder,
|
|
24
|
+
RSSMPosterior,
|
|
25
|
+
RSSMPrior,
|
|
26
|
+
RSSMRollout,
|
|
27
|
+
)
|
|
28
|
+
from .models import (
|
|
29
|
+
Conv2dNet,
|
|
30
|
+
Conv3dNet,
|
|
31
|
+
ConvNet,
|
|
32
|
+
DdpgCnnActor,
|
|
33
|
+
DdpgCnnQNet,
|
|
34
|
+
DdpgMlpActor,
|
|
35
|
+
DdpgMlpQNet,
|
|
36
|
+
DTActor,
|
|
37
|
+
DuelingCnnDQNet,
|
|
38
|
+
DuelingMlpDQNet,
|
|
39
|
+
MLP,
|
|
40
|
+
OnlineDTActor,
|
|
41
|
+
)
|
|
42
|
+
from .multiagent import (
|
|
43
|
+
MultiAgentConvNet,
|
|
44
|
+
MultiAgentMLP,
|
|
45
|
+
MultiAgentNetBase,
|
|
46
|
+
QMixer,
|
|
47
|
+
VDNMixer,
|
|
48
|
+
)
|
|
49
|
+
from .utils import Squeeze2dLayer, SqueezeLayer
|
|
50
|
+
|
|
51
|
+
__all__ = [
|
|
52
|
+
"DistributionalDQNnet",
|
|
53
|
+
"BatchRenorm1d",
|
|
54
|
+
"DecisionTransformer",
|
|
55
|
+
"GPT2RewardModel",
|
|
56
|
+
"ConsistentDropout",
|
|
57
|
+
"ConsistentDropoutModule",
|
|
58
|
+
"NoisyLazyLinear",
|
|
59
|
+
"NoisyLinear",
|
|
60
|
+
"reset_noise",
|
|
61
|
+
"DreamerActor",
|
|
62
|
+
"ObsDecoder",
|
|
63
|
+
"ObsEncoder",
|
|
64
|
+
"RSSMPosterior",
|
|
65
|
+
"RSSMPrior",
|
|
66
|
+
"RSSMRollout",
|
|
67
|
+
"Conv2dNet",
|
|
68
|
+
"Conv3dNet",
|
|
69
|
+
"ConvNet",
|
|
70
|
+
"DdpgCnnActor",
|
|
71
|
+
"DdpgCnnQNet",
|
|
72
|
+
"DdpgMlpActor",
|
|
73
|
+
"DdpgMlpQNet",
|
|
74
|
+
"DTActor",
|
|
75
|
+
"DuelingCnnDQNet",
|
|
76
|
+
"DuelingMlpDQNet",
|
|
77
|
+
"MLP",
|
|
78
|
+
"OnlineDTActor",
|
|
79
|
+
"MultiAgentConvNet",
|
|
80
|
+
"MultiAgentMLP",
|
|
81
|
+
"MultiAgentNetBase",
|
|
82
|
+
"QMixer",
|
|
83
|
+
"VDNMixer",
|
|
84
|
+
"Squeeze2dLayer",
|
|
85
|
+
"SqueezeLayer",
|
|
86
|
+
]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BatchRenorm1d(nn.Module):
|
|
12
|
+
"""BatchRenorm Module (https://arxiv.org/abs/1702.03275).
|
|
13
|
+
|
|
14
|
+
The code is adapted from https://github.com/google-research/corenet
|
|
15
|
+
|
|
16
|
+
BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm,
|
|
17
|
+
it utilizes running statistics to normalize batches after an initial warmup phase.
|
|
18
|
+
This approach reduces the impact of "outlier" batches that may occur during
|
|
19
|
+
extended training periods, making BatchRenorm more robust for long training runs.
|
|
20
|
+
|
|
21
|
+
During the warmup phase, BatchRenorm functions identically to a BatchNorm layer.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
num_features (int): Number of features in the input tensor.
|
|
25
|
+
|
|
26
|
+
Keyword Args:
|
|
27
|
+
momentum (:obj:`float`, optional): Momentum factor for computing the running mean and variance.
|
|
28
|
+
Defaults to ``0.01``.
|
|
29
|
+
eps (:obj:`float`, optional): Small value added to the variance to avoid division by zero.
|
|
30
|
+
Defaults to ``1e-5``.
|
|
31
|
+
max_r (:obj:`float`, optional): Maximum value for the scaling factor r.
|
|
32
|
+
Defaults to ``3.0``.
|
|
33
|
+
max_d (:obj:`float`, optional): Maximum value for the bias factor d.
|
|
34
|
+
Defaults to ``5.0``.
|
|
35
|
+
warmup_steps (int, optional): Number of warm-up steps for the running mean and variance.
|
|
36
|
+
Defaults to ``10000``.
|
|
37
|
+
smooth (bool, optional): if ``True``, the behavior smoothly transitions from regular
|
|
38
|
+
batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``).
|
|
39
|
+
Otherwise, the behavior will transition from batch-norm to batch-renorm when
|
|
40
|
+
``iter=warmup_steps``. Defaults to ``False``.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
num_features: int,
|
|
46
|
+
*,
|
|
47
|
+
momentum: float = 0.01,
|
|
48
|
+
eps: float = 1e-5,
|
|
49
|
+
max_r: float = 3.0,
|
|
50
|
+
max_d: float = 5.0,
|
|
51
|
+
warmup_steps: int = 10000,
|
|
52
|
+
smooth: bool = False,
|
|
53
|
+
):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.num_features = num_features
|
|
56
|
+
self.eps = eps
|
|
57
|
+
self.momentum = momentum
|
|
58
|
+
self.max_r = max_r
|
|
59
|
+
self.max_d = max_d
|
|
60
|
+
self.warmup_steps = warmup_steps
|
|
61
|
+
self.smooth = smooth
|
|
62
|
+
|
|
63
|
+
self.register_buffer(
|
|
64
|
+
"running_mean", torch.zeros(num_features, dtype=torch.float32)
|
|
65
|
+
)
|
|
66
|
+
self.register_buffer(
|
|
67
|
+
"running_var", torch.ones(num_features, dtype=torch.float32)
|
|
68
|
+
)
|
|
69
|
+
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.int64))
|
|
70
|
+
self.weight = nn.Parameter(torch.ones(num_features, dtype=torch.float32))
|
|
71
|
+
self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32))
|
|
72
|
+
|
|
73
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
74
|
+
if not x.dim() >= 2:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2)
|
|
80
|
+
|
|
81
|
+
def _v(v):
|
|
82
|
+
return v.view(view_dims)
|
|
83
|
+
|
|
84
|
+
running_std = (self.running_var + self.eps).sqrt_()
|
|
85
|
+
|
|
86
|
+
if self.training:
|
|
87
|
+
reduce_dims = [i for i in range(x.dim()) if i != 1]
|
|
88
|
+
b_mean = x.mean(reduce_dims)
|
|
89
|
+
b_var = x.var(reduce_dims, unbiased=False)
|
|
90
|
+
b_std = (b_var + self.eps).sqrt_()
|
|
91
|
+
|
|
92
|
+
r = torch.clamp((b_std.detach() / running_std), 1 / self.max_r, self.max_r)
|
|
93
|
+
d = torch.clamp(
|
|
94
|
+
(b_mean.detach() - self.running_mean) / running_std,
|
|
95
|
+
-self.max_d,
|
|
96
|
+
self.max_d,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Compute warmup factor (0 during warmup, 1 after warmup)
|
|
100
|
+
if self.warmup_steps > 0:
|
|
101
|
+
if self.smooth:
|
|
102
|
+
warmup_factor = self.num_batches_tracked / self.warmup_steps
|
|
103
|
+
else:
|
|
104
|
+
warmup_factor = self.num_batches_tracked // self.warmup_steps
|
|
105
|
+
r = 1.0 + (r - 1.0) * warmup_factor
|
|
106
|
+
d = d * warmup_factor
|
|
107
|
+
|
|
108
|
+
x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d)
|
|
109
|
+
|
|
110
|
+
unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1)
|
|
111
|
+
self.running_var += self.momentum * (unbiased_var - self.running_var)
|
|
112
|
+
self.running_mean += self.momentum * (b_mean.detach() - self.running_mean)
|
|
113
|
+
self.num_batches_tracked += 1
|
|
114
|
+
self.num_batches_tracked.clamp_max(self.warmup_steps)
|
|
115
|
+
else:
|
|
116
|
+
x = (x - _v(self.running_mean)) / _v(running_std)
|
|
117
|
+
|
|
118
|
+
x = _v(self.weight) * x + _v(self.bias)
|
|
119
|
+
return x
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import dataclasses
|
|
8
|
+
import importlib
|
|
9
|
+
from contextlib import nullcontext
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
|
|
16
|
+
_has_transformers = importlib.util.find_spec("transformers") is not None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DecisionTransformer(nn.Module):
|
|
20
|
+
"""Online Decision Transformer.
|
|
21
|
+
|
|
22
|
+
Desdescribed in https://arxiv.org/abs/2202.05607 .
|
|
23
|
+
|
|
24
|
+
The transformer utilizes a default config to create the GPT2 model if the user does not provide a specific config.
|
|
25
|
+
default_config = {
|
|
26
|
+
... "n_embd": 256,
|
|
27
|
+
... "n_layer": 4,
|
|
28
|
+
... "n_head": 4,
|
|
29
|
+
... "n_inner": 1024,
|
|
30
|
+
... "activation": "relu",
|
|
31
|
+
... "n_positions": 1024,
|
|
32
|
+
... "resid_pdrop": 0.1,
|
|
33
|
+
... "attn_pdrop": 0.1,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
state_dim (int): dimension of the state space
|
|
38
|
+
action_dim (int): dimension of the action space
|
|
39
|
+
config (:obj:`~.DTConfig` or dict, optional): transformer architecture configuration,
|
|
40
|
+
used to create the GPT2Config from transformers.
|
|
41
|
+
Defaults to ``default_config``.
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
>>> config = DecisionTransformer.default_config()
|
|
46
|
+
>>> config.n_embd = 128
|
|
47
|
+
>>> print(config)
|
|
48
|
+
DTConfig(n_embd: 128, n_layer: 4, n_head: 4, n_inner: 1024, activation: relu, n_positions: 1024, resid_pdrop: 0.1, attn_pdrop: 0.1)
|
|
49
|
+
>>> # alternatively
|
|
50
|
+
>>> config = DecisionTransformer.DTConfig(n_embd=128)
|
|
51
|
+
>>> model = DecisionTransformer(state_dim=4, action_dim=2, config=config)
|
|
52
|
+
>>> batch_size = [3, 32]
|
|
53
|
+
>>> length = 10
|
|
54
|
+
>>> observation = torch.randn(*batch_size, length, 4)
|
|
55
|
+
>>> action = torch.randn(*batch_size, length, 2)
|
|
56
|
+
>>> return_to_go = torch.randn(*batch_size, length, 1)
|
|
57
|
+
>>> output = model(observation, action, return_to_go)
|
|
58
|
+
>>> output.shape
|
|
59
|
+
torch.Size([3, 32, 10, 128])
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class DTConfig:
|
|
65
|
+
"""Default configuration for DecisionTransformer."""
|
|
66
|
+
|
|
67
|
+
n_embd: Any = 256
|
|
68
|
+
n_layer: Any = 4
|
|
69
|
+
n_head: Any = 4
|
|
70
|
+
n_inner: Any = 1024
|
|
71
|
+
activation: Any = "relu"
|
|
72
|
+
n_positions: Any = 1024
|
|
73
|
+
resid_pdrop: Any = 0.1
|
|
74
|
+
attn_pdrop: Any = 0.1
|
|
75
|
+
|
|
76
|
+
def __repr__(self):
|
|
77
|
+
fields = []
|
|
78
|
+
for f in dataclasses.fields(self):
|
|
79
|
+
value = getattr(self, f.name)
|
|
80
|
+
fields.append(f"{f.name}: {value}")
|
|
81
|
+
fields = ", ".join(fields)
|
|
82
|
+
return f"{self.__class__.__name__}({fields})"
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def default_config(cls):
|
|
86
|
+
return cls.DTConfig()
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
state_dim,
|
|
91
|
+
action_dim,
|
|
92
|
+
config: dict | DTConfig = None,
|
|
93
|
+
device: torch.device | None = None,
|
|
94
|
+
):
|
|
95
|
+
|
|
96
|
+
if not _has_transformers:
|
|
97
|
+
raise ImportError(
|
|
98
|
+
"transformers is not installed. Please install it with `pip install transformers`."
|
|
99
|
+
)
|
|
100
|
+
import transformers
|
|
101
|
+
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
|
102
|
+
|
|
103
|
+
if config is None:
|
|
104
|
+
config = self.default_config()
|
|
105
|
+
if isinstance(config, self.DTConfig):
|
|
106
|
+
config = dataclasses.asdict(config)
|
|
107
|
+
if not isinstance(config, dict):
|
|
108
|
+
try:
|
|
109
|
+
config = dict(config)
|
|
110
|
+
except Exception as err:
|
|
111
|
+
raise TypeError(
|
|
112
|
+
f"Config of type {type(config)} is not supported."
|
|
113
|
+
) from err
|
|
114
|
+
|
|
115
|
+
super().__init__()
|
|
116
|
+
|
|
117
|
+
with torch.device(device) if device is not None else nullcontext():
|
|
118
|
+
gpt_config = transformers.GPT2Config(
|
|
119
|
+
n_embd=config["n_embd"],
|
|
120
|
+
n_layer=config["n_layer"],
|
|
121
|
+
n_head=config["n_head"],
|
|
122
|
+
n_inner=config["n_inner"],
|
|
123
|
+
activation_function=config["activation"],
|
|
124
|
+
n_positions=config["n_positions"],
|
|
125
|
+
resid_pdrop=config["resid_pdrop"],
|
|
126
|
+
attn_pdrop=config["attn_pdrop"],
|
|
127
|
+
vocab_size=1,
|
|
128
|
+
)
|
|
129
|
+
self.state_dim = state_dim
|
|
130
|
+
self.action_dim = action_dim
|
|
131
|
+
self.hidden_size = config["n_embd"]
|
|
132
|
+
|
|
133
|
+
self.transformer = GPT2Model(config=gpt_config)
|
|
134
|
+
|
|
135
|
+
self.embed_return = torch.nn.Linear(1, self.hidden_size)
|
|
136
|
+
self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size)
|
|
137
|
+
self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size)
|
|
138
|
+
|
|
139
|
+
self.embed_ln = nn.LayerNorm(self.hidden_size)
|
|
140
|
+
|
|
141
|
+
def forward(
|
|
142
|
+
self,
|
|
143
|
+
observation: torch.Tensor,
|
|
144
|
+
action: torch.Tensor,
|
|
145
|
+
return_to_go: torch.Tensor,
|
|
146
|
+
):
|
|
147
|
+
batch_size, seq_length = observation.shape[:-2], observation.shape[-2]
|
|
148
|
+
batch_size_orig = batch_size
|
|
149
|
+
if len(batch_size) != 1:
|
|
150
|
+
# TODO: vmap over transformer once this is possible
|
|
151
|
+
observation = observation.view(-1, *observation.shape[-2:])
|
|
152
|
+
action = action.view(-1, *action.shape[-2:])
|
|
153
|
+
return_to_go = return_to_go.view(-1, *return_to_go.shape[-2:])
|
|
154
|
+
batch_size = torch.Size([batch_size.numel()])
|
|
155
|
+
|
|
156
|
+
# embed each modality with a different head
|
|
157
|
+
state_embeddings = self.embed_state(observation)
|
|
158
|
+
action_embeddings = self.embed_action(action)
|
|
159
|
+
returns_embeddings = self.embed_return(return_to_go)
|
|
160
|
+
|
|
161
|
+
# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
|
|
162
|
+
# which works nice in an autoregressive sense since states predict actions
|
|
163
|
+
stacked_inputs = torch.stack(
|
|
164
|
+
(returns_embeddings, state_embeddings, action_embeddings), dim=-2
|
|
165
|
+
).reshape(*batch_size, 3 * seq_length, self.hidden_size)
|
|
166
|
+
stacked_inputs = self.embed_ln(stacked_inputs)
|
|
167
|
+
|
|
168
|
+
# we feed in the input embeddings (not word indices as in NLP) to the model
|
|
169
|
+
transformer_outputs = self.transformer(
|
|
170
|
+
inputs_embeds=stacked_inputs,
|
|
171
|
+
)
|
|
172
|
+
x = transformer_outputs["last_hidden_state"]
|
|
173
|
+
|
|
174
|
+
# reshape x so that the second dimension corresponds to the original
|
|
175
|
+
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
|
|
176
|
+
x = x.reshape(*batch_size, seq_length, 3, self.hidden_size).transpose(-3, -2)
|
|
177
|
+
if batch_size_orig is batch_size:
|
|
178
|
+
return x[..., 1, :, :] # only state tokens
|
|
179
|
+
return x[..., 1, :, :].reshape(*batch_size_orig, *x.shape[-2:])
|