torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
import os
|
|
10
|
+
from functools import partial
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import hydra
|
|
14
|
+
|
|
15
|
+
from torchrl import merge_ray_runtime_env, torchrl_logger
|
|
16
|
+
from torchrl.data.llm.history import History
|
|
17
|
+
from torchrl.record.loggers.wandb import WandbLogger
|
|
18
|
+
from torchrl.weight_update.llm import get_model_metadata
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import ray
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"Ray is required for sync training. Please install ray with `pip install ray`."
|
|
25
|
+
)
|
|
26
|
+
import time
|
|
27
|
+
|
|
28
|
+
import torch
|
|
29
|
+
import tqdm
|
|
30
|
+
|
|
31
|
+
from grpo_utils import (
|
|
32
|
+
add_kl_transforms_to_replay_buffer,
|
|
33
|
+
check_grpo_dependencies,
|
|
34
|
+
compute_device_allocation,
|
|
35
|
+
get_inference_model,
|
|
36
|
+
get_train_model,
|
|
37
|
+
log_training_metrics,
|
|
38
|
+
make_env,
|
|
39
|
+
make_weight_sync_scheme,
|
|
40
|
+
)
|
|
41
|
+
from omegaconf import DictConfig
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
from tensordict import set_list_to_stack
|
|
45
|
+
except ImportError:
|
|
46
|
+
raise ImportError(
|
|
47
|
+
"TensorDict is required. Please install it with `pip install tensordict`."
|
|
48
|
+
)
|
|
49
|
+
from torch.amp.autocast_mode import autocast
|
|
50
|
+
from torch.amp.grad_scaler import GradScaler
|
|
51
|
+
from torchrl._utils import timeit
|
|
52
|
+
from torchrl.collectors.llm import RayLLMCollector
|
|
53
|
+
from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
|
|
54
|
+
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
|
|
55
|
+
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def setup_environment() -> None:
|
|
59
|
+
"""Setup required environment variables and configurations."""
|
|
60
|
+
if not torch.cuda.is_available():
|
|
61
|
+
raise RuntimeError("CUDA is required for training")
|
|
62
|
+
|
|
63
|
+
# Set default dtype to float32 for mixed precision training
|
|
64
|
+
torch.set_default_dtype(torch.float32)
|
|
65
|
+
torch.set_default_device("cuda:0")
|
|
66
|
+
set_list_to_stack(True).set()
|
|
67
|
+
|
|
68
|
+
# Ensure CUDA is using the correct dtype
|
|
69
|
+
if torch.cuda.is_available():
|
|
70
|
+
torch.cuda.set_device("cuda:0")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def train(
|
|
74
|
+
replay_buffer: ReplayBuffer,
|
|
75
|
+
cfg: DictConfig,
|
|
76
|
+
collector: RayLLMCollector,
|
|
77
|
+
inference_policy,
|
|
78
|
+
devices: list[int] | None = None,
|
|
79
|
+
):
|
|
80
|
+
"""Main training loop for GRPO sync.
|
|
81
|
+
|
|
82
|
+
This function implements synchronous training where data collection and optimization
|
|
83
|
+
happen in separate, consecutive steps. The total number of steps is determined by the number of epochs,
|
|
84
|
+
samples per epoch, and batches collected.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
replay_buffer: The replay buffer to store experiences. The sampler will typically be a `SamplerWithoutReplacement`.
|
|
88
|
+
cfg: The configuration object containing training parameters
|
|
89
|
+
collector: The collector object.
|
|
90
|
+
devices: The devices to use for the training model.
|
|
91
|
+
"""
|
|
92
|
+
# Setup training model and tokenizer
|
|
93
|
+
policy_training, train_tokenizer = get_train_model(cfg, devices=devices)
|
|
94
|
+
train_device = torch.device(f"cuda:{devices[0]}" if devices else "cuda:0")
|
|
95
|
+
|
|
96
|
+
# Setup loss function
|
|
97
|
+
loss_fn = GRPOLoss(
|
|
98
|
+
actor_network=policy_training,
|
|
99
|
+
kl_to_ref_coeff=cfg.train.kl_to_ref_coeff
|
|
100
|
+
if (cfg.train.kl_coef_in_loss and cfg.train.use_kl_to_ref)
|
|
101
|
+
else 0.0,
|
|
102
|
+
kl_to_inference_coeff=cfg.train.kl_to_inference_coeff,
|
|
103
|
+
entropy_coeff=cfg.train.entropy_coeff,
|
|
104
|
+
masking_strategy="rlhf" if cfg.env.reasoning else "sft",
|
|
105
|
+
device=train_device,
|
|
106
|
+
)
|
|
107
|
+
if cfg.env.reasoning:
|
|
108
|
+
# TODO: this is clunky, we should find a way to do this more naturally
|
|
109
|
+
loss_fn.set_keys(sample_log_prob=("next", "log_probs", "full"))
|
|
110
|
+
if cfg.model.compile:
|
|
111
|
+
loss_fn = torch.compile(loss_fn)
|
|
112
|
+
|
|
113
|
+
vllm_engine = inference_policy.model
|
|
114
|
+
|
|
115
|
+
# Create weight sync scheme
|
|
116
|
+
weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine)
|
|
117
|
+
|
|
118
|
+
# Set up weight sender
|
|
119
|
+
torchrl_logger.info("Setting up weight synchronization scheme...")
|
|
120
|
+
sender = weight_sync_scheme.create_sender()
|
|
121
|
+
sender.register_model(policy_training)
|
|
122
|
+
|
|
123
|
+
# Initialize collective group
|
|
124
|
+
torchrl_logger.info("Initializing collective group...")
|
|
125
|
+
metadata = get_model_metadata(policy_training)
|
|
126
|
+
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
|
|
127
|
+
|
|
128
|
+
# First weight update
|
|
129
|
+
with timeit("update_policy_weights"):
|
|
130
|
+
sender.update_weights()
|
|
131
|
+
timeit.print(prefix="First update_policy_weights_ time")
|
|
132
|
+
timeit.reset()
|
|
133
|
+
|
|
134
|
+
# Make optimizer
|
|
135
|
+
torchrl_logger.info("Starting optimizer.")
|
|
136
|
+
optimizer = torch.optim.Adam(
|
|
137
|
+
policy_training.parameters(),
|
|
138
|
+
lr=cfg.optimizer.lr,
|
|
139
|
+
weight_decay=cfg.optimizer.weight_decay,
|
|
140
|
+
fused=False,
|
|
141
|
+
)
|
|
142
|
+
scaler = GradScaler(enabled=cfg.train.mixed_precision)
|
|
143
|
+
|
|
144
|
+
# Make checkpoint dir
|
|
145
|
+
checkpoint_dir = Path(cfg.logging.checkpoint_dir)
|
|
146
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
147
|
+
|
|
148
|
+
# Make wandb logger
|
|
149
|
+
torchrl_logger.info("Starting wandb logger.")
|
|
150
|
+
experiment_name = cfg.logging.experiment_name
|
|
151
|
+
if experiment_name is not None:
|
|
152
|
+
experiment_name = [experiment_name]
|
|
153
|
+
else:
|
|
154
|
+
experiment_name = []
|
|
155
|
+
|
|
156
|
+
experiment_name.append(cfg.env.dataset)
|
|
157
|
+
experiment_name.append(cfg.model.name)
|
|
158
|
+
wandb_logger = WandbLogger(
|
|
159
|
+
project="grpo-sync", exp_name="-".join(["grpo-sync"] + experiment_name)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Training loop
|
|
163
|
+
torchrl_logger.info("Starting training loop.")
|
|
164
|
+
pbar = tqdm.tqdm(collector)
|
|
165
|
+
grad_norm = 0.0 # Initialize grad_norm
|
|
166
|
+
data_read_count = 0
|
|
167
|
+
|
|
168
|
+
global_step = 0
|
|
169
|
+
start_time = time.time()
|
|
170
|
+
for data in pbar:
|
|
171
|
+
# Wait for the replay buffer to be filled - when reasoning, we collect trajectories
|
|
172
|
+
# so the buffer may not be filled straight away
|
|
173
|
+
if not len(replay_buffer):
|
|
174
|
+
torchrl_logger.info("Waiting for replay buffer to be filled")
|
|
175
|
+
continue
|
|
176
|
+
else:
|
|
177
|
+
torchrl_logger.info(f"Replay buffer filled: {len(replay_buffer)}")
|
|
178
|
+
|
|
179
|
+
pbar.update(1)
|
|
180
|
+
|
|
181
|
+
# data is None as the collector directly writes to the replay buffer
|
|
182
|
+
if data is not None:
|
|
183
|
+
raise ValueError("Data is not None")
|
|
184
|
+
|
|
185
|
+
for _ in range(cfg.train.epochs):
|
|
186
|
+
# Iterate over the replay buffer
|
|
187
|
+
for batch in replay_buffer:
|
|
188
|
+
batch = batch.to(train_device)
|
|
189
|
+
global_step += 1
|
|
190
|
+
pbar.set_description(
|
|
191
|
+
f"Gradient step {global_step}, writes: {replay_buffer.write_count}, batch size: {batch.shape}"
|
|
192
|
+
)
|
|
193
|
+
history: History = batch.view(-1)[0]["next", "history"].prompt
|
|
194
|
+
history_str: list[str] | str = history.apply_chat_template(
|
|
195
|
+
tokenizer=train_tokenizer
|
|
196
|
+
)
|
|
197
|
+
while not isinstance(history_str, str):
|
|
198
|
+
history_str = "\n".join(history_str)
|
|
199
|
+
|
|
200
|
+
data_read_count += batch.numel()
|
|
201
|
+
|
|
202
|
+
with timeit("forward_pass"):
|
|
203
|
+
with autocast("cuda", enabled=cfg.train.mixed_precision):
|
|
204
|
+
loss = loss_fn(batch)
|
|
205
|
+
loss_val = (
|
|
206
|
+
loss.mean(reduce=True)
|
|
207
|
+
/ cfg.train.gradient_accumulation_steps
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
with timeit("backward_pass"):
|
|
211
|
+
if (
|
|
212
|
+
cfg.train.mixed_precision
|
|
213
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
214
|
+
):
|
|
215
|
+
scaler = GradScaler(enabled=True)
|
|
216
|
+
scaler.scale(loss_val).backward()
|
|
217
|
+
else:
|
|
218
|
+
loss_val.backward()
|
|
219
|
+
|
|
220
|
+
if ((global_step + 1) % cfg.train.gradient_accumulation_steps) == 0:
|
|
221
|
+
with timeit("optim_step"):
|
|
222
|
+
if (
|
|
223
|
+
cfg.train.mixed_precision
|
|
224
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
225
|
+
):
|
|
226
|
+
scaler.unscale_(optimizer)
|
|
227
|
+
|
|
228
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
229
|
+
policy_training.parameters(),
|
|
230
|
+
cfg.optimizer.clip_grad_norm,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if (
|
|
234
|
+
cfg.train.mixed_precision
|
|
235
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
236
|
+
):
|
|
237
|
+
scaler.step(optimizer)
|
|
238
|
+
scaler.update()
|
|
239
|
+
else:
|
|
240
|
+
optimizer.step()
|
|
241
|
+
optimizer.zero_grad(set_to_none=True)
|
|
242
|
+
|
|
243
|
+
del loss_val
|
|
244
|
+
# TODO: do we need this? Does it interfere with other processes?
|
|
245
|
+
# torch.cuda.empty_cache()
|
|
246
|
+
gc.collect()
|
|
247
|
+
|
|
248
|
+
if (global_step % cfg.train.logging_frequency) == 0:
|
|
249
|
+
log_training_metrics(
|
|
250
|
+
wandb_logger=wandb_logger,
|
|
251
|
+
replay_buffer=replay_buffer,
|
|
252
|
+
batch=batch,
|
|
253
|
+
loss=loss,
|
|
254
|
+
grad_norm=grad_norm,
|
|
255
|
+
global_step=global_step,
|
|
256
|
+
data_read_count=data_read_count,
|
|
257
|
+
collector=collector,
|
|
258
|
+
start_time=start_time,
|
|
259
|
+
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
|
|
260
|
+
history_str=history_str,
|
|
261
|
+
use_kl_to_ref=cfg.train.use_kl_to_ref,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Checkpointing disabled to prevent disk space issues
|
|
265
|
+
# if (global_step + 1) % cfg.train.checkpoint_frequency == 0:
|
|
266
|
+
# with timeit("save_checkpoint"):
|
|
267
|
+
# torchrl_logger.info(
|
|
268
|
+
# f"Saving checkpoint {(global_step+1) // cfg.train.checkpoint_frequency}..."
|
|
269
|
+
# )
|
|
270
|
+
# checkpoint = {
|
|
271
|
+
# "step": global_step,
|
|
272
|
+
# "model_state_dict": policy_training.model.state_dict(),
|
|
273
|
+
# "optimizer_state_dict": optimizer.state_dict(),
|
|
274
|
+
# "scaler_state_dict": scaler.state_dict(),
|
|
275
|
+
# "config": dict(cfg),
|
|
276
|
+
# }
|
|
277
|
+
# torch.save(checkpoint, checkpoint_dir / f"checkpoint_{global_step:04d}.pt")
|
|
278
|
+
|
|
279
|
+
with timeit("update_policy_weights"):
|
|
280
|
+
torchrl_logger.info("Updating policy weights...")
|
|
281
|
+
sender.update_weights()
|
|
282
|
+
# TODO: do we need this? Does it interfere with other processes?
|
|
283
|
+
# torch.cuda.empty_cache()
|
|
284
|
+
gc.collect()
|
|
285
|
+
|
|
286
|
+
timeit.print(prefix="timeit")
|
|
287
|
+
for key, val in timeit.todict().items():
|
|
288
|
+
wandb_logger.log_scalar(f"timeit/{key}", val)
|
|
289
|
+
timeit.reset()
|
|
290
|
+
|
|
291
|
+
if cfg.train.empty_replay_buffer:
|
|
292
|
+
replay_buffer.empty(empty_write_count=False)
|
|
293
|
+
|
|
294
|
+
pbar.close()
|
|
295
|
+
collector.shutdown()
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@hydra.main(version_base=None, config_path="config", config_name="grpo_gsm8k")
|
|
299
|
+
def main(cfg):
|
|
300
|
+
# Check for required GRPO dependencies
|
|
301
|
+
check_grpo_dependencies()
|
|
302
|
+
|
|
303
|
+
# Force sync mode
|
|
304
|
+
if not cfg.train.sync:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
"grpo-sync.py must run in sync mode (`python grpo-sync.py mode=sync`). Please use grpo-async.py for async mode (`python grpo-async.py mode=async`)."
|
|
307
|
+
)
|
|
308
|
+
if cfg.train.weight_update_frequency is not None:
|
|
309
|
+
raise ValueError("weight_update_frequency must be left empty in sync mode.")
|
|
310
|
+
|
|
311
|
+
# Compute device allocation
|
|
312
|
+
device_config = compute_device_allocation(cfg)
|
|
313
|
+
|
|
314
|
+
if not ray.is_initialized():
|
|
315
|
+
# Convert OmegaConf to regular dict and filter out unsupported parameters
|
|
316
|
+
ray_init_config = {
|
|
317
|
+
k: dict(v) if isinstance(v, DictConfig) else v
|
|
318
|
+
for k, v in dict(cfg.ray.init_config).items()
|
|
319
|
+
if not k.startswith("_")
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
# Add computed GPU configuration and merge with default runtime_env
|
|
323
|
+
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
|
|
324
|
+
ray_init_config = merge_ray_runtime_env(ray_init_config)
|
|
325
|
+
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
|
|
326
|
+
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
|
|
327
|
+
if ray_managed_externally:
|
|
328
|
+
ray.init(address="auto")
|
|
329
|
+
else:
|
|
330
|
+
ray.init(**ray_init_config)
|
|
331
|
+
|
|
332
|
+
# Check if num_devices is set
|
|
333
|
+
if cfg.inference_model.num_devices is None:
|
|
334
|
+
raise ValueError(
|
|
335
|
+
"Inference model num_devices must be set via inference_model.num_devices"
|
|
336
|
+
)
|
|
337
|
+
if cfg.train.use_kl_to_ref and cfg.ref_model.num_devices is None:
|
|
338
|
+
raise ValueError(
|
|
339
|
+
"Ref model num_devices must be set via ref_model.num_devices when use_kl_to_ref is True"
|
|
340
|
+
)
|
|
341
|
+
if cfg.train_model.num_devices is None:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
"Train model num_devices must be set via train_model.num_devices"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Convert OmegaConf to regular dict for Ray configs
|
|
347
|
+
replay_buffer_config = dict(cfg.ray.replay_buffer_config)
|
|
348
|
+
collector_config = dict(cfg.ray.collector_config)
|
|
349
|
+
train_handler_config = dict(cfg.ray.train_handler_config)
|
|
350
|
+
|
|
351
|
+
inference_policy = get_inference_model(
|
|
352
|
+
cfg, devices=device_config["inference_model_devices"]
|
|
353
|
+
)
|
|
354
|
+
torchrl_logger.info(f"Inference policy: {inference_policy}")
|
|
355
|
+
|
|
356
|
+
torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
|
|
357
|
+
if cfg.train.buffer_size is not None and (
|
|
358
|
+
cfg.train.buffer_size != cfg.train.dialog_turns_per_batch
|
|
359
|
+
):
|
|
360
|
+
raise ValueError(
|
|
361
|
+
"buffer_size must be equal to dialog_turns_per_batch in sync settings."
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if cfg.train.optim_batch_size % cfg.train.gradient_accumulation_steps != 0:
|
|
365
|
+
raise ValueError(
|
|
366
|
+
"optim_batch_size must be divisible by gradient_accumulation_steps"
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
rb = RayReplayBuffer(
|
|
370
|
+
storage=partial(
|
|
371
|
+
LazyStackStorage,
|
|
372
|
+
# Since we cache the values in the queue until we have "repeats" samples,
|
|
373
|
+
# the buffer can be bigger than what the dialog_turns_per_batch (at most repeats * num_envs)
|
|
374
|
+
cfg.train.buffer_size
|
|
375
|
+
if cfg.train.buffer_size
|
|
376
|
+
else cfg.env.repeats * cfg.env.num_envs,
|
|
377
|
+
),
|
|
378
|
+
sampler=SamplerWithoutReplacement,
|
|
379
|
+
transform_factory=partial(MCAdvantage, grpo_size=cfg.env.repeats),
|
|
380
|
+
batch_size=max(
|
|
381
|
+
1, cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps
|
|
382
|
+
),
|
|
383
|
+
remote_config=replay_buffer_config,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
add_kl_transforms_to_replay_buffer(rb, cfg)
|
|
387
|
+
|
|
388
|
+
torchrl_logger.info(f"Replay buffer: {rb}")
|
|
389
|
+
|
|
390
|
+
collector_config["num_gpus"] = 0
|
|
391
|
+
collector_config["num_cpus"] = cfg.ray.collector_config.get("num_cpus", 1)
|
|
392
|
+
torchrl_logger.info(f"Starting collector with {collector_config=}")
|
|
393
|
+
|
|
394
|
+
collector = RayLLMCollector(
|
|
395
|
+
env=partial(make_env, cfg),
|
|
396
|
+
policy=inference_policy,
|
|
397
|
+
dialog_turns_per_batch=cfg.train.dialog_turns_per_batch,
|
|
398
|
+
total_dialog_turns=cfg.train.total_dialog_turns,
|
|
399
|
+
replay_buffer=rb,
|
|
400
|
+
ray_init_config=None,
|
|
401
|
+
weight_updater=None,
|
|
402
|
+
track_policy_version=True,
|
|
403
|
+
remote_config=collector_config,
|
|
404
|
+
sync_iter=cfg.train.sync_iter,
|
|
405
|
+
verbose=False,
|
|
406
|
+
yield_only_last_steps=cfg.env.reasoning,
|
|
407
|
+
)
|
|
408
|
+
ray.get(collector._collector.is_initialized.remote())
|
|
409
|
+
torchrl_logger.info(f"Collector: {collector}")
|
|
410
|
+
|
|
411
|
+
train_handler_config = {
|
|
412
|
+
"num_cpus": train_handler_config.get("num_cpus", 1),
|
|
413
|
+
"num_gpus": cfg.train_model.num_devices,
|
|
414
|
+
}
|
|
415
|
+
torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
|
|
416
|
+
train_handler = ray.remote(
|
|
417
|
+
**train_handler_config,
|
|
418
|
+
)(train)
|
|
419
|
+
|
|
420
|
+
# launch training
|
|
421
|
+
ray.get(
|
|
422
|
+
train_handler.remote(
|
|
423
|
+
rb,
|
|
424
|
+
cfg,
|
|
425
|
+
collector,
|
|
426
|
+
inference_policy,
|
|
427
|
+
devices=device_config["train_model_devices"],
|
|
428
|
+
)
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
if __name__ == "__main__":
|
|
433
|
+
# Setup environment
|
|
434
|
+
setup_environment()
|
|
435
|
+
main()
|