torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,843 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
import warnings
|
|
11
|
+
from typing import Any, Literal
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from omegaconf import DictConfig
|
|
15
|
+
from torch import device as torch_device, dtype as torch_dtype
|
|
16
|
+
|
|
17
|
+
from torchrl._utils import logger as torchrl_logger, timeit
|
|
18
|
+
from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
|
|
19
|
+
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
|
|
20
|
+
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
|
|
21
|
+
from torchrl.weight_update.llm import VLLMWeightSyncScheme
|
|
22
|
+
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
23
|
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def check_grpo_dependencies() -> None:
|
|
27
|
+
"""Check for required GRPO dependencies and provide helpful error messages.
|
|
28
|
+
|
|
29
|
+
This function checks for critical dependencies needed for GRPO training and
|
|
30
|
+
provides installation instructions for missing packages.
|
|
31
|
+
"""
|
|
32
|
+
missing_packages = []
|
|
33
|
+
missing_optional = []
|
|
34
|
+
|
|
35
|
+
# Core required packages
|
|
36
|
+
required_packages = {
|
|
37
|
+
"datasets": "pip install datasets",
|
|
38
|
+
"peft": "pip install peft",
|
|
39
|
+
"wandb": "pip install wandb",
|
|
40
|
+
"vllm": "pip install vllm",
|
|
41
|
+
"transformers": "pip install transformers",
|
|
42
|
+
"accelerate": "pip install accelerate",
|
|
43
|
+
"ray": "pip install ray",
|
|
44
|
+
"tqdm": "pip install tqdm",
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Optional but recommended packages
|
|
48
|
+
optional_packages = {
|
|
49
|
+
"flash_attn": "pip install flash-attn",
|
|
50
|
+
"bitsandbytes": "pip install bitsandbytes",
|
|
51
|
+
"xformers": "pip install xformers",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Check required packages
|
|
55
|
+
for package, install_cmd in required_packages.items():
|
|
56
|
+
try:
|
|
57
|
+
__import__(package)
|
|
58
|
+
except ImportError:
|
|
59
|
+
missing_packages.append((package, install_cmd))
|
|
60
|
+
|
|
61
|
+
# Check optional packages
|
|
62
|
+
for package, install_cmd in optional_packages.items():
|
|
63
|
+
try:
|
|
64
|
+
__import__(package)
|
|
65
|
+
except ImportError:
|
|
66
|
+
missing_optional.append((package, install_cmd))
|
|
67
|
+
|
|
68
|
+
# Report missing required packages
|
|
69
|
+
if missing_packages:
|
|
70
|
+
error_msg = (
|
|
71
|
+
"Missing required packages for GRPO training:\n"
|
|
72
|
+
+ "\n".join(f" - {pkg}: {cmd}" for pkg, cmd in missing_packages)
|
|
73
|
+
+ "\n\nYou can install all GRPO dependencies with:\n"
|
|
74
|
+
+ " pip install torchrl[grpo]\n"
|
|
75
|
+
+ "or install individual packages as shown above."
|
|
76
|
+
)
|
|
77
|
+
raise ImportError(error_msg)
|
|
78
|
+
|
|
79
|
+
# Report missing optional packages as warnings
|
|
80
|
+
if missing_optional:
|
|
81
|
+
warning_msg = (
|
|
82
|
+
"Missing optional packages that may improve GRPO performance:\n"
|
|
83
|
+
+ "\n".join(f" - {pkg}: {cmd}" for pkg, cmd in missing_optional)
|
|
84
|
+
+ "\n\nThese packages are optional but recommended for optimal performance."
|
|
85
|
+
)
|
|
86
|
+
warnings.warn(warning_msg, UserWarning, stacklevel=2)
|
|
87
|
+
|
|
88
|
+
torchrl_logger.info("✓ All required GRPO dependencies are available")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_tokenizer(cfg: DictConfig) -> PreTrainedTokenizer:
|
|
92
|
+
from transformers import AutoTokenizer
|
|
93
|
+
|
|
94
|
+
model_name = cfg.model.name
|
|
95
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
96
|
+
# tokenizer.eos_token = "<|im_end|>"
|
|
97
|
+
if tokenizer.pad_token == tokenizer.eos_token:
|
|
98
|
+
tokenizer.pad_token = "PAD"
|
|
99
|
+
tokenizer.padding_side = "left"
|
|
100
|
+
return tokenizer
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_train_model(
|
|
104
|
+
cfg: DictConfig,
|
|
105
|
+
devices: list[int] | None = None,
|
|
106
|
+
) -> tuple[TransformersWrapper, PreTrainedTokenizer]:
|
|
107
|
+
"""Creates and configures the training model with LoRA adapters.
|
|
108
|
+
|
|
109
|
+
This function initializes the main training model with LoRA adapters and other
|
|
110
|
+
training-specific configurations like gradient checkpointing. The model is wrapped
|
|
111
|
+
in a TransformersWrapper for policy training.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
cfg (DictConfig): The hydra configuration object containing model and training settings.
|
|
115
|
+
Expected to have train_model section with LoRA, quantization, and other
|
|
116
|
+
training-specific parameters.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
tuple[TransformersWrapper, PreTrainedTokenizer]:
|
|
120
|
+
- policy_training: The wrapped training model
|
|
121
|
+
- train_tokenizer: The tokenizer for the model
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
RuntimeError: If CUDA is not available or if device allocation fails
|
|
125
|
+
"""
|
|
126
|
+
torchrl_logger.info("Creating train model")
|
|
127
|
+
|
|
128
|
+
# Set model dtype explicitly
|
|
129
|
+
model_dtype = getattr(torch, cfg.train_model.torch_dtype)
|
|
130
|
+
|
|
131
|
+
# Get configured devices or default to [0]
|
|
132
|
+
train_devices = devices if devices is not None else [0]
|
|
133
|
+
|
|
134
|
+
# Create max_memory dict - set 0 memory for GPUs we don't want to use
|
|
135
|
+
max_memory = {}
|
|
136
|
+
for i in range(torch.cuda.device_count()):
|
|
137
|
+
if i in train_devices:
|
|
138
|
+
max_memory[i] = "24GiB" # Allow max memory for devices we want to use
|
|
139
|
+
else:
|
|
140
|
+
max_memory[i] = "0GiB" # No memory for other devices
|
|
141
|
+
max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
|
|
142
|
+
|
|
143
|
+
# Let HF handle distribution with max_memory
|
|
144
|
+
device_map = "balanced" if len(train_devices) > 1 else f"cuda:{train_devices[0]}"
|
|
145
|
+
train_model, train_tokenizer = get_hf_model(
|
|
146
|
+
cfg.model.name,
|
|
147
|
+
device_map=device_map,
|
|
148
|
+
max_memory=max_memory,
|
|
149
|
+
lora=cfg.train_model.lora.enabled,
|
|
150
|
+
lora_r=cfg.train_model.lora.r,
|
|
151
|
+
lora_alpha=cfg.train_model.lora.alpha,
|
|
152
|
+
lora_dropout=cfg.train_model.lora.dropout,
|
|
153
|
+
gradient_checkpointing=cfg.train_model.gradient_checkpointing,
|
|
154
|
+
quantize=cfg.train_model.quantization.enabled,
|
|
155
|
+
torch_dtype=model_dtype,
|
|
156
|
+
attn_implementation=cfg.train_model.attn_implementation,
|
|
157
|
+
compile=cfg.model.compile,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Force all model parameters to the same dtype
|
|
161
|
+
for param in train_model.parameters():
|
|
162
|
+
param.data = param.data.to(model_dtype)
|
|
163
|
+
|
|
164
|
+
policy_training = TransformersWrapper(
|
|
165
|
+
train_model,
|
|
166
|
+
tokenizer=train_tokenizer,
|
|
167
|
+
input_mode="tokens" if not cfg.env.reasoning else "history",
|
|
168
|
+
generate=False,
|
|
169
|
+
return_log_probs=True,
|
|
170
|
+
pad_output=False,
|
|
171
|
+
device=torch.device("cuda:0"),
|
|
172
|
+
# Enable packing when cfg.train.packing=True by disabling padding
|
|
173
|
+
pad_model_input=not cfg.train.packing,
|
|
174
|
+
)
|
|
175
|
+
# Ensure model stays in eval mode after wrapping
|
|
176
|
+
policy_training.model.eval()
|
|
177
|
+
policy_training.model.train(False)
|
|
178
|
+
return policy_training, train_tokenizer
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_inference_model(
|
|
182
|
+
cfg: DictConfig,
|
|
183
|
+
devices: list[int] | None = None,
|
|
184
|
+
make_ray_worker: bool = True,
|
|
185
|
+
tokenizer: PreTrainedTokenizer | None = None,
|
|
186
|
+
) -> vLLMWrapper:
|
|
187
|
+
"""Creates the vLLM-based inference model for fast generation.
|
|
188
|
+
|
|
189
|
+
This function initializes a vLLM model server for efficient inference and wraps
|
|
190
|
+
it in a vLLMWrapper for policy inference. vLLM provides optimized generation
|
|
191
|
+
with better throughput than standard HuggingFace generation.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
cfg (DictConfig): The hydra configuration object containing model settings.
|
|
195
|
+
Expected to have inference_model section with vLLM-specific parameters
|
|
196
|
+
like gpu_memory_utilization and generation settings.
|
|
197
|
+
devices (list[int], optional): The devices to use for the inference model. Default: `None`.
|
|
198
|
+
make_ray_worker (bool, optional): Whether to make a ray worker. Default: `True`.
|
|
199
|
+
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use with the inference model. Default: `None`.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
vLLMWrapper: The wrapped vLLM model ready for inference.
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
AssertionError: If the vLLM server or model initialization fails
|
|
206
|
+
"""
|
|
207
|
+
from torchrl.modules.llm.backends.vllm import AsyncVLLM
|
|
208
|
+
|
|
209
|
+
num_devices = cfg.inference_model.num_devices
|
|
210
|
+
if num_devices is None:
|
|
211
|
+
vllm_devices = devices if devices is not None else [1]
|
|
212
|
+
num_devices = len(vllm_devices)
|
|
213
|
+
else:
|
|
214
|
+
vllm_devices = None
|
|
215
|
+
torchrl_logger.info(
|
|
216
|
+
f"Creating AsyncVLLM inference model with num_devices={num_devices}, devices={vllm_devices}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
model_name = cfg.model.name
|
|
220
|
+
|
|
221
|
+
# Use AsyncVLLM for better performance and async processing
|
|
222
|
+
verbose = getattr(cfg.inference_model, "verbose", True)
|
|
223
|
+
compile_model = getattr(
|
|
224
|
+
cfg.inference_model, "compile", False
|
|
225
|
+
) # Disabled by default for GRPO
|
|
226
|
+
|
|
227
|
+
# Build parameters dict for AsyncVLLM with all config options
|
|
228
|
+
inference_params = {
|
|
229
|
+
"model_name": model_name,
|
|
230
|
+
"num_devices": 1,
|
|
231
|
+
"num_replicas": num_devices,
|
|
232
|
+
"gpu_memory_utilization": cfg.inference_model.gpu_memory_utilization,
|
|
233
|
+
"enforce_eager": cfg.inference_model.enforce_eager,
|
|
234
|
+
"verbose": verbose,
|
|
235
|
+
"compile": compile_model,
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
# CRITICAL FIX: Configure attention implementation to prevent Flash Attention errors
|
|
239
|
+
# vLLM doesn't accept attn_implementation directly through AsyncEngineArgs
|
|
240
|
+
# Instead, we set the VLLM_ATTENTION_BACKEND environment variable
|
|
241
|
+
if hasattr(cfg.inference_model, "attn_implementation"):
|
|
242
|
+
import os
|
|
243
|
+
|
|
244
|
+
attn_impl = cfg.inference_model.attn_implementation
|
|
245
|
+
|
|
246
|
+
# Map common attention implementations to vLLM backend names
|
|
247
|
+
attn_backend_map = {
|
|
248
|
+
"flash_attention_2": "FLASH_ATTN",
|
|
249
|
+
"flash_attn": "FLASH_ATTN",
|
|
250
|
+
"sdpa": "TORCH_SDPA",
|
|
251
|
+
"torch_sdpa": "TORCH_SDPA",
|
|
252
|
+
"xformers": "XFORMERS",
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
vllm_backend = attn_backend_map.get(attn_impl, attn_impl.upper())
|
|
256
|
+
os.environ["VLLM_ATTENTION_BACKEND"] = vllm_backend
|
|
257
|
+
|
|
258
|
+
torchrl_logger.info(
|
|
259
|
+
f"Setting VLLM_ATTENTION_BACKEND={vllm_backend} (from config: {attn_impl})"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Handle FP32 output configuration
|
|
263
|
+
if hasattr(cfg.inference_model, "enable_fp32_output"):
|
|
264
|
+
enable_fp32 = cfg.inference_model.enable_fp32_output
|
|
265
|
+
if enable_fp32:
|
|
266
|
+
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
|
|
267
|
+
torchrl_logger.info(
|
|
268
|
+
"Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
|
|
269
|
+
"This will use FP32 for the final output layer if the model supports it."
|
|
270
|
+
)
|
|
271
|
+
# Add to inference params so it gets passed to AsyncVLLM
|
|
272
|
+
inference_params["enable_fp32_output"] = enable_fp32
|
|
273
|
+
|
|
274
|
+
# Add other common vLLM parameters from config if present
|
|
275
|
+
optional_vllm_params = [
|
|
276
|
+
"max_model_len",
|
|
277
|
+
"dtype",
|
|
278
|
+
"trust_remote_code",
|
|
279
|
+
"seed",
|
|
280
|
+
"swap_space",
|
|
281
|
+
"cpu_offload_gb",
|
|
282
|
+
"enable_prefix_caching",
|
|
283
|
+
"tensor_parallel_size",
|
|
284
|
+
"pipeline_parallel_size",
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
for param in optional_vllm_params:
|
|
288
|
+
if hasattr(cfg.inference_model, param):
|
|
289
|
+
value = getattr(cfg.inference_model, param)
|
|
290
|
+
if value is not None:
|
|
291
|
+
inference_params[param] = value
|
|
292
|
+
|
|
293
|
+
# Handle torch_dtype specifically (convert string to torch dtype)
|
|
294
|
+
if hasattr(cfg.inference_model, "torch_dtype"):
|
|
295
|
+
dtype_str = cfg.inference_model.torch_dtype
|
|
296
|
+
if dtype_str is not None:
|
|
297
|
+
if isinstance(dtype_str, str):
|
|
298
|
+
inference_params["dtype"] = getattr(torch, dtype_str)
|
|
299
|
+
else:
|
|
300
|
+
inference_params["dtype"] = dtype_str
|
|
301
|
+
|
|
302
|
+
inference_server = AsyncVLLM.from_pretrained(**inference_params)
|
|
303
|
+
assert inference_server is not None
|
|
304
|
+
if tokenizer is None:
|
|
305
|
+
from transformers import AutoTokenizer
|
|
306
|
+
|
|
307
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
308
|
+
if tokenizer.pad_token == tokenizer.eos_token:
|
|
309
|
+
tokenizer.pad_token = "PAD"
|
|
310
|
+
tokenizer.padding_side = "left"
|
|
311
|
+
policy = vLLMWrapper(
|
|
312
|
+
inference_server,
|
|
313
|
+
input_mode="history",
|
|
314
|
+
chat_template_name="qwen",
|
|
315
|
+
return_log_probs=not cfg.env.reasoning,
|
|
316
|
+
tokenizer=tokenizer,
|
|
317
|
+
pad_output=False,
|
|
318
|
+
generate_kwargs={
|
|
319
|
+
"max_tokens": cfg.inference_model.max_tokens,
|
|
320
|
+
"include_stop_str_in_output": cfg.inference_model.include_stop_str_in_output,
|
|
321
|
+
"temperature": cfg.inference_model.temperature,
|
|
322
|
+
"top_p": cfg.inference_model.top_p,
|
|
323
|
+
},
|
|
324
|
+
)
|
|
325
|
+
assert policy.model is not None
|
|
326
|
+
return policy
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def get_ref_model(
|
|
330
|
+
cfg: DictConfig,
|
|
331
|
+
tokenizer: PreTrainedTokenizer,
|
|
332
|
+
devices: list[int] | None = None,
|
|
333
|
+
) -> TransformersWrapper:
|
|
334
|
+
"""Creates the reference model for KL penalty computation.
|
|
335
|
+
|
|
336
|
+
This function initializes a frozen copy of the base model to serve as the
|
|
337
|
+
reference model for KL divergence computation. The reference model is typically
|
|
338
|
+
quantized and does not require gradient computation.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
cfg (DictConfig): The hydra configuration object containing model settings.
|
|
342
|
+
Expected to have ref_model section with quantization and attention settings.
|
|
343
|
+
tokenizer (PreTrainedTokenizer): The tokenizer to use with the reference model.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
TransformersWrapper: The wrapped reference model in eval mode with detached weights.
|
|
347
|
+
"""
|
|
348
|
+
from tensordict import TensorDict
|
|
349
|
+
|
|
350
|
+
torchrl_logger.info("Creating ref model")
|
|
351
|
+
|
|
352
|
+
# Get configured devices or default to [2]
|
|
353
|
+
if cfg.ref_model.num_devices is None:
|
|
354
|
+
ref_devices = devices if devices is not None else [2]
|
|
355
|
+
else:
|
|
356
|
+
ref_devices = list(range(cfg.ref_model.num_devices))
|
|
357
|
+
|
|
358
|
+
# Create max_memory dict - set 0 memory for GPUs we don't want to use
|
|
359
|
+
max_memory = {}
|
|
360
|
+
for i in range(torch.cuda.device_count()):
|
|
361
|
+
if i in ref_devices:
|
|
362
|
+
max_memory[i] = "24GiB" # Allow max memory for devices we want to use
|
|
363
|
+
else:
|
|
364
|
+
max_memory[i] = "0GiB" # No memory for other devices
|
|
365
|
+
max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
|
|
366
|
+
|
|
367
|
+
# Let HF handle distribution with max_memory
|
|
368
|
+
device_map = "balanced" if len(ref_devices) > 1 else f"cuda:{ref_devices[0]}"
|
|
369
|
+
model_name = cfg.model.name
|
|
370
|
+
|
|
371
|
+
ref_model = get_hf_model(
|
|
372
|
+
model_name,
|
|
373
|
+
device_map=device_map,
|
|
374
|
+
max_memory=max_memory,
|
|
375
|
+
torch_dtype=getattr(torch, cfg.ref_model.torch_dtype),
|
|
376
|
+
quantize=cfg.ref_model.quantization.enabled,
|
|
377
|
+
gradient_checkpointing=cfg.ref_model.gradient_checkpointing,
|
|
378
|
+
attn_implementation=cfg.ref_model.attn_implementation,
|
|
379
|
+
lora=False, # Reference model doesn't need LoRA
|
|
380
|
+
requires_grad=False,
|
|
381
|
+
)[0].eval()
|
|
382
|
+
# Detach weights
|
|
383
|
+
TensorDict.from_module(ref_model).data.to_module(ref_model)
|
|
384
|
+
ref_model = TransformersWrapper(
|
|
385
|
+
ref_model,
|
|
386
|
+
input_mode="tokens" if not cfg.env.reasoning else "history",
|
|
387
|
+
tokenizer=tokenizer,
|
|
388
|
+
generate=False,
|
|
389
|
+
return_log_probs=True,
|
|
390
|
+
pad_output=False,
|
|
391
|
+
device=torch.device("cuda:0"),
|
|
392
|
+
)
|
|
393
|
+
return ref_model
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def get_hf_model(
|
|
397
|
+
model_name: str,
|
|
398
|
+
torch_dtype: torch_dtype = torch.float32,
|
|
399
|
+
lora_r: int = 8,
|
|
400
|
+
lora_alpha: int = 16,
|
|
401
|
+
lora_dropout: float = 0.1,
|
|
402
|
+
quantize: bool = False,
|
|
403
|
+
fsdp: str = "",
|
|
404
|
+
fsdp_config: Any = None,
|
|
405
|
+
gradient_checkpointing: bool = True,
|
|
406
|
+
device_map: str
|
|
407
|
+
| dict[str, int | str | torch_device]
|
|
408
|
+
| int
|
|
409
|
+
| torch_device
|
|
410
|
+
| None = None,
|
|
411
|
+
lora: bool = True,
|
|
412
|
+
attn_implementation: Literal["flash_attention_2", "flex_attention", "sdpa"]
|
|
413
|
+
| None = "flex_attention",
|
|
414
|
+
requires_grad: bool = True,
|
|
415
|
+
compile: bool = False,
|
|
416
|
+
max_memory: dict[str, str] | None = None,
|
|
417
|
+
) -> tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
|
|
418
|
+
"""Creates and configures a HuggingFace model with optional optimizations.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
model_name (str): HuggingFace model identifier (e.g., "Qwen/Qwen2.5-3B")
|
|
422
|
+
torch_dtype (torch.dtype, optional): Model precision. Default: torch.float32
|
|
423
|
+
lora_r (int, optional): LoRA rank - controls capacity of adaptations. Default: 8
|
|
424
|
+
lora_alpha (int, optional): LoRA alpha - scales the adaptations. Default: 16
|
|
425
|
+
lora_dropout (float, optional): Dropout probability for LoRA layers. Default: 0.1
|
|
426
|
+
quantize (bool, optional): Whether to enable 4-bit quantization. Default: False
|
|
427
|
+
fsdp (str, optional): Fully Sharded Data Parallel configuration. Default: ""
|
|
428
|
+
fsdp_config (Any, optional): Additional FSDP configurations. Default: None
|
|
429
|
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Default: True
|
|
430
|
+
device_map (str | dict | int | torch.device | None, optional): Device placement strategy. Default: None
|
|
431
|
+
lora (bool, optional): Whether to apply LoRA adapters. Default: True
|
|
432
|
+
attn_implementation (Literal["flash_attention_2", "flex_attention", "sdpa"] | None, optional):
|
|
433
|
+
Attention implementation to use. Default: "flex_attention"
|
|
434
|
+
requires_grad (bool, optional): Whether to enable gradient computation. Default: True
|
|
435
|
+
compile (bool, optional): Whether to enable model compilation. Default: False
|
|
436
|
+
max_memory (dict[str, str], optional): Memory configuration for distributed training. Default: {}
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
|
|
440
|
+
- model: The configured HuggingFace model
|
|
441
|
+
- tokenizer: The associated tokenizer
|
|
442
|
+
|
|
443
|
+
Raises:
|
|
444
|
+
ImportError: If required dependencies are not installed
|
|
445
|
+
RuntimeError: If model initialization fails
|
|
446
|
+
"""
|
|
447
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
448
|
+
|
|
449
|
+
if max_memory is None:
|
|
450
|
+
max_memory = {}
|
|
451
|
+
|
|
452
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
453
|
+
# tokenizer.eos_token = "<|im_end|>"
|
|
454
|
+
if tokenizer.pad_token == tokenizer.eos_token:
|
|
455
|
+
tokenizer.pad_token = "PAD"
|
|
456
|
+
tokenizer.padding_side = "left"
|
|
457
|
+
|
|
458
|
+
# Configure model settings for mixed precision
|
|
459
|
+
# Store original dtype to restore it later
|
|
460
|
+
original_dtype = torch.get_default_dtype()
|
|
461
|
+
torch.set_default_dtype(torch_dtype)
|
|
462
|
+
|
|
463
|
+
model_configs = {
|
|
464
|
+
"torch_dtype": torch_dtype,
|
|
465
|
+
"device_map": device_map if device_map is not None else "auto",
|
|
466
|
+
"max_memory": max_memory,
|
|
467
|
+
}
|
|
468
|
+
if torch.cuda.is_available() and attn_implementation:
|
|
469
|
+
torchrl_logger.info(f"{attn_implementation} init")
|
|
470
|
+
model_configs["attn_implementation"] = attn_implementation
|
|
471
|
+
|
|
472
|
+
try:
|
|
473
|
+
# Configure training settings based on FSDP usage
|
|
474
|
+
if fsdp != "" and fsdp_config is not None:
|
|
475
|
+
torchrl_logger.info("Configurations for FSDP")
|
|
476
|
+
bnb_config_params = {"bnb_4bit_quant_storage": torch_dtype}
|
|
477
|
+
else:
|
|
478
|
+
bnb_config_params = {}
|
|
479
|
+
|
|
480
|
+
# Enable Quantization
|
|
481
|
+
if quantize:
|
|
482
|
+
try:
|
|
483
|
+
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
484
|
+
except ImportError:
|
|
485
|
+
raise ImportError(
|
|
486
|
+
"Please install transformers with bitsandbytes support"
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
bnb_config = BitsAndBytesConfig(
|
|
490
|
+
load_in_4bit=True,
|
|
491
|
+
bnb_4bit_use_double_quant=True,
|
|
492
|
+
bnb_4bit_quant_type="nf4",
|
|
493
|
+
bnb_4bit_compute_dtype=torch_dtype,
|
|
494
|
+
**bnb_config_params,
|
|
495
|
+
)
|
|
496
|
+
model_configs["quantization_config"] = bnb_config
|
|
497
|
+
|
|
498
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
499
|
+
model_name,
|
|
500
|
+
trust_remote_code=True,
|
|
501
|
+
use_cache=not gradient_checkpointing,
|
|
502
|
+
cache_dir="/tmp/.cache",
|
|
503
|
+
**model_configs,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Configure gradient checkpointing based on FSDP usage
|
|
507
|
+
if fsdp == "" and fsdp_config is None:
|
|
508
|
+
if gradient_checkpointing:
|
|
509
|
+
torchrl_logger.info("gradient_checkpointing enabled")
|
|
510
|
+
model.gradient_checkpointing_enable()
|
|
511
|
+
else:
|
|
512
|
+
if gradient_checkpointing:
|
|
513
|
+
torchrl_logger.info("gradient_checkpointing enabled")
|
|
514
|
+
model.gradient_checkpointing_enable(
|
|
515
|
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
if lora:
|
|
519
|
+
try:
|
|
520
|
+
from peft import get_peft_model, LoraConfig
|
|
521
|
+
except ImportError:
|
|
522
|
+
raise ImportError("Please install peft: pip install peft")
|
|
523
|
+
|
|
524
|
+
# Create LoRA config with explicit dtype setting
|
|
525
|
+
lora_config = LoraConfig(
|
|
526
|
+
r=lora_r,
|
|
527
|
+
lora_alpha=lora_alpha,
|
|
528
|
+
target_modules="all-linear",
|
|
529
|
+
lora_dropout=0.0, # Disable dropout for RL training
|
|
530
|
+
bias="none",
|
|
531
|
+
task_type="CAUSAL_LM",
|
|
532
|
+
inference_mode=True, # Force inference mode for consistent behavior
|
|
533
|
+
init_lora_weights=True, # This ensures weights are initialized
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Initialize LoRA model
|
|
537
|
+
model = get_peft_model(
|
|
538
|
+
model,
|
|
539
|
+
lora_config,
|
|
540
|
+
autocast_adapter_dtype=False, # Prevent automatic casting of adapter layers
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Force LoRA layers to correct dtype and eval mode
|
|
544
|
+
for n, p in model.named_parameters():
|
|
545
|
+
if "lora_" in n: # Only convert LoRA parameters
|
|
546
|
+
p.data = p.data.to(torch_dtype)
|
|
547
|
+
|
|
548
|
+
model.eval() # Ensure model is in eval mode
|
|
549
|
+
if requires_grad:
|
|
550
|
+
model.requires_grad_(True)
|
|
551
|
+
|
|
552
|
+
return model, tokenizer
|
|
553
|
+
|
|
554
|
+
finally:
|
|
555
|
+
# Restore original dtype
|
|
556
|
+
torch.set_default_dtype(original_dtype)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def make_weight_sync_scheme(
|
|
560
|
+
vllm_engine,
|
|
561
|
+
) -> VLLMWeightSyncScheme:
|
|
562
|
+
"""Creates a vLLM weight synchronization scheme using NCCL collectives.
|
|
563
|
+
|
|
564
|
+
This function creates a weight sync scheme that uses NCCL for high-performance
|
|
565
|
+
GPU-to-GPU weight transfers from the training model to vLLM inference workers.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
vllm_engine: A vLLM engine implementing the RLvLLMEngine interface
|
|
569
|
+
(like RayLLMWorker, LocalLLMWrapper, or AsyncVLLM).
|
|
570
|
+
This is typically obtained from the inference policy's model attribute.
|
|
571
|
+
|
|
572
|
+
Returns:
|
|
573
|
+
VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
|
|
574
|
+
"""
|
|
575
|
+
# Get configuration from the vLLM engine
|
|
576
|
+
tp_size = vllm_engine.get_tp_size()
|
|
577
|
+
num_replicas = getattr(vllm_engine, "num_replicas", 1)
|
|
578
|
+
master_address = vllm_engine.get_master_address()
|
|
579
|
+
master_port = vllm_engine.get_master_port()
|
|
580
|
+
|
|
581
|
+
torchrl_logger.info(
|
|
582
|
+
f"Creating VLLMWeightSyncScheme with tp_size={tp_size}, "
|
|
583
|
+
f"num_replicas={num_replicas}, master_address={master_address}, "
|
|
584
|
+
f"master_port={master_port}"
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
return VLLMWeightSyncScheme(
|
|
588
|
+
master_address=master_address,
|
|
589
|
+
master_port=master_port,
|
|
590
|
+
gpus_per_replica=tp_size,
|
|
591
|
+
num_replicas=num_replicas,
|
|
592
|
+
strategy="state_dict",
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def compute_device_allocation(cfg):
|
|
597
|
+
"""Compute device allocations and Ray GPU config.
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
cfg: The configuration object
|
|
601
|
+
|
|
602
|
+
Returns:
|
|
603
|
+
dict: Updated device configuration containing:
|
|
604
|
+
- train_model_devices: list of devices for training
|
|
605
|
+
- inference_model_devices: list of devices for inference
|
|
606
|
+
- ray_num_gpus: number of GPUs to tell Ray about
|
|
607
|
+
- cuda_visible_devices: string for CUDA_VISIBLE_DEVICES
|
|
608
|
+
"""
|
|
609
|
+
train_devices = cfg.train_model.num_devices
|
|
610
|
+
inf_devices = cfg.inference_model.num_devices
|
|
611
|
+
|
|
612
|
+
train_start = 0
|
|
613
|
+
train_end = train_devices
|
|
614
|
+
inference_start = 0
|
|
615
|
+
inference_end = inf_devices
|
|
616
|
+
|
|
617
|
+
ref_devices = cfg.ref_model.num_devices if cfg.train.use_kl_to_ref else 0
|
|
618
|
+
ray_num_gpus = train_devices + inf_devices + ref_devices
|
|
619
|
+
|
|
620
|
+
train_model_devices = list(range(train_start, train_end))
|
|
621
|
+
inference_model_devices = list(range(inference_start, inference_end))
|
|
622
|
+
|
|
623
|
+
all_devices = sorted(set(train_model_devices + inference_model_devices))
|
|
624
|
+
if cfg.train.use_kl_to_ref:
|
|
625
|
+
ref_device_start = max(all_devices) + 1 if all_devices else 0
|
|
626
|
+
ref_devices_list = list(range(ref_device_start, ref_device_start + ref_devices))
|
|
627
|
+
all_devices.extend(ref_devices_list)
|
|
628
|
+
cuda_visible_devices = ",".join(map(str, all_devices))
|
|
629
|
+
|
|
630
|
+
return {
|
|
631
|
+
"train_model_devices": train_model_devices,
|
|
632
|
+
"inference_model_devices": inference_model_devices,
|
|
633
|
+
"ray_num_gpus": ray_num_gpus,
|
|
634
|
+
"cuda_visible_devices": cuda_visible_devices,
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def make_env(cfg: DictConfig, single_env: bool = False):
|
|
639
|
+
"""Create the environment.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
cfg: The configuration object
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
The configured environment
|
|
646
|
+
"""
|
|
647
|
+
train_tokenizer = get_tokenizer(cfg)
|
|
648
|
+
|
|
649
|
+
# Setup environment
|
|
650
|
+
max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
|
|
651
|
+
if cfg.env.dataset == "gsm8k":
|
|
652
|
+
# Reward scale is 0.0 to 100
|
|
653
|
+
reward_threshold = 20
|
|
654
|
+
env = GSM8KEnv(
|
|
655
|
+
repeats=cfg.env.repeats,
|
|
656
|
+
tokenizer=train_tokenizer,
|
|
657
|
+
num_envs=cfg.env.num_envs if not single_env else 1,
|
|
658
|
+
max_steps=max_steps,
|
|
659
|
+
device=torch.device("cpu"),
|
|
660
|
+
ray_backend=True,
|
|
661
|
+
)
|
|
662
|
+
elif cfg.env.dataset == "ifeval": # ifeval
|
|
663
|
+
# Reward scale is 0.0 to 2.2
|
|
664
|
+
reward_threshold = 1.0
|
|
665
|
+
env = IFEvalEnv(
|
|
666
|
+
repeats=cfg.env.repeats,
|
|
667
|
+
tokenizer=train_tokenizer,
|
|
668
|
+
num_envs=cfg.env.num_envs if not single_env else 1,
|
|
669
|
+
max_steps=max_steps,
|
|
670
|
+
device=torch.device("cpu"),
|
|
671
|
+
ray_backend=True,
|
|
672
|
+
)
|
|
673
|
+
else:
|
|
674
|
+
raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")
|
|
675
|
+
|
|
676
|
+
if cfg.env.reasoning:
|
|
677
|
+
env = env.append_transform(
|
|
678
|
+
AddThinkingPrompt(
|
|
679
|
+
cond=lambda td, reward_threshol=reward_threshold, max_steps=max_steps: td[
|
|
680
|
+
"reward"
|
|
681
|
+
]
|
|
682
|
+
<= reward_threshold
|
|
683
|
+
and td["step_count"] < max_steps,
|
|
684
|
+
role="user",
|
|
685
|
+
edit_last_turn=False,
|
|
686
|
+
zero_reward=False,
|
|
687
|
+
undo_done=True,
|
|
688
|
+
random_prompt=True,
|
|
689
|
+
),
|
|
690
|
+
)
|
|
691
|
+
return env
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def make_ref_model_factory(cfg: DictConfig) -> functools.partial | None:
|
|
695
|
+
"""Create a factory for the reference model if KL to ref is enabled.
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
cfg: The configuration object
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
A partial function that creates the reference model, or None if KL to ref is disabled
|
|
702
|
+
"""
|
|
703
|
+
if not cfg.train.use_kl_to_ref:
|
|
704
|
+
return None
|
|
705
|
+
|
|
706
|
+
train_tokenizer = get_tokenizer(cfg)
|
|
707
|
+
ref_cfg = DictConfig(dict(cfg))
|
|
708
|
+
ref_model_factory = functools.partial(
|
|
709
|
+
get_ref_model,
|
|
710
|
+
ref_cfg,
|
|
711
|
+
train_tokenizer,
|
|
712
|
+
devices=[0],
|
|
713
|
+
)
|
|
714
|
+
return ref_model_factory
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def add_kl_transforms_to_replay_buffer(replay_buffer, cfg: DictConfig):
|
|
718
|
+
"""Add KL transforms to replay buffer.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
replay_buffer: The replay buffer to add transforms to
|
|
722
|
+
cfg: The configuration object
|
|
723
|
+
"""
|
|
724
|
+
if not cfg.train.use_kl_to_ref:
|
|
725
|
+
return
|
|
726
|
+
|
|
727
|
+
ref_model_factory = make_ref_model_factory(cfg)
|
|
728
|
+
if ref_model_factory is None:
|
|
729
|
+
return
|
|
730
|
+
|
|
731
|
+
if cfg.env.reasoning:
|
|
732
|
+
kl_transform = RetrieveKL(
|
|
733
|
+
ref_model_factory=ref_model_factory,
|
|
734
|
+
add_to_reward=not cfg.train.kl_coef_in_loss,
|
|
735
|
+
coeff=cfg.train.kl_to_ref_coeff,
|
|
736
|
+
use_ray_service=True,
|
|
737
|
+
)
|
|
738
|
+
else:
|
|
739
|
+
kl_transform = KLRewardTransform(
|
|
740
|
+
ref_model_factory=ref_model_factory,
|
|
741
|
+
coef=cfg.train.kl_to_ref_coeff,
|
|
742
|
+
add_to_reward=not cfg.train.kl_coef_in_loss,
|
|
743
|
+
device=torch.device("cuda:0"),
|
|
744
|
+
use_ray_service=True,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
replay_buffer.append_transform(kl_transform, invert=True)
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
@timeit("Logging metrics")
|
|
751
|
+
def log_training_metrics(
|
|
752
|
+
wandb_logger,
|
|
753
|
+
replay_buffer,
|
|
754
|
+
batch,
|
|
755
|
+
loss,
|
|
756
|
+
grad_norm,
|
|
757
|
+
global_step,
|
|
758
|
+
data_read_count,
|
|
759
|
+
collector,
|
|
760
|
+
start_time,
|
|
761
|
+
gradient_accumulation_steps,
|
|
762
|
+
history_str=None,
|
|
763
|
+
use_kl_to_ref=True,
|
|
764
|
+
):
|
|
765
|
+
"""Log training metrics to wandb.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
wandb_logger: The wandb logger instance
|
|
769
|
+
replay_buffer: The replay buffer containing collected data
|
|
770
|
+
batch: The current training batch
|
|
771
|
+
loss: The computed loss object
|
|
772
|
+
grad_norm: The gradient norm value
|
|
773
|
+
global_step: Current global training step
|
|
774
|
+
data_read_count: Total data read count
|
|
775
|
+
collector: The collector instance
|
|
776
|
+
start_time: Training start time
|
|
777
|
+
gradient_accumulation_steps: Number of gradient accumulation steps
|
|
778
|
+
history_str: Optional history string for logging
|
|
779
|
+
"""
|
|
780
|
+
with torch.no_grad():
|
|
781
|
+
rb_content = replay_buffer[:]
|
|
782
|
+
step_count = rb_content.get(("next", "step_count")).view(-1).float().mean()
|
|
783
|
+
batch_policy_version = batch["next", "policy_version"].view(-1).min()
|
|
784
|
+
batch_policy_age = collector.policy_version - batch_policy_version
|
|
785
|
+
|
|
786
|
+
metrics = {
|
|
787
|
+
"step_count from buffer": float(step_count),
|
|
788
|
+
"reward from buffer": float(
|
|
789
|
+
torch.cat(rb_content.get(("next", "reward"), as_list=True)).mean()
|
|
790
|
+
),
|
|
791
|
+
"seq_length from buffer": float(
|
|
792
|
+
torch.tensor(
|
|
793
|
+
[
|
|
794
|
+
t.numel()
|
|
795
|
+
for t in rb_content.get(("tokens", "response"), as_list=True)
|
|
796
|
+
],
|
|
797
|
+
dtype=torch.float,
|
|
798
|
+
).mean()
|
|
799
|
+
),
|
|
800
|
+
"ESS, from loss": float(loss.ESS),
|
|
801
|
+
"loss_objective, from loss": float(loss.loss_objective),
|
|
802
|
+
"clip_fraction, from loss": float(loss.clip_fraction),
|
|
803
|
+
"kl_approx (train to inference), from loss": float(loss.kl_approx),
|
|
804
|
+
"kl_to_inference (train to inference - differentiable), from loss": float(
|
|
805
|
+
loss.kl_to_inference.mean()
|
|
806
|
+
),
|
|
807
|
+
"loss_kl_to_inference, from loss": float(loss.loss_kl_to_inference.mean()),
|
|
808
|
+
"entropy loss, from loss": float(loss.loss_entropy.mean()),
|
|
809
|
+
"grad_norm": float(grad_norm)
|
|
810
|
+
if global_step % gradient_accumulation_steps == 0
|
|
811
|
+
else 0.0,
|
|
812
|
+
"write_count, from buffer": int(replay_buffer.write_count),
|
|
813
|
+
# how many gradient steps per write
|
|
814
|
+
"gradient_step_throughput (gradient step per write)": float(
|
|
815
|
+
global_step / replay_buffer.write_count
|
|
816
|
+
),
|
|
817
|
+
# how many optim steps per write
|
|
818
|
+
"optim_step_throughput (optim step per write)": float(
|
|
819
|
+
(global_step // gradient_accumulation_steps) / replay_buffer.write_count
|
|
820
|
+
),
|
|
821
|
+
"data_read_count (total)": data_read_count,
|
|
822
|
+
"current_policy_version (collector)": collector.policy_version,
|
|
823
|
+
# FIXME: Assume batch is a single trajectory
|
|
824
|
+
# FIXME: The addition of the transform after the env instantiation + _shuttle creation
|
|
825
|
+
# is messed up - we need the next data
|
|
826
|
+
"batch_policy_version (sampled batch)": batch_policy_version,
|
|
827
|
+
"batch_policy_age (sampled batch)": batch_policy_age,
|
|
828
|
+
"throughput (steps per second)": float(
|
|
829
|
+
global_step / (time.time() - start_time)
|
|
830
|
+
),
|
|
831
|
+
}
|
|
832
|
+
if use_kl_to_ref:
|
|
833
|
+
metrics["kl_penalty (inference to ref) from buffer"] = float(
|
|
834
|
+
torch.cat(rb_content.get(("next", "kl_penalty"), as_list=True)).mean()
|
|
835
|
+
)
|
|
836
|
+
metrics["kl_to_ref, from loss"] = float(loss.kl_to_ref.mean())
|
|
837
|
+
metrics["loss_kl_to_ref, from loss"] = float(loss.loss_kl_to_ref.mean())
|
|
838
|
+
|
|
839
|
+
for name, value in metrics.items():
|
|
840
|
+
wandb_logger.log_scalar(name, value, step=global_step)
|
|
841
|
+
|
|
842
|
+
if history_str is not None:
|
|
843
|
+
wandb_logger.log_str("history", history_str, step=global_step)
|