torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""CQL Example.
|
|
6
|
+
|
|
7
|
+
This is a self-contained example of an offline CQL training script.
|
|
8
|
+
|
|
9
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import warnings
|
|
15
|
+
|
|
16
|
+
import hydra
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import tqdm
|
|
20
|
+
from tensordict.nn import CudaGraphModule
|
|
21
|
+
from torchrl._utils import get_available_device, timeit
|
|
22
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
23
|
+
from torchrl.objectives import group_optimizers
|
|
24
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
25
|
+
from utils import (
|
|
26
|
+
dump_video,
|
|
27
|
+
log_metrics,
|
|
28
|
+
make_continuous_cql_optimizer,
|
|
29
|
+
make_continuous_loss,
|
|
30
|
+
make_cql_model,
|
|
31
|
+
make_environment,
|
|
32
|
+
make_offline_replay_buffer,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
torch.set_float32_matmul_precision("high")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
|
|
39
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
40
|
+
# Create logger
|
|
41
|
+
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
|
|
42
|
+
logger = None
|
|
43
|
+
if cfg.logger.backend:
|
|
44
|
+
logger = get_logger(
|
|
45
|
+
logger_type=cfg.logger.backend,
|
|
46
|
+
logger_name="cql_logging",
|
|
47
|
+
experiment_name=exp_name,
|
|
48
|
+
wandb_kwargs={
|
|
49
|
+
"mode": cfg.logger.mode,
|
|
50
|
+
"config": dict(cfg),
|
|
51
|
+
"project": cfg.logger.project_name,
|
|
52
|
+
"group": cfg.logger.group_name,
|
|
53
|
+
},
|
|
54
|
+
)
|
|
55
|
+
# Set seeds
|
|
56
|
+
torch.manual_seed(cfg.env.seed)
|
|
57
|
+
np.random.seed(cfg.env.seed)
|
|
58
|
+
device = (
|
|
59
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Create replay buffer
|
|
63
|
+
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
|
|
64
|
+
|
|
65
|
+
# Create env
|
|
66
|
+
train_env, eval_env = make_environment(
|
|
67
|
+
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Create agent
|
|
71
|
+
model = make_cql_model(cfg, train_env, eval_env, device)
|
|
72
|
+
del train_env
|
|
73
|
+
if hasattr(eval_env, "start"):
|
|
74
|
+
# To set the number of threads to the definitive value
|
|
75
|
+
eval_env.start()
|
|
76
|
+
|
|
77
|
+
# Create loss
|
|
78
|
+
loss_module, target_net_updater = make_continuous_loss(
|
|
79
|
+
cfg.loss, model, device=device
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Create Optimizer
|
|
83
|
+
(
|
|
84
|
+
policy_optim,
|
|
85
|
+
critic_optim,
|
|
86
|
+
alpha_optim,
|
|
87
|
+
alpha_prime_optim,
|
|
88
|
+
) = make_continuous_cql_optimizer(cfg, loss_module)
|
|
89
|
+
|
|
90
|
+
# Group optimizers
|
|
91
|
+
optimizer = group_optimizers(
|
|
92
|
+
policy_optim, critic_optim, alpha_optim, alpha_prime_optim
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def update(data, policy_eval_start, iteration):
|
|
96
|
+
loss_vals = loss_module(data.to(device))
|
|
97
|
+
|
|
98
|
+
# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
|
|
99
|
+
actor_loss = torch.where(
|
|
100
|
+
iteration >= policy_eval_start,
|
|
101
|
+
loss_vals["loss_actor"],
|
|
102
|
+
loss_vals["loss_actor_bc"],
|
|
103
|
+
)
|
|
104
|
+
q_loss = loss_vals["loss_qvalue"]
|
|
105
|
+
cql_loss = loss_vals["loss_cql"]
|
|
106
|
+
|
|
107
|
+
q_loss = q_loss + cql_loss
|
|
108
|
+
loss_vals["q_loss"] = q_loss
|
|
109
|
+
|
|
110
|
+
# update model
|
|
111
|
+
alpha_loss = loss_vals["loss_alpha"]
|
|
112
|
+
alpha_prime_loss = loss_vals["loss_alpha_prime"]
|
|
113
|
+
if alpha_prime_loss is None:
|
|
114
|
+
alpha_prime_loss = 0
|
|
115
|
+
|
|
116
|
+
loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
|
|
117
|
+
|
|
118
|
+
loss.backward()
|
|
119
|
+
optimizer.step()
|
|
120
|
+
optimizer.zero_grad(set_to_none=True)
|
|
121
|
+
|
|
122
|
+
# update qnet_target params
|
|
123
|
+
target_net_updater.step()
|
|
124
|
+
|
|
125
|
+
return loss.detach(), loss_vals.detach()
|
|
126
|
+
|
|
127
|
+
compile_mode = None
|
|
128
|
+
if cfg.compile.compile:
|
|
129
|
+
if cfg.compile.compile_mode not in (None, ""):
|
|
130
|
+
compile_mode = cfg.compile.compile_mode
|
|
131
|
+
elif cfg.compile.cudagraphs:
|
|
132
|
+
compile_mode = "default"
|
|
133
|
+
else:
|
|
134
|
+
compile_mode = "reduce-overhead"
|
|
135
|
+
update = torch.compile(update, mode=compile_mode)
|
|
136
|
+
if cfg.compile.cudagraphs:
|
|
137
|
+
warnings.warn(
|
|
138
|
+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
|
|
139
|
+
category=UserWarning,
|
|
140
|
+
)
|
|
141
|
+
update = CudaGraphModule(update, warmup=50)
|
|
142
|
+
|
|
143
|
+
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
|
|
144
|
+
|
|
145
|
+
gradient_steps = cfg.optim.gradient_steps
|
|
146
|
+
policy_eval_start = cfg.optim.policy_eval_start
|
|
147
|
+
evaluation_interval = cfg.logger.eval_iter
|
|
148
|
+
eval_steps = cfg.logger.eval_steps
|
|
149
|
+
|
|
150
|
+
# Training loop
|
|
151
|
+
policy_eval_start = torch.tensor(policy_eval_start, device=device)
|
|
152
|
+
for i in range(gradient_steps):
|
|
153
|
+
timeit.printevery(1000, gradient_steps, erase=True)
|
|
154
|
+
pbar.update(1)
|
|
155
|
+
# sample data
|
|
156
|
+
with timeit("sample"):
|
|
157
|
+
data = replay_buffer.sample()
|
|
158
|
+
|
|
159
|
+
with timeit("update"):
|
|
160
|
+
# compute loss
|
|
161
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
162
|
+
i_device = torch.tensor(i, device=device)
|
|
163
|
+
loss, loss_vals = update(
|
|
164
|
+
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# log metrics
|
|
168
|
+
metrics_to_log = {
|
|
169
|
+
"loss": loss.cpu(),
|
|
170
|
+
**loss_vals.cpu(),
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# evaluation
|
|
174
|
+
with timeit("log/eval"):
|
|
175
|
+
if i % evaluation_interval == 0:
|
|
176
|
+
with set_exploration_type(
|
|
177
|
+
ExplorationType.DETERMINISTIC
|
|
178
|
+
), torch.no_grad():
|
|
179
|
+
eval_td = eval_env.rollout(
|
|
180
|
+
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
|
|
181
|
+
)
|
|
182
|
+
eval_env.apply(dump_video)
|
|
183
|
+
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
|
|
184
|
+
metrics_to_log["evaluation_reward"] = eval_reward
|
|
185
|
+
|
|
186
|
+
with timeit("log"):
|
|
187
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
188
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
189
|
+
|
|
190
|
+
log_metrics(logger, metrics_to_log, i)
|
|
191
|
+
|
|
192
|
+
pbar.close()
|
|
193
|
+
if not eval_env.is_closed:
|
|
194
|
+
eval_env.close()
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
if __name__ == "__main__":
|
|
198
|
+
main()
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""CQL Example.
|
|
6
|
+
|
|
7
|
+
This is a self-contained example of an online CQL training script.
|
|
8
|
+
|
|
9
|
+
It works across Gym and MuJoCo over a variety of tasks.
|
|
10
|
+
|
|
11
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import warnings
|
|
17
|
+
|
|
18
|
+
import hydra
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import tqdm
|
|
22
|
+
from tensordict import TensorDict
|
|
23
|
+
from tensordict.nn import CudaGraphModule
|
|
24
|
+
from torchrl._utils import get_available_device, timeit
|
|
25
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
26
|
+
from torchrl.objectives import group_optimizers
|
|
27
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
28
|
+
from utils import (
|
|
29
|
+
dump_video,
|
|
30
|
+
log_metrics,
|
|
31
|
+
make_collector,
|
|
32
|
+
make_continuous_cql_optimizer,
|
|
33
|
+
make_continuous_loss,
|
|
34
|
+
make_cql_model,
|
|
35
|
+
make_environment,
|
|
36
|
+
make_replay_buffer,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
torch.set_float32_matmul_precision("high")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@hydra.main(version_base="1.1", config_path="", config_name="online_config")
|
|
43
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
44
|
+
# Create logger
|
|
45
|
+
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
|
|
46
|
+
logger = None
|
|
47
|
+
if cfg.logger.backend:
|
|
48
|
+
logger = get_logger(
|
|
49
|
+
logger_type=cfg.logger.backend,
|
|
50
|
+
logger_name="cql_logging",
|
|
51
|
+
experiment_name=exp_name,
|
|
52
|
+
wandb_kwargs={
|
|
53
|
+
"mode": cfg.logger.mode,
|
|
54
|
+
"config": dict(cfg),
|
|
55
|
+
"project": cfg.logger.project_name,
|
|
56
|
+
"group": cfg.logger.group_name,
|
|
57
|
+
},
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Set seeds
|
|
61
|
+
torch.manual_seed(cfg.env.seed)
|
|
62
|
+
np.random.seed(cfg.env.seed)
|
|
63
|
+
device = (
|
|
64
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Create env
|
|
68
|
+
train_env, eval_env = make_environment(
|
|
69
|
+
cfg,
|
|
70
|
+
cfg.env.train_num_envs,
|
|
71
|
+
cfg.env.eval_num_envs,
|
|
72
|
+
logger=logger,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Create replay buffer
|
|
76
|
+
replay_buffer = make_replay_buffer(
|
|
77
|
+
batch_size=cfg.optim.batch_size,
|
|
78
|
+
prb=cfg.replay_buffer.prb,
|
|
79
|
+
buffer_size=cfg.replay_buffer.size,
|
|
80
|
+
device="cpu",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# create agent
|
|
84
|
+
model = make_cql_model(cfg, train_env, eval_env, device)
|
|
85
|
+
|
|
86
|
+
compile_mode = None
|
|
87
|
+
if cfg.compile.compile:
|
|
88
|
+
if cfg.compile.compile_mode not in (None, ""):
|
|
89
|
+
compile_mode = cfg.compile.compile_mode
|
|
90
|
+
elif cfg.compile.cudagraphs:
|
|
91
|
+
compile_mode = "default"
|
|
92
|
+
else:
|
|
93
|
+
compile_mode = "reduce-overhead"
|
|
94
|
+
|
|
95
|
+
# Create collector
|
|
96
|
+
collector = make_collector(
|
|
97
|
+
cfg,
|
|
98
|
+
train_env,
|
|
99
|
+
actor_model_explore=model[0],
|
|
100
|
+
compile=cfg.compile.compile,
|
|
101
|
+
compile_mode=compile_mode,
|
|
102
|
+
cudagraph=cfg.compile.cudagraphs,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Create loss
|
|
106
|
+
loss_module, target_net_updater = make_continuous_loss(
|
|
107
|
+
cfg.loss, model, device=device
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Create optimizer
|
|
111
|
+
(
|
|
112
|
+
policy_optim,
|
|
113
|
+
critic_optim,
|
|
114
|
+
alpha_optim,
|
|
115
|
+
alpha_prime_optim,
|
|
116
|
+
) = make_continuous_cql_optimizer(cfg, loss_module)
|
|
117
|
+
optimizer = group_optimizers(
|
|
118
|
+
policy_optim, critic_optim, alpha_optim, alpha_prime_optim
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def update(sampled_tensordict):
|
|
122
|
+
|
|
123
|
+
loss_td = loss_module(sampled_tensordict)
|
|
124
|
+
|
|
125
|
+
actor_loss = loss_td["loss_actor"]
|
|
126
|
+
q_loss = loss_td["loss_qvalue"]
|
|
127
|
+
cql_loss = loss_td["loss_cql"]
|
|
128
|
+
q_loss = q_loss + cql_loss
|
|
129
|
+
alpha_loss = loss_td["loss_alpha"]
|
|
130
|
+
alpha_prime_loss = loss_td["loss_alpha_prime"]
|
|
131
|
+
|
|
132
|
+
total_loss = alpha_loss + actor_loss + alpha_prime_loss + q_loss
|
|
133
|
+
total_loss.backward()
|
|
134
|
+
optimizer.step()
|
|
135
|
+
optimizer.zero_grad(set_to_none=True)
|
|
136
|
+
|
|
137
|
+
# update qnet_target params
|
|
138
|
+
target_net_updater.step()
|
|
139
|
+
|
|
140
|
+
return loss_td.detach()
|
|
141
|
+
|
|
142
|
+
if compile_mode:
|
|
143
|
+
update = torch.compile(update, mode=compile_mode)
|
|
144
|
+
if cfg.compile.cudagraphs:
|
|
145
|
+
warnings.warn(
|
|
146
|
+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
|
|
147
|
+
category=UserWarning,
|
|
148
|
+
)
|
|
149
|
+
update = CudaGraphModule(update, warmup=50)
|
|
150
|
+
|
|
151
|
+
# Main loop
|
|
152
|
+
collected_frames = 0
|
|
153
|
+
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
|
|
154
|
+
|
|
155
|
+
init_random_frames = cfg.collector.init_random_frames
|
|
156
|
+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
|
|
157
|
+
prb = cfg.replay_buffer.prb
|
|
158
|
+
frames_per_batch = cfg.collector.frames_per_batch
|
|
159
|
+
evaluation_interval = cfg.logger.log_interval
|
|
160
|
+
eval_rollout_steps = cfg.logger.eval_steps
|
|
161
|
+
|
|
162
|
+
c_iter = iter(collector)
|
|
163
|
+
total_iter = len(collector)
|
|
164
|
+
for i in range(total_iter):
|
|
165
|
+
timeit.printevery(1000, total_iter, erase=True)
|
|
166
|
+
with timeit("collecting"):
|
|
167
|
+
tensordict = next(c_iter)
|
|
168
|
+
pbar.update(tensordict.numel())
|
|
169
|
+
# update weights of the inference policy
|
|
170
|
+
collector.update_policy_weights_()
|
|
171
|
+
|
|
172
|
+
with timeit("rb - extend"):
|
|
173
|
+
tensordict = tensordict.view(-1)
|
|
174
|
+
current_frames = tensordict.numel()
|
|
175
|
+
# add to replay buffer
|
|
176
|
+
replay_buffer.extend(tensordict)
|
|
177
|
+
collected_frames += current_frames
|
|
178
|
+
|
|
179
|
+
if collected_frames >= init_random_frames:
|
|
180
|
+
log_loss_td = TensorDict(batch_size=[num_updates], device=device)
|
|
181
|
+
for j in range(num_updates):
|
|
182
|
+
pbar.set_description(f"optim iter {j}")
|
|
183
|
+
with timeit("rb - sample"):
|
|
184
|
+
# sample from replay buffer
|
|
185
|
+
sampled_tensordict = replay_buffer.sample().to(device)
|
|
186
|
+
|
|
187
|
+
with timeit("update"):
|
|
188
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
189
|
+
loss_td = update(sampled_tensordict)
|
|
190
|
+
log_loss_td[j] = loss_td.detach()
|
|
191
|
+
# update priority
|
|
192
|
+
if prb:
|
|
193
|
+
with timeit("rb - update priority"):
|
|
194
|
+
replay_buffer.update_priority(sampled_tensordict)
|
|
195
|
+
|
|
196
|
+
episode_rewards = tensordict["next", "episode_reward"][
|
|
197
|
+
tensordict["next", "done"]
|
|
198
|
+
]
|
|
199
|
+
# Logging
|
|
200
|
+
metrics_to_log = {}
|
|
201
|
+
if len(episode_rewards) > 0:
|
|
202
|
+
episode_length = tensordict["next", "step_count"][
|
|
203
|
+
tensordict["next", "done"]
|
|
204
|
+
]
|
|
205
|
+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
|
|
206
|
+
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
|
|
207
|
+
episode_length
|
|
208
|
+
)
|
|
209
|
+
if collected_frames >= init_random_frames:
|
|
210
|
+
metrics_to_log["train/loss_actor"] = log_loss_td.get("loss_actor").mean()
|
|
211
|
+
metrics_to_log["train/loss_qvalue"] = log_loss_td.get("loss_qvalue").mean()
|
|
212
|
+
metrics_to_log["train/loss_alpha"] = log_loss_td.get("loss_alpha").mean()
|
|
213
|
+
metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get(
|
|
214
|
+
"loss_alpha_prime"
|
|
215
|
+
).mean()
|
|
216
|
+
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
|
|
217
|
+
|
|
218
|
+
# Evaluation
|
|
219
|
+
with timeit("eval"):
|
|
220
|
+
prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval
|
|
221
|
+
cur_test_frame = (i * frames_per_batch) // evaluation_interval
|
|
222
|
+
final = current_frames >= collector.total_frames
|
|
223
|
+
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
|
|
224
|
+
with set_exploration_type(
|
|
225
|
+
ExplorationType.DETERMINISTIC
|
|
226
|
+
), torch.no_grad():
|
|
227
|
+
eval_rollout = eval_env.rollout(
|
|
228
|
+
eval_rollout_steps,
|
|
229
|
+
model[0],
|
|
230
|
+
auto_cast_to_device=True,
|
|
231
|
+
break_when_any_done=True,
|
|
232
|
+
)
|
|
233
|
+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
|
|
234
|
+
eval_env.apply(dump_video)
|
|
235
|
+
metrics_to_log["eval/reward"] = eval_reward
|
|
236
|
+
|
|
237
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
238
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
239
|
+
log_metrics(logger, metrics_to_log, collected_frames)
|
|
240
|
+
|
|
241
|
+
collector.shutdown()
|
|
242
|
+
if not eval_env.is_closed:
|
|
243
|
+
eval_env.close()
|
|
244
|
+
if not train_env.is_closed:
|
|
245
|
+
train_env.close()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
if __name__ == "__main__":
|
|
249
|
+
main()
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""CQL Example.
|
|
7
|
+
|
|
8
|
+
This is a self-contained example of a discrete offline CQL training script.
|
|
9
|
+
|
|
10
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
11
|
+
"""
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import warnings
|
|
15
|
+
|
|
16
|
+
import hydra
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import tqdm
|
|
20
|
+
from tensordict.nn import CudaGraphModule
|
|
21
|
+
from torchrl._utils import get_available_device, timeit
|
|
22
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
23
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
24
|
+
from utils import (
|
|
25
|
+
dump_video,
|
|
26
|
+
log_metrics,
|
|
27
|
+
make_discrete_cql_optimizer,
|
|
28
|
+
make_discrete_loss,
|
|
29
|
+
make_discretecql_model,
|
|
30
|
+
make_environment,
|
|
31
|
+
make_offline_discrete_replay_buffer,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
torch.set_float32_matmul_precision("high")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@hydra.main(version_base="1.1", config_path="", config_name="discrete_offline_config")
|
|
38
|
+
def main(cfg): # noqa: F821
|
|
39
|
+
device = (
|
|
40
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Create logger
|
|
44
|
+
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
|
|
45
|
+
logger = None
|
|
46
|
+
if cfg.logger.backend:
|
|
47
|
+
logger = get_logger(
|
|
48
|
+
logger_type=cfg.logger.backend,
|
|
49
|
+
logger_name="discretecql_logging",
|
|
50
|
+
experiment_name=exp_name,
|
|
51
|
+
wandb_kwargs={
|
|
52
|
+
"mode": cfg.logger.mode,
|
|
53
|
+
"config": dict(cfg),
|
|
54
|
+
"project": cfg.logger.project_name,
|
|
55
|
+
"group": cfg.logger.group_name,
|
|
56
|
+
},
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Set seeds
|
|
60
|
+
torch.manual_seed(cfg.env.seed)
|
|
61
|
+
np.random.seed(cfg.env.seed)
|
|
62
|
+
if cfg.env.seed is not None:
|
|
63
|
+
warnings.warn(
|
|
64
|
+
"The seed in the environment config is deprecated. "
|
|
65
|
+
"Please set the seed in the optim config instead."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Create replay buffer
|
|
69
|
+
replay_buffer = make_offline_discrete_replay_buffer(cfg.replay_buffer)
|
|
70
|
+
|
|
71
|
+
# Create env
|
|
72
|
+
train_env, eval_env = make_environment(
|
|
73
|
+
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Create agent
|
|
77
|
+
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)
|
|
78
|
+
|
|
79
|
+
del train_env
|
|
80
|
+
|
|
81
|
+
# Create loss
|
|
82
|
+
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device)
|
|
83
|
+
|
|
84
|
+
# Create optimizers
|
|
85
|
+
optimizer = make_discrete_cql_optimizer(cfg, loss_module) # optimizer for CQL loss
|
|
86
|
+
|
|
87
|
+
def update(data):
|
|
88
|
+
|
|
89
|
+
# Compute loss components
|
|
90
|
+
loss_vals = loss_module(data)
|
|
91
|
+
|
|
92
|
+
q_loss = loss_vals["loss_qvalue"]
|
|
93
|
+
cql_loss = loss_vals["loss_cql"]
|
|
94
|
+
|
|
95
|
+
# Total loss = Q-learning loss + CQL regularization
|
|
96
|
+
loss = q_loss + cql_loss
|
|
97
|
+
|
|
98
|
+
loss.backward()
|
|
99
|
+
optimizer.step()
|
|
100
|
+
optimizer.zero_grad(set_to_none=True)
|
|
101
|
+
|
|
102
|
+
# Soft update of target Q-network
|
|
103
|
+
target_net_updater.step()
|
|
104
|
+
|
|
105
|
+
# Detach to avoid keeping computation graph in logging
|
|
106
|
+
return loss.detach(), loss_vals.detach()
|
|
107
|
+
|
|
108
|
+
compile_mode = None
|
|
109
|
+
if cfg.compile.compile:
|
|
110
|
+
if cfg.compile.compile_mode not in (None, ""):
|
|
111
|
+
compile_mode = cfg.compile.compile_mode
|
|
112
|
+
elif cfg.compile.cudagraphs:
|
|
113
|
+
compile_mode = "default"
|
|
114
|
+
else:
|
|
115
|
+
compile_mode = "reduce-overhead"
|
|
116
|
+
update = torch.compile(update, mode=compile_mode)
|
|
117
|
+
if cfg.compile.cudagraphs:
|
|
118
|
+
warnings.warn(
|
|
119
|
+
"CudaGraphModule es experimental y puede llevar a resultados incorrectos silenciosamente. Úsalo con precaución.",
|
|
120
|
+
category=UserWarning,
|
|
121
|
+
)
|
|
122
|
+
update = CudaGraphModule(update, warmup=50)
|
|
123
|
+
|
|
124
|
+
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
|
|
125
|
+
|
|
126
|
+
gradient_steps = cfg.optim.gradient_steps
|
|
127
|
+
policy_eval_start = cfg.optim.policy_eval_start
|
|
128
|
+
evaluation_interval = cfg.logger.eval_iter
|
|
129
|
+
eval_steps = cfg.logger.eval_steps
|
|
130
|
+
|
|
131
|
+
# Training loop
|
|
132
|
+
policy_eval_start = torch.tensor(policy_eval_start, device=device)
|
|
133
|
+
for i in range(gradient_steps):
|
|
134
|
+
timeit.printevery(1000, gradient_steps, erase=True)
|
|
135
|
+
pbar.update(1)
|
|
136
|
+
# sample data
|
|
137
|
+
with timeit("sample"):
|
|
138
|
+
data = replay_buffer.sample()
|
|
139
|
+
|
|
140
|
+
with timeit("update"):
|
|
141
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
142
|
+
loss, loss_vals = update(data.to(device))
|
|
143
|
+
|
|
144
|
+
# log metrics
|
|
145
|
+
metrics_to_log = {
|
|
146
|
+
"loss": loss.cpu(),
|
|
147
|
+
**loss_vals.cpu(),
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
# evaluation
|
|
151
|
+
with timeit("log/eval"):
|
|
152
|
+
if i % evaluation_interval == 0:
|
|
153
|
+
with set_exploration_type(
|
|
154
|
+
ExplorationType.DETERMINISTIC
|
|
155
|
+
), torch.no_grad():
|
|
156
|
+
eval_td = eval_env.rollout(
|
|
157
|
+
max_steps=eval_steps,
|
|
158
|
+
policy=explore_policy,
|
|
159
|
+
auto_cast_to_device=True,
|
|
160
|
+
)
|
|
161
|
+
eval_env.apply(dump_video)
|
|
162
|
+
|
|
163
|
+
# eval_td: matrix of shape: [num_episodes, max_steps, ...]
|
|
164
|
+
eval_reward = (
|
|
165
|
+
eval_td["next", "reward"].sum(1).mean().item()
|
|
166
|
+
) # mean computed over the sum of rewards for each episode
|
|
167
|
+
metrics_to_log["evaluation_reward"] = eval_reward
|
|
168
|
+
|
|
169
|
+
with timeit("log"):
|
|
170
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
171
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
172
|
+
log_metrics(logger, metrics_to_log, i)
|
|
173
|
+
|
|
174
|
+
pbar.close()
|
|
175
|
+
if not eval_env.is_closed:
|
|
176
|
+
eval_env.close()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
if __name__ == "__main__":
|
|
180
|
+
main()
|