torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from functools import partial
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import hydra
|
|
15
|
+
|
|
16
|
+
from torchrl import merge_ray_runtime_env, torchrl_logger
|
|
17
|
+
from torchrl.data.llm.history import History
|
|
18
|
+
from torchrl.record.loggers.wandb import WandbLogger
|
|
19
|
+
from torchrl.weight_update.llm import get_model_metadata
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import ray
|
|
23
|
+
except ImportError:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"Ray is required for async training. Please install ray with `pip install ray`."
|
|
26
|
+
)
|
|
27
|
+
import torch
|
|
28
|
+
import tqdm
|
|
29
|
+
|
|
30
|
+
from grpo_utils import (
|
|
31
|
+
add_kl_transforms_to_replay_buffer,
|
|
32
|
+
check_grpo_dependencies,
|
|
33
|
+
compute_device_allocation,
|
|
34
|
+
get_inference_model,
|
|
35
|
+
get_train_model,
|
|
36
|
+
log_training_metrics,
|
|
37
|
+
make_env,
|
|
38
|
+
make_weight_sync_scheme,
|
|
39
|
+
)
|
|
40
|
+
from omegaconf import DictConfig
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
from tensordict import set_list_to_stack
|
|
44
|
+
except ImportError:
|
|
45
|
+
raise ImportError(
|
|
46
|
+
"TensorDict is required. Please install it with `pip install tensordict`."
|
|
47
|
+
)
|
|
48
|
+
from torch.amp.autocast_mode import autocast
|
|
49
|
+
from torch.amp.grad_scaler import GradScaler
|
|
50
|
+
from torchrl._utils import timeit
|
|
51
|
+
from torchrl.collectors.llm import RayLLMCollector
|
|
52
|
+
from torchrl.data import LazyStackStorage, ReplayBuffer
|
|
53
|
+
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
|
|
54
|
+
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def setup_environment() -> None:
|
|
58
|
+
"""Setup required environment variables and configurations."""
|
|
59
|
+
if not torch.cuda.is_available():
|
|
60
|
+
raise RuntimeError("CUDA is required for training")
|
|
61
|
+
|
|
62
|
+
# Set default dtype to float32 for mixed precision training
|
|
63
|
+
torch.set_default_dtype(torch.float32)
|
|
64
|
+
torch.set_default_device("cuda:0")
|
|
65
|
+
set_list_to_stack(True).set()
|
|
66
|
+
|
|
67
|
+
# Ensure CUDA is using the correct dtype
|
|
68
|
+
if torch.cuda.is_available():
|
|
69
|
+
torch.cuda.set_device("cuda:0")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def train(
|
|
73
|
+
replay_buffer: ReplayBuffer,
|
|
74
|
+
cfg: DictConfig,
|
|
75
|
+
collectors: list[RayLLMCollector],
|
|
76
|
+
inference_policy,
|
|
77
|
+
devices: list[int] | None = None,
|
|
78
|
+
):
|
|
79
|
+
"""Main training loop for GRPO async.
|
|
80
|
+
|
|
81
|
+
This function implements asynchronous training where data collection and optimization
|
|
82
|
+
happen concurrently. The total number of steps is determined by the number of epochs,
|
|
83
|
+
samples per epoch, and batches collected.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
replay_buffer: The replay buffer to store experiences
|
|
87
|
+
cfg: The configuration object containing training parameters
|
|
88
|
+
collectors: The collectors objects.
|
|
89
|
+
devices: The devices to use for the training model.
|
|
90
|
+
"""
|
|
91
|
+
# Setup training model and tokenizer
|
|
92
|
+
policy_training, train_tokenizer = get_train_model(cfg, devices=devices)
|
|
93
|
+
train_device = torch.device(f"cuda:{devices[0]}" if devices else "cuda:0")
|
|
94
|
+
|
|
95
|
+
# Setup loss function
|
|
96
|
+
loss_fn = GRPOLoss(
|
|
97
|
+
actor_network=policy_training,
|
|
98
|
+
kl_to_ref_coeff=cfg.train.kl_to_ref_coeff
|
|
99
|
+
if (cfg.train.kl_coef_in_loss and cfg.train.use_kl_to_ref)
|
|
100
|
+
else 0.0,
|
|
101
|
+
kl_to_inference_coeff=cfg.train.kl_to_inference_coeff,
|
|
102
|
+
entropy_coeff=cfg.train.entropy_coeff,
|
|
103
|
+
masking_strategy="rlhf" if cfg.env.reasoning else "sft",
|
|
104
|
+
device=train_device,
|
|
105
|
+
)
|
|
106
|
+
if cfg.env.reasoning:
|
|
107
|
+
# TODO: this is clunky, we should find a way to do this more naturally
|
|
108
|
+
loss_fn.set_keys(sample_log_prob=("next", "log_probs", "full"))
|
|
109
|
+
if cfg.model.compile:
|
|
110
|
+
loss_fn = torch.compile(loss_fn)
|
|
111
|
+
|
|
112
|
+
vllm_engine = inference_policy.model
|
|
113
|
+
|
|
114
|
+
# Create weight sync scheme for the collectors
|
|
115
|
+
weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine)
|
|
116
|
+
|
|
117
|
+
# Set up weight sync scheme for collectors
|
|
118
|
+
# Note: We need to get the sender after the collectors are created
|
|
119
|
+
# For now, we'll update the collectors to use the scheme
|
|
120
|
+
torchrl_logger.info("Setting up weight synchronization scheme...")
|
|
121
|
+
|
|
122
|
+
# We'll need to manually set up the sender since collectors were already created
|
|
123
|
+
# without the scheme. In production, collectors should be created with weight_sync_schemes parameter.
|
|
124
|
+
sender = weight_sync_scheme.create_sender()
|
|
125
|
+
sender.register_model(policy_training)
|
|
126
|
+
|
|
127
|
+
# Initialize collective group
|
|
128
|
+
torchrl_logger.info("Initializing collective group...")
|
|
129
|
+
metadata = get_model_metadata(policy_training)
|
|
130
|
+
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
|
|
131
|
+
|
|
132
|
+
# First weight update
|
|
133
|
+
with timeit("update_policy_weights"):
|
|
134
|
+
sender.update_weights()
|
|
135
|
+
torchrl_logger.info("Completed first update_policy_weights. Starting collectors...")
|
|
136
|
+
timeit.print(prefix="First update_policy_weights_ time")
|
|
137
|
+
timeit.reset()
|
|
138
|
+
|
|
139
|
+
for i, collector in enumerate(collectors):
|
|
140
|
+
torchrl_logger.info(f"Starting collector {i}...")
|
|
141
|
+
collector.start()
|
|
142
|
+
|
|
143
|
+
while not replay_buffer.write_count:
|
|
144
|
+
torchrl_logger.info("Waiting for replay buffer...")
|
|
145
|
+
time.sleep(1)
|
|
146
|
+
|
|
147
|
+
# Make optimizer
|
|
148
|
+
optimizer = torch.optim.Adam(
|
|
149
|
+
policy_training.parameters(),
|
|
150
|
+
lr=cfg.optimizer.lr,
|
|
151
|
+
weight_decay=cfg.optimizer.weight_decay,
|
|
152
|
+
fused=False,
|
|
153
|
+
)
|
|
154
|
+
scaler = GradScaler(enabled=cfg.train.mixed_precision)
|
|
155
|
+
|
|
156
|
+
# Make checkpoint dir
|
|
157
|
+
checkpoint_dir = Path(cfg.logging.checkpoint_dir)
|
|
158
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
159
|
+
|
|
160
|
+
# Make wandb logger
|
|
161
|
+
experiment_name = cfg.logging.experiment_name
|
|
162
|
+
if experiment_name is not None:
|
|
163
|
+
experiment_name = [experiment_name]
|
|
164
|
+
else:
|
|
165
|
+
experiment_name = []
|
|
166
|
+
|
|
167
|
+
experiment_name.append(cfg.env.dataset)
|
|
168
|
+
experiment_name.append(cfg.model.name)
|
|
169
|
+
wandb_logger = WandbLogger(
|
|
170
|
+
project="grpo-async", exp_name="-".join(["grpo-async"] + experiment_name)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Training loop
|
|
174
|
+
total_steps = (
|
|
175
|
+
-(cfg.train.total_dialog_turns // -cfg.train.optim_batch_size)
|
|
176
|
+
* cfg.train.epochs
|
|
177
|
+
)
|
|
178
|
+
torchrl_logger.info(f"Total steps: {total_steps}")
|
|
179
|
+
|
|
180
|
+
pbar = tqdm.tqdm(total=total_steps)
|
|
181
|
+
grad_norm = 0.0 # Initialize grad_norm
|
|
182
|
+
data_read_count = 0
|
|
183
|
+
start_time = time.time()
|
|
184
|
+
|
|
185
|
+
for step in range(total_steps):
|
|
186
|
+
if not any(collector.is_running() for collector in collectors):
|
|
187
|
+
torchrl_logger.info("Collectors stopped, stopping training")
|
|
188
|
+
break
|
|
189
|
+
pbar.update(1)
|
|
190
|
+
pbar.set_description(f"Step {step}, writes: {replay_buffer.write_count}")
|
|
191
|
+
|
|
192
|
+
with timeit("sampling"):
|
|
193
|
+
# Sample the correct batch size for gradient accumulation
|
|
194
|
+
# The replay buffer is configured with batch_size = optim_batch_size // gradient_accumulation_steps
|
|
195
|
+
# So we should sample that amount per step, not the full optim_batch_size
|
|
196
|
+
batch_size_per_step = (
|
|
197
|
+
cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps
|
|
198
|
+
)
|
|
199
|
+
batch = replay_buffer.sample(batch_size_per_step).to(train_device)
|
|
200
|
+
history: History = batch.view(-1)[0]["history", "full"]
|
|
201
|
+
history_str: list[str] | str = history.apply_chat_template(
|
|
202
|
+
tokenizer=train_tokenizer
|
|
203
|
+
)
|
|
204
|
+
while not isinstance(history_str, str):
|
|
205
|
+
history_str = "\n".join(history_str)
|
|
206
|
+
|
|
207
|
+
data_read_count += batch.numel()
|
|
208
|
+
|
|
209
|
+
with timeit("forward_pass"):
|
|
210
|
+
with autocast("cuda", enabled=cfg.train.mixed_precision):
|
|
211
|
+
loss = loss_fn(batch)
|
|
212
|
+
loss_val = (
|
|
213
|
+
loss.mean(reduce=True) / cfg.train.gradient_accumulation_steps
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
with timeit("backward_pass"):
|
|
217
|
+
if cfg.train.mixed_precision and cfg.train_model.torch_dtype == "float16":
|
|
218
|
+
scaler = GradScaler(enabled=True)
|
|
219
|
+
scaler.scale(loss_val).backward()
|
|
220
|
+
else:
|
|
221
|
+
loss_val.backward()
|
|
222
|
+
|
|
223
|
+
if (step + 1) % cfg.train.gradient_accumulation_steps == 0:
|
|
224
|
+
with timeit("optim_step"):
|
|
225
|
+
if (
|
|
226
|
+
cfg.train.mixed_precision
|
|
227
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
228
|
+
):
|
|
229
|
+
scaler.unscale_(optimizer)
|
|
230
|
+
|
|
231
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
232
|
+
policy_training.parameters(),
|
|
233
|
+
cfg.optimizer.clip_grad_norm,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if (
|
|
237
|
+
cfg.train.mixed_precision
|
|
238
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
239
|
+
):
|
|
240
|
+
scaler.step(optimizer)
|
|
241
|
+
scaler.update()
|
|
242
|
+
else:
|
|
243
|
+
optimizer.step()
|
|
244
|
+
optimizer.zero_grad(set_to_none=True)
|
|
245
|
+
|
|
246
|
+
if (step % cfg.train.logging_frequency) == 0:
|
|
247
|
+
log_training_metrics(
|
|
248
|
+
wandb_logger=wandb_logger,
|
|
249
|
+
replay_buffer=replay_buffer,
|
|
250
|
+
batch=batch,
|
|
251
|
+
loss=loss,
|
|
252
|
+
grad_norm=grad_norm,
|
|
253
|
+
global_step=step,
|
|
254
|
+
data_read_count=data_read_count,
|
|
255
|
+
collector=collectors[0],
|
|
256
|
+
start_time=start_time,
|
|
257
|
+
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
|
|
258
|
+
history_str=history_str,
|
|
259
|
+
use_kl_to_ref=cfg.train.use_kl_to_ref,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if step % cfg.train.weight_update_frequency == 0:
|
|
263
|
+
with timeit("update_policy_weights"):
|
|
264
|
+
torchrl_logger.info("Updating policy weights...")
|
|
265
|
+
sender.update_weights()
|
|
266
|
+
# TODO: do we need this? Does it interfere with other processes?
|
|
267
|
+
# torch.cuda.empty_cache()
|
|
268
|
+
gc.collect()
|
|
269
|
+
|
|
270
|
+
# Checkpointing disabled to prevent disk space issues
|
|
271
|
+
# if (step + 1) % cfg.train.checkpoint_frequency == 0:
|
|
272
|
+
# with timeit("save_checkpoint"):
|
|
273
|
+
# torchrl_logger.info(
|
|
274
|
+
# f"Saving checkpoint {(step+1) // cfg.train.checkpoint_frequency}..."
|
|
275
|
+
# )
|
|
276
|
+
# checkpoint = {
|
|
277
|
+
# "step": step,
|
|
278
|
+
# "model_state_dict": policy_training.model.state_dict(),
|
|
279
|
+
# "optimizer_state_dict": optimizer.state_dict(),
|
|
280
|
+
# "scaler_state_dict": scaler.state_dict(),
|
|
281
|
+
# "config": dict(cfg),
|
|
282
|
+
# }
|
|
283
|
+
# torch.save(checkpoint, checkpoint_dir / f"checkpoint_{step:04d}.pt")
|
|
284
|
+
|
|
285
|
+
if step % cfg.train.weight_update_frequency == 0:
|
|
286
|
+
timeit.print(prefix="timeit")
|
|
287
|
+
for key, val in timeit.todict().items():
|
|
288
|
+
wandb_logger.log_scalar(f"timeit/{key}", val)
|
|
289
|
+
timeit.reset()
|
|
290
|
+
|
|
291
|
+
del loss_val
|
|
292
|
+
# TODO: do we need this? Does it interfere with other processes?
|
|
293
|
+
# torch.cuda.empty_cache()
|
|
294
|
+
gc.collect()
|
|
295
|
+
|
|
296
|
+
pbar.close()
|
|
297
|
+
collector.shutdown()
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@hydra.main(version_base=None, config_path="config", config_name="grpo_gsm8k")
|
|
301
|
+
def main(cfg):
|
|
302
|
+
# Check for required GRPO dependencies
|
|
303
|
+
check_grpo_dependencies()
|
|
304
|
+
|
|
305
|
+
# Force async mode
|
|
306
|
+
if cfg.train.sync:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
"grpo-async.py must run in async mode (`python grpo-async.py mode=async`). Please use grpo-sync.py for sync mode (`python grpo-sync.py mode=sync`)."
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Compute device allocation
|
|
312
|
+
device_config = compute_device_allocation(cfg)
|
|
313
|
+
|
|
314
|
+
if not ray.is_initialized():
|
|
315
|
+
# Convert OmegaConf to regular dict and filter out unsupported parameters
|
|
316
|
+
ray_init_config = {
|
|
317
|
+
k: dict(v) if isinstance(v, DictConfig) else v
|
|
318
|
+
for k, v in dict(cfg.ray.init_config).items()
|
|
319
|
+
if not k.startswith("_")
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
# Add computed GPU configuration and merge with default runtime_env
|
|
323
|
+
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
|
|
324
|
+
ray_init_config = merge_ray_runtime_env(ray_init_config)
|
|
325
|
+
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
|
|
326
|
+
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
|
|
327
|
+
if ray_managed_externally:
|
|
328
|
+
ray.init(address="auto")
|
|
329
|
+
else:
|
|
330
|
+
ray.init(**ray_init_config)
|
|
331
|
+
|
|
332
|
+
# Check if num_devices is set
|
|
333
|
+
if cfg.inference_model.num_devices is None:
|
|
334
|
+
raise ValueError(
|
|
335
|
+
"Inference model num_devices must be set via inference_model.num_devices"
|
|
336
|
+
)
|
|
337
|
+
if cfg.train.use_kl_to_ref and cfg.ref_model.num_devices is None:
|
|
338
|
+
raise ValueError(
|
|
339
|
+
"Ref model num_devices must be set via ref_model.num_devices when use_kl_to_ref is True"
|
|
340
|
+
)
|
|
341
|
+
if cfg.train_model.num_devices is None:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
"Train model num_devices must be set via train_model.num_devices"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Convert OmegaConf to regular dict for Ray configs
|
|
347
|
+
replay_buffer_config = dict(cfg.ray.replay_buffer_config)
|
|
348
|
+
collector_config = dict(cfg.ray.collector_config)
|
|
349
|
+
train_handler_config = dict(cfg.ray.train_handler_config)
|
|
350
|
+
|
|
351
|
+
inference_policy = get_inference_model(
|
|
352
|
+
cfg,
|
|
353
|
+
devices=device_config["inference_model_devices"],
|
|
354
|
+
)
|
|
355
|
+
torchrl_logger.info(f"Inference policy: {inference_policy}")
|
|
356
|
+
|
|
357
|
+
torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
|
|
358
|
+
if cfg.train.optim_batch_size % cfg.train.gradient_accumulation_steps != 0:
|
|
359
|
+
raise ValueError(
|
|
360
|
+
"optim_batch_size must be divisible by gradient_accumulation_steps"
|
|
361
|
+
)
|
|
362
|
+
rb = RayReplayBuffer(
|
|
363
|
+
storage=partial(
|
|
364
|
+
LazyStackStorage,
|
|
365
|
+
cfg.train.buffer_size
|
|
366
|
+
if cfg.train.buffer_size
|
|
367
|
+
else cfg.env.repeats * cfg.env.num_envs,
|
|
368
|
+
),
|
|
369
|
+
transform_factory=partial(MCAdvantage, grpo_size=cfg.env.repeats),
|
|
370
|
+
batch_size=max(
|
|
371
|
+
1, cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps
|
|
372
|
+
),
|
|
373
|
+
remote_config=replay_buffer_config,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
add_kl_transforms_to_replay_buffer(rb, cfg)
|
|
377
|
+
|
|
378
|
+
torchrl_logger.info(f"Replay buffer: {rb}")
|
|
379
|
+
|
|
380
|
+
collector_config["num_gpus"] = 0
|
|
381
|
+
collector_config["num_cpus"] = 2
|
|
382
|
+
torchrl_logger.info(f"Starting collector with {collector_config=}")
|
|
383
|
+
|
|
384
|
+
if cfg.train.sync_iter is not None:
|
|
385
|
+
raise ValueError("sync_iter is not supported in async mode.")
|
|
386
|
+
collectors = []
|
|
387
|
+
for i in tqdm.trange(cfg.env.num_envs, desc="Starting collectors"):
|
|
388
|
+
collector = RayLLMCollector(
|
|
389
|
+
env=partial(make_env, cfg, single_env=True),
|
|
390
|
+
policy=inference_policy,
|
|
391
|
+
dialog_turns_per_batch=cfg.train.dialog_turns_per_batch,
|
|
392
|
+
total_dialog_turns=cfg.train.total_dialog_turns,
|
|
393
|
+
replay_buffer=rb,
|
|
394
|
+
ray_init_config=None,
|
|
395
|
+
weight_updater=None,
|
|
396
|
+
track_policy_version=True,
|
|
397
|
+
remote_config=collector_config,
|
|
398
|
+
yield_only_last_steps=cfg.env.reasoning,
|
|
399
|
+
verbose=False,
|
|
400
|
+
)
|
|
401
|
+
collectors.append(collector)
|
|
402
|
+
if i == 0:
|
|
403
|
+
# wait for the first collector to initialize
|
|
404
|
+
ray.get(collector._collector.is_initialized.remote())
|
|
405
|
+
inits = []
|
|
406
|
+
for collector in tqdm.tqdm(
|
|
407
|
+
collectors[1:], desc="Checking collector initialization"
|
|
408
|
+
):
|
|
409
|
+
inits.append(collector._collector.is_initialized.remote())
|
|
410
|
+
ray.get(inits)
|
|
411
|
+
torchrl_logger.info("All collectors initialized")
|
|
412
|
+
|
|
413
|
+
train_handler_config = {
|
|
414
|
+
"num_cpus": train_handler_config.get("num_cpus", 1),
|
|
415
|
+
"num_gpus": cfg.train_model.num_devices,
|
|
416
|
+
}
|
|
417
|
+
torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
|
|
418
|
+
train_handler = ray.remote(
|
|
419
|
+
**train_handler_config,
|
|
420
|
+
)(train)
|
|
421
|
+
|
|
422
|
+
# launch training
|
|
423
|
+
ray.get(
|
|
424
|
+
train_handler.remote(
|
|
425
|
+
rb,
|
|
426
|
+
cfg,
|
|
427
|
+
collectors,
|
|
428
|
+
inference_policy,
|
|
429
|
+
devices=device_config["train_model_devices"],
|
|
430
|
+
)
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
if __name__ == "__main__":
|
|
435
|
+
# Setup environment
|
|
436
|
+
setup_environment()
|
|
437
|
+
main()
|