torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,770 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any, Literal
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
|
|
13
|
+
from tensordict import TensorDict
|
|
14
|
+
from torch import device as torch_device, dtype as torch_dtype
|
|
15
|
+
|
|
16
|
+
from torchrl._utils import logger as torchrl_logger
|
|
17
|
+
from torchrl.envs.llm import RetrieveLogProb
|
|
18
|
+
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
|
|
19
|
+
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
|
|
20
|
+
from torchrl.weight_update.llm import VLLMWeightSyncScheme
|
|
21
|
+
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
22
|
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import ray
|
|
26
|
+
except ImportError:
|
|
27
|
+
ray = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_tokenizer(cfg: DictConfig) -> PreTrainedTokenizer:
|
|
31
|
+
from transformers import AutoTokenizer
|
|
32
|
+
|
|
33
|
+
model_name = cfg.model.name
|
|
34
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
35
|
+
# tokenizer.eos_token = "<|im_end|>"
|
|
36
|
+
if tokenizer.pad_token == tokenizer.eos_token:
|
|
37
|
+
tokenizer.pad_token = "PAD"
|
|
38
|
+
tokenizer.padding_side = "left"
|
|
39
|
+
return tokenizer
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def make_env(cfg: DictConfig, devices: list[int] | None = None):
|
|
43
|
+
"""Create the environment with proper device allocation.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
cfg: The configuration object
|
|
47
|
+
devices: The devices to use for the reference model
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The configured environment
|
|
51
|
+
"""
|
|
52
|
+
# Create reference model with proper device allocation
|
|
53
|
+
# For the collector actor, we want inference_model devices first, then ref_model devices
|
|
54
|
+
train_tokenizer = get_tokenizer(cfg)
|
|
55
|
+
|
|
56
|
+
# Get device information
|
|
57
|
+
num_inf_devices = cfg.inference_model.num_devices
|
|
58
|
+
num_ref_devices = cfg.ref_model.num_devices
|
|
59
|
+
num_inf_devices + num_ref_devices
|
|
60
|
+
|
|
61
|
+
# Create a new config with adjusted device assignments
|
|
62
|
+
ref_cfg = DictConfig(dict(cfg))
|
|
63
|
+
ref_model = get_ref_model(ref_cfg, train_tokenizer, devices=devices)
|
|
64
|
+
|
|
65
|
+
# Setup environment
|
|
66
|
+
if cfg.env.dataset == "gsm8k":
|
|
67
|
+
from torchrl.envs.llm import GSM8KEnv
|
|
68
|
+
|
|
69
|
+
env = GSM8KEnv(
|
|
70
|
+
repeats=cfg.env.repeats,
|
|
71
|
+
tokenizer=train_tokenizer,
|
|
72
|
+
num_envs=cfg.env.num_envs,
|
|
73
|
+
device=torch.device("cpu"),
|
|
74
|
+
)
|
|
75
|
+
else: # ifeval
|
|
76
|
+
env = IFEvalEnv(
|
|
77
|
+
repeats=cfg.env.repeats,
|
|
78
|
+
tokenizer=train_tokenizer,
|
|
79
|
+
num_envs=cfg.env.num_envs,
|
|
80
|
+
device=torch.device("cpu"),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Pass device directly to RetrieveLogProb - Since, for Ray, the local device is always 0
|
|
84
|
+
# we can just use 0 here.
|
|
85
|
+
device = torch.device("cuda:0")
|
|
86
|
+
env = env.append_transform(
|
|
87
|
+
RetrieveLogProb(
|
|
88
|
+
model=ref_model,
|
|
89
|
+
assistant_only=True,
|
|
90
|
+
tokenizer_kwargs={"chat_template_name": "qwen"},
|
|
91
|
+
device=device,
|
|
92
|
+
log_probs_full_key=("ref_log_probs", "full"),
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
return env
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_train_model(
|
|
99
|
+
cfg: DictConfig,
|
|
100
|
+
devices: list[int] | None = None,
|
|
101
|
+
chat_template_name: str | None = None,
|
|
102
|
+
) -> tuple[TransformersWrapper, PreTrainedTokenizer]:
|
|
103
|
+
"""Creates and configures the training model with LoRA adapters.
|
|
104
|
+
|
|
105
|
+
This function initializes the main training model with LoRA adapters and other
|
|
106
|
+
training-specific configurations like gradient checkpointing. The model is wrapped
|
|
107
|
+
in a TransformersWrapper for policy training.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
cfg (DictConfig): The hydra configuration object containing model and training settings.
|
|
111
|
+
Expected to have train_model section with LoRA, quantization, and other
|
|
112
|
+
training-specific parameters.
|
|
113
|
+
devices (list[int] | None, optional): The devices to use for the training model. Defaults to `None`.
|
|
114
|
+
chat_template_name (str | None, optional): The name of the chat template to use. Defaults to `None`.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
tuple[TransformersWrapper, PreTrainedTokenizer]:
|
|
118
|
+
- policy_training: The wrapped training model
|
|
119
|
+
- train_tokenizer: The tokenizer for the model
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
RuntimeError: If CUDA is not available or if device allocation fails
|
|
123
|
+
"""
|
|
124
|
+
torchrl_logger.info("Creating train model")
|
|
125
|
+
|
|
126
|
+
# Set model dtype explicitly
|
|
127
|
+
model_dtype = getattr(torch, cfg.train_model.torch_dtype)
|
|
128
|
+
|
|
129
|
+
# Get configured devices or default to [0]
|
|
130
|
+
train_devices = devices if devices is not None else [0]
|
|
131
|
+
|
|
132
|
+
# Create max_memory dict - set 0 memory for GPUs we don't want to use
|
|
133
|
+
max_memory = {}
|
|
134
|
+
for i in range(torch.cuda.device_count()):
|
|
135
|
+
if i in train_devices:
|
|
136
|
+
max_memory[i] = "24GiB" # Allow max memory for devices we want to use
|
|
137
|
+
else:
|
|
138
|
+
max_memory[i] = "0GiB" # No memory for other devices
|
|
139
|
+
max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
|
|
140
|
+
|
|
141
|
+
# Let HF handle distribution with max_memory
|
|
142
|
+
device_map = "balanced" if len(train_devices) > 1 else f"cuda:{train_devices[0]}"
|
|
143
|
+
train_model, train_tokenizer = get_hf_model(
|
|
144
|
+
cfg.model.name,
|
|
145
|
+
device_map=device_map,
|
|
146
|
+
max_memory=max_memory,
|
|
147
|
+
lora=cfg.train_model.lora.enabled,
|
|
148
|
+
lora_r=cfg.train_model.lora.r,
|
|
149
|
+
lora_alpha=cfg.train_model.lora.alpha,
|
|
150
|
+
lora_dropout=cfg.train_model.lora.dropout,
|
|
151
|
+
gradient_checkpointing=cfg.train_model.gradient_checkpointing,
|
|
152
|
+
quantize=cfg.train_model.quantization.enabled,
|
|
153
|
+
torch_dtype=model_dtype,
|
|
154
|
+
attn_implementation=cfg.train_model.attn_implementation,
|
|
155
|
+
compile=cfg.model.compile,
|
|
156
|
+
eval_mode=cfg.train_model.eval,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Force all model parameters to the same dtype
|
|
160
|
+
for param in train_model.parameters():
|
|
161
|
+
param.data = param.data.to(model_dtype)
|
|
162
|
+
|
|
163
|
+
policy_training = TransformersWrapper(
|
|
164
|
+
train_model,
|
|
165
|
+
tokenizer=train_tokenizer,
|
|
166
|
+
input_mode="history",
|
|
167
|
+
generate=False,
|
|
168
|
+
return_log_probs=True,
|
|
169
|
+
pad_output=False,
|
|
170
|
+
device=torch.device("cuda:0"),
|
|
171
|
+
)
|
|
172
|
+
# Ensure model stays in eval mode after wrapping
|
|
173
|
+
policy_training.model.eval()
|
|
174
|
+
policy_training.model.train(False)
|
|
175
|
+
return policy_training, train_tokenizer
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def get_inference_model(
|
|
179
|
+
cfg: DictConfig,
|
|
180
|
+
devices: list[int] | None = None,
|
|
181
|
+
make_ray_worker: bool = True,
|
|
182
|
+
tokenizer: PreTrainedTokenizer | None = None,
|
|
183
|
+
) -> vLLMWrapper:
|
|
184
|
+
"""Creates the vLLM-based inference model for fast generation.
|
|
185
|
+
|
|
186
|
+
This function initializes a vLLM model server for efficient inference and wraps
|
|
187
|
+
it in a vLLMWrapper for policy inference. vLLM provides optimized generation
|
|
188
|
+
with better throughput than standard HuggingFace generation.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
cfg (DictConfig): The hydra configuration object containing model settings.
|
|
192
|
+
Expected to have inference_model section with vLLM-specific parameters
|
|
193
|
+
like gpu_memory_utilization and generation settings.
|
|
194
|
+
devices (list[int], optional): The devices to use for the inference model. Default: `None`.
|
|
195
|
+
make_ray_worker (bool, optional): Whether to make a ray worker. Default: `True`.
|
|
196
|
+
tokenizer (PreTrainedTokenizer | None, optional): The tokenizer to use. Default: None
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
vLLMWrapper: The wrapped vLLM model ready for inference.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
AssertionError: If the vLLM server or model initialization fails
|
|
203
|
+
"""
|
|
204
|
+
from torchrl.modules.llm.backends import make_vllm_worker
|
|
205
|
+
|
|
206
|
+
num_devices = cfg.inference_model.num_devices
|
|
207
|
+
if num_devices is None:
|
|
208
|
+
vllm_devices = devices if devices is not None else [1]
|
|
209
|
+
else:
|
|
210
|
+
vllm_devices = None
|
|
211
|
+
torchrl_logger.info(
|
|
212
|
+
f"Creating inference model with num_devices={num_devices}, devices={vllm_devices}"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
model_name = cfg.model.name
|
|
216
|
+
|
|
217
|
+
if tokenizer is None:
|
|
218
|
+
tokenizer = get_tokenizer(cfg)
|
|
219
|
+
|
|
220
|
+
# vLLM handles device mapping internally
|
|
221
|
+
inference_server = make_vllm_worker(
|
|
222
|
+
model_name=model_name,
|
|
223
|
+
gpu_memory_utilization=cfg.inference_model.gpu_memory_utilization,
|
|
224
|
+
num_devices=num_devices,
|
|
225
|
+
devices=list(vllm_devices)
|
|
226
|
+
if vllm_devices is not None
|
|
227
|
+
else None, # Convert to list for type compatibility
|
|
228
|
+
make_ray_worker=make_ray_worker,
|
|
229
|
+
enforce_eager=cfg.inference_model.enforce_eager,
|
|
230
|
+
)
|
|
231
|
+
assert inference_server is not None
|
|
232
|
+
policy = vLLMWrapper(
|
|
233
|
+
inference_server,
|
|
234
|
+
input_mode="history",
|
|
235
|
+
chat_template_name="qwen",
|
|
236
|
+
return_log_probs=True,
|
|
237
|
+
tokenizer=tokenizer,
|
|
238
|
+
pad_output=False,
|
|
239
|
+
generate_kwargs={
|
|
240
|
+
"max_tokens": cfg.inference_model.max_tokens,
|
|
241
|
+
"include_stop_str_in_output": cfg.inference_model.include_stop_str_in_output,
|
|
242
|
+
"temperature": cfg.inference_model.temperature,
|
|
243
|
+
},
|
|
244
|
+
)
|
|
245
|
+
assert policy.model is not None
|
|
246
|
+
return policy
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def get_ref_model(
|
|
250
|
+
cfg: DictConfig, tokenizer: PreTrainedTokenizer, devices: list[int] | None = None
|
|
251
|
+
) -> TransformersWrapper:
|
|
252
|
+
"""Creates the reference model for KL penalty computation.
|
|
253
|
+
|
|
254
|
+
This function initializes a frozen copy of the base model to serve as the
|
|
255
|
+
reference model for KL divergence computation. The reference model is typically
|
|
256
|
+
quantized and does not require gradient computation.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
cfg (DictConfig): The hydra configuration object containing model settings.
|
|
260
|
+
Expected to have ref_model section with quantization and attention settings.
|
|
261
|
+
tokenizer (PreTrainedTokenizer): The tokenizer to use with the reference model.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
TransformersWrapper: The wrapped reference model in eval mode with detached weights.
|
|
265
|
+
"""
|
|
266
|
+
torchrl_logger.info("Creating ref model")
|
|
267
|
+
|
|
268
|
+
# Get configured devices or default to [2]
|
|
269
|
+
if cfg.ref_model.num_devices is None:
|
|
270
|
+
ref_devices = devices if devices is not None else [2]
|
|
271
|
+
else:
|
|
272
|
+
ref_devices = list(range(cfg.ref_model.num_devices))
|
|
273
|
+
|
|
274
|
+
# Create max_memory dict - set 0 memory for GPUs we don't want to use
|
|
275
|
+
max_memory = {}
|
|
276
|
+
for i in range(torch.cuda.device_count()):
|
|
277
|
+
if i in ref_devices:
|
|
278
|
+
max_memory[i] = "24GiB" # Allow max memory for devices we want to use
|
|
279
|
+
else:
|
|
280
|
+
max_memory[i] = "0GiB" # No memory for other devices
|
|
281
|
+
max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
|
|
282
|
+
|
|
283
|
+
# Let HF handle distribution with max_memory
|
|
284
|
+
device_map = "balanced" if len(ref_devices) > 1 else f"cuda:{ref_devices[0]}"
|
|
285
|
+
model_name = cfg.model.name
|
|
286
|
+
|
|
287
|
+
ref_model = get_hf_model(
|
|
288
|
+
model_name,
|
|
289
|
+
device_map=device_map,
|
|
290
|
+
max_memory=max_memory,
|
|
291
|
+
torch_dtype=getattr(torch, cfg.ref_model.torch_dtype),
|
|
292
|
+
quantize=cfg.ref_model.quantization.enabled,
|
|
293
|
+
gradient_checkpointing=cfg.ref_model.gradient_checkpointing,
|
|
294
|
+
attn_implementation=cfg.ref_model.attn_implementation,
|
|
295
|
+
lora=False, # Reference model doesn't need LoRA
|
|
296
|
+
requires_grad=False,
|
|
297
|
+
eval_mode=True,
|
|
298
|
+
lora_dropout=0.0,
|
|
299
|
+
)[0]
|
|
300
|
+
# Detach weights
|
|
301
|
+
TensorDict.from_module(ref_model).data.to_module(ref_model)
|
|
302
|
+
ref_model = TransformersWrapper(
|
|
303
|
+
ref_model,
|
|
304
|
+
tokenizer=tokenizer,
|
|
305
|
+
input_mode="history",
|
|
306
|
+
generate=False,
|
|
307
|
+
return_log_probs=True,
|
|
308
|
+
pad_output=False,
|
|
309
|
+
device=torch.device("cuda:0"),
|
|
310
|
+
chat_template_name="qwen",
|
|
311
|
+
)
|
|
312
|
+
return ref_model
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_hf_model(
|
|
316
|
+
model_name: str,
|
|
317
|
+
torch_dtype: torch_dtype = torch.float32,
|
|
318
|
+
lora_r: int = 8,
|
|
319
|
+
lora_alpha: int = 16,
|
|
320
|
+
lora_dropout: float = 0.1,
|
|
321
|
+
quantize: bool = False,
|
|
322
|
+
fsdp: str = "",
|
|
323
|
+
fsdp_config: Any = None,
|
|
324
|
+
gradient_checkpointing: bool = True,
|
|
325
|
+
device_map: str
|
|
326
|
+
| dict[str, int | str | torch_device]
|
|
327
|
+
| int
|
|
328
|
+
| torch_device
|
|
329
|
+
| None = None,
|
|
330
|
+
lora: bool = True,
|
|
331
|
+
attn_implementation: Literal["flash_attention_2", "flex_attention", "sdpa"]
|
|
332
|
+
| None = "flex_attention",
|
|
333
|
+
requires_grad: bool = True,
|
|
334
|
+
compile: bool = False,
|
|
335
|
+
max_memory: dict[str, str] | None = None,
|
|
336
|
+
eval_mode: bool = False,
|
|
337
|
+
) -> tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
|
|
338
|
+
"""Creates and configures a HuggingFace model with optional optimizations.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
model_name (str): HuggingFace model identifier (e.g., "Qwen/Qwen2.5-3B")
|
|
342
|
+
torch_dtype (torch.dtype, optional): Model precision. Default: torch.float32
|
|
343
|
+
lora_r (int, optional): LoRA rank - controls capacity of adaptations. Default: 8
|
|
344
|
+
lora_alpha (int, optional): LoRA alpha - scales the adaptations. Default: 16
|
|
345
|
+
lora_dropout (float, optional): Dropout probability for LoRA layers. Default: 0.1
|
|
346
|
+
quantize (bool, optional): Whether to enable 4-bit quantization. Default: False
|
|
347
|
+
fsdp (str, optional): Fully Sharded Data Parallel configuration. Default: ""
|
|
348
|
+
fsdp_config (Any, optional): Additional FSDP configurations. Default: None
|
|
349
|
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Default: True
|
|
350
|
+
device_map (str | dict | int | torch.device | None, optional): Device placement strategy. Default: None
|
|
351
|
+
lora (bool, optional): Whether to apply LoRA adapters. Default: True
|
|
352
|
+
attn_implementation (Literal["flash_attention_2", "flex_attention", "sdpa"] | None, optional):
|
|
353
|
+
Attention implementation to use. Default: "flex_attention"
|
|
354
|
+
requires_grad (bool, optional): Whether to enable gradient computation. Default: True
|
|
355
|
+
compile (bool, optional): Whether to enable model compilation. Default: False
|
|
356
|
+
max_memory (dict[str, str], optional): Memory configuration for distributed training. Default: {}
|
|
357
|
+
eval_mode (bool, optional): Whether to use the model in eval mode. Default: False
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
|
|
361
|
+
- model: The configured HuggingFace model
|
|
362
|
+
- tokenizer: The associated tokenizer
|
|
363
|
+
|
|
364
|
+
Raises:
|
|
365
|
+
ImportError: If required dependencies are not installed
|
|
366
|
+
RuntimeError: If model initialization fails
|
|
367
|
+
"""
|
|
368
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
369
|
+
|
|
370
|
+
if max_memory is None:
|
|
371
|
+
max_memory = {}
|
|
372
|
+
|
|
373
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
374
|
+
# tokenizer.eos_token = "<|im_end|>"
|
|
375
|
+
if tokenizer.pad_token == tokenizer.eos_token:
|
|
376
|
+
tokenizer.pad_token = "PAD"
|
|
377
|
+
tokenizer.padding_side = "left"
|
|
378
|
+
|
|
379
|
+
# Configure model settings for mixed precision
|
|
380
|
+
# Store original dtype to restore it later
|
|
381
|
+
original_dtype = torch.get_default_dtype()
|
|
382
|
+
torch.set_default_dtype(torch_dtype)
|
|
383
|
+
|
|
384
|
+
model_configs = {
|
|
385
|
+
"torch_dtype": torch_dtype,
|
|
386
|
+
"device_map": device_map if device_map is not None else "auto",
|
|
387
|
+
"max_memory": max_memory,
|
|
388
|
+
}
|
|
389
|
+
if torch.cuda.is_available() and attn_implementation:
|
|
390
|
+
torchrl_logger.info(f"{attn_implementation} init")
|
|
391
|
+
model_configs["attn_implementation"] = attn_implementation
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
# Configure training settings based on FSDP usage
|
|
395
|
+
if fsdp != "" and fsdp_config is not None:
|
|
396
|
+
torchrl_logger.info("Configurations for FSDP")
|
|
397
|
+
else:
|
|
398
|
+
pass
|
|
399
|
+
|
|
400
|
+
# Enable Quantization
|
|
401
|
+
if quantize:
|
|
402
|
+
try:
|
|
403
|
+
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
404
|
+
except ImportError:
|
|
405
|
+
raise ImportError(
|
|
406
|
+
"Please install transformers with bitsandbytes support"
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
bnb_config = BitsAndBytesConfig(
|
|
410
|
+
load_in_4bit=True,
|
|
411
|
+
bnb_4bit_use_double_quant=True,
|
|
412
|
+
bnb_4bit_quant_type="nf4",
|
|
413
|
+
bnb_4bit_compute_dtype=torch_dtype,
|
|
414
|
+
)
|
|
415
|
+
model_configs["quantization_config"] = bnb_config
|
|
416
|
+
|
|
417
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
418
|
+
model_name,
|
|
419
|
+
trust_remote_code=True,
|
|
420
|
+
use_cache=not gradient_checkpointing,
|
|
421
|
+
cache_dir="/tmp/.cache",
|
|
422
|
+
**model_configs,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Configure gradient checkpointing based on FSDP usage
|
|
426
|
+
if fsdp == "" and fsdp_config is None:
|
|
427
|
+
if gradient_checkpointing:
|
|
428
|
+
torchrl_logger.info("gradient_checkpointing enabled")
|
|
429
|
+
model.gradient_checkpointing_enable()
|
|
430
|
+
else:
|
|
431
|
+
if gradient_checkpointing:
|
|
432
|
+
torchrl_logger.info("gradient_checkpointing enabled")
|
|
433
|
+
model.gradient_checkpointing_enable(
|
|
434
|
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
if lora:
|
|
438
|
+
try:
|
|
439
|
+
from peft import get_peft_model, LoraConfig
|
|
440
|
+
except ImportError:
|
|
441
|
+
raise ImportError("Please install peft: pip install peft")
|
|
442
|
+
|
|
443
|
+
# Create LoRA config with explicit dtype setting
|
|
444
|
+
lora_config = LoraConfig(
|
|
445
|
+
r=lora_r,
|
|
446
|
+
lora_alpha=lora_alpha,
|
|
447
|
+
target_modules="all-linear",
|
|
448
|
+
lora_dropout=lora_dropout, # Standard dropout for regularization
|
|
449
|
+
bias="none",
|
|
450
|
+
task_type="CAUSAL_LM",
|
|
451
|
+
inference_mode=not eval_mode, # CRITICAL: Must be False for training
|
|
452
|
+
init_lora_weights=True, # Good practice
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# Initialize LoRA model
|
|
456
|
+
model = get_peft_model(
|
|
457
|
+
model,
|
|
458
|
+
lora_config,
|
|
459
|
+
autocast_adapter_dtype=False, # Prevent automatic casting of adapter layers
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# Force LoRA layers to correct dtype and eval mode
|
|
463
|
+
for n, p in model.named_parameters():
|
|
464
|
+
if "lora_" in n: # Only convert LoRA parameters
|
|
465
|
+
p.data = p.data.to(torch_dtype)
|
|
466
|
+
if eval_mode:
|
|
467
|
+
model.eval() # Ensure model is in eval mode
|
|
468
|
+
else:
|
|
469
|
+
model.train(True)
|
|
470
|
+
if requires_grad:
|
|
471
|
+
model.requires_grad_(True)
|
|
472
|
+
else:
|
|
473
|
+
model.requires_grad_(False)
|
|
474
|
+
return model, tokenizer
|
|
475
|
+
|
|
476
|
+
finally:
|
|
477
|
+
# Restore original dtype
|
|
478
|
+
torch.set_default_dtype(original_dtype)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def make_weight_sync_scheme(
|
|
482
|
+
master_address=None,
|
|
483
|
+
master_port=None,
|
|
484
|
+
vllm_tp_size=1,
|
|
485
|
+
) -> VLLMWeightSyncScheme:
|
|
486
|
+
"""Creates a vLLM weight synchronization scheme using NCCL collectives.
|
|
487
|
+
|
|
488
|
+
This function creates a weight sync scheme that uses NCCL for high-performance
|
|
489
|
+
GPU-to-GPU weight transfers from the training model to vLLM inference workers.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
master_address (Optional[str]): Address of the master node for distributed init.
|
|
493
|
+
Defaults to "localhost".
|
|
494
|
+
master_port (Optional[int]): Port of the master node for distributed init.
|
|
495
|
+
If None, will auto-assign.
|
|
496
|
+
vllm_tp_size (int): vLLM tensor parallel size (gpus_per_replica). Defaults to 1.
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
|
|
500
|
+
"""
|
|
501
|
+
if master_address is None:
|
|
502
|
+
master_address = "localhost"
|
|
503
|
+
|
|
504
|
+
torchrl_logger.info(
|
|
505
|
+
f"Creating VLLMWeightSyncScheme with tp_size={vllm_tp_size}, "
|
|
506
|
+
f"master_address={master_address}, master_port={master_port}"
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
return VLLMWeightSyncScheme(
|
|
510
|
+
master_address=master_address,
|
|
511
|
+
master_port=master_port,
|
|
512
|
+
gpus_per_replica=vllm_tp_size,
|
|
513
|
+
num_replicas=1, # For expert iteration, typically 1 replica
|
|
514
|
+
strategy="state_dict",
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def compute_device_allocation(cfg):
|
|
519
|
+
"""Compute device allocation for different model components.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
cfg: The configuration object
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
A dictionary containing device allocations for different components
|
|
526
|
+
"""
|
|
527
|
+
train_devices = cfg.train_model.num_devices
|
|
528
|
+
inf_devices = cfg.inference_model.num_devices
|
|
529
|
+
ref_devices = cfg.ref_model.num_devices
|
|
530
|
+
|
|
531
|
+
# So we need all GPUs for Ray
|
|
532
|
+
train_start = 0
|
|
533
|
+
train_end = train_devices
|
|
534
|
+
inference_start = 0
|
|
535
|
+
inference_end = inf_devices
|
|
536
|
+
ref_start = inference_end
|
|
537
|
+
ref_end = ref_start + ref_devices
|
|
538
|
+
ray_num_gpus = train_devices + inf_devices + ref_devices
|
|
539
|
+
|
|
540
|
+
# Create device lists
|
|
541
|
+
train_model_devices = list(range(train_start, train_end))
|
|
542
|
+
inference_model_devices = list(range(inference_start, inference_end))
|
|
543
|
+
ref_model_devices = list(range(ref_start, ref_end))
|
|
544
|
+
|
|
545
|
+
# Get total unique devices for CUDA_VISIBLE_DEVICES
|
|
546
|
+
all_devices = sorted(
|
|
547
|
+
set(train_model_devices + inference_model_devices + ref_model_devices)
|
|
548
|
+
)
|
|
549
|
+
cuda_visible_devices = ",".join(map(str, all_devices))
|
|
550
|
+
|
|
551
|
+
return {
|
|
552
|
+
"train_model_devices": train_model_devices,
|
|
553
|
+
"inference_model_devices": inference_model_devices,
|
|
554
|
+
"ref_model_devices": ref_model_devices,
|
|
555
|
+
"ray_num_gpus": ray_num_gpus,
|
|
556
|
+
"cuda_visible_devices": cuda_visible_devices,
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def create_cosine_scheduler_with_warmup(
|
|
561
|
+
optimizer: torch.optim.Optimizer,
|
|
562
|
+
num_warmup_steps: int,
|
|
563
|
+
num_training_steps: int,
|
|
564
|
+
num_cycles: float = 0.5,
|
|
565
|
+
) -> torch.optim.lr_scheduler.LRScheduler:
|
|
566
|
+
"""Create a cosine scheduler with warmup using PyTorch's built-in schedulers.
|
|
567
|
+
|
|
568
|
+
This function creates a learning rate scheduler that:
|
|
569
|
+
1. Linearly increases the learning rate from 0 to the base learning rate during warmup
|
|
570
|
+
2. Follows a cosine curve from the base learning rate to 0 after warmup
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
optimizer: The optimizer to schedule learning rates for
|
|
574
|
+
num_warmup_steps: Number of warmup steps
|
|
575
|
+
num_training_steps: Total number of training steps
|
|
576
|
+
num_cycles: Number of cosine cycles (default: 0.5 for half a cycle)
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
A PyTorch learning rate scheduler
|
|
580
|
+
"""
|
|
581
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
|
582
|
+
|
|
583
|
+
# Create warmup scheduler (linear increase from 0 to base LR)
|
|
584
|
+
warmup_scheduler = LinearLR(
|
|
585
|
+
optimizer, start_factor=0.01, end_factor=1.0, total_iters=num_warmup_steps
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
# Create cosine decay scheduler (from base LR to 0)
|
|
589
|
+
cosine_scheduler = CosineAnnealingLR(
|
|
590
|
+
optimizer, T_max=num_training_steps - num_warmup_steps, eta_min=0.0
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
# Combine warmup and cosine decay
|
|
594
|
+
scheduler = SequentialLR(
|
|
595
|
+
optimizer,
|
|
596
|
+
schedulers=[warmup_scheduler, cosine_scheduler],
|
|
597
|
+
milestones=[num_warmup_steps],
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return scheduler
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def get_wandb_run_id(wandb_logger):
|
|
604
|
+
"""Get the wandb run ID from a WandbLogger instance.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
wandb_logger: The WandbLogger instance
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
str: The wandb run ID, or None if not available
|
|
611
|
+
"""
|
|
612
|
+
try:
|
|
613
|
+
# Wait a bit for wandb to initialize
|
|
614
|
+
|
|
615
|
+
max_attempts = 10
|
|
616
|
+
for attempt in range(max_attempts):
|
|
617
|
+
if hasattr(wandb_logger, "experiment") and wandb_logger.experiment:
|
|
618
|
+
run_id = wandb_logger.experiment.id
|
|
619
|
+
if run_id:
|
|
620
|
+
torchrl_logger.info(f"Got wandb run ID: {run_id}")
|
|
621
|
+
return run_id
|
|
622
|
+
if attempt < max_attempts - 1:
|
|
623
|
+
time.sleep(0.5)
|
|
624
|
+
torchrl_logger.info(
|
|
625
|
+
f"Waiting for wandb run ID, attempt {attempt + 1}/{max_attempts}"
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
torchrl_logger.warning("Could not get wandb run ID after multiple attempts")
|
|
629
|
+
return None
|
|
630
|
+
except Exception as e:
|
|
631
|
+
torchrl_logger.error(f"Error getting wandb run ID: {e}")
|
|
632
|
+
return None
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def log_training_metrics(
|
|
636
|
+
wandb_logger,
|
|
637
|
+
replay_buffer,
|
|
638
|
+
batch,
|
|
639
|
+
loss,
|
|
640
|
+
grad_norm,
|
|
641
|
+
global_step,
|
|
642
|
+
data_read_count,
|
|
643
|
+
collector,
|
|
644
|
+
start_time,
|
|
645
|
+
gradient_accumulation_steps,
|
|
646
|
+
history_str=None,
|
|
647
|
+
):
|
|
648
|
+
"""Log training metrics to wandb.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
wandb_logger: The wandb logger instance
|
|
652
|
+
replay_buffer: The replay buffer containing collected data
|
|
653
|
+
batch: The current training batch
|
|
654
|
+
loss: The computed loss object
|
|
655
|
+
grad_norm: The gradient norm value
|
|
656
|
+
global_step: Current global training step
|
|
657
|
+
data_read_count: Total data read count
|
|
658
|
+
collector: The collector instance
|
|
659
|
+
start_time: Training start time
|
|
660
|
+
gradient_accumulation_steps: Number of gradient accumulation steps
|
|
661
|
+
history_str: Optional history string for logging
|
|
662
|
+
"""
|
|
663
|
+
with torch.no_grad():
|
|
664
|
+
rb_content = replay_buffer[:]
|
|
665
|
+
batch_policy_version = batch["next", "policy_version"].view(-1).min()
|
|
666
|
+
batch_policy_age = collector.policy_version - batch_policy_version
|
|
667
|
+
|
|
668
|
+
metrics = {
|
|
669
|
+
"reward from buffer": float(
|
|
670
|
+
torch.cat(rb_content.get(("next", "reward"), as_list=True)).mean()
|
|
671
|
+
),
|
|
672
|
+
"reward from batch": float(batch["next", "reward"].mean()),
|
|
673
|
+
"seq_length from buffer": float(
|
|
674
|
+
torch.tensor(
|
|
675
|
+
[
|
|
676
|
+
t.numel()
|
|
677
|
+
for t in rb_content.get(("tokens", "response"), as_list=True)
|
|
678
|
+
],
|
|
679
|
+
dtype=torch.float,
|
|
680
|
+
).mean()
|
|
681
|
+
),
|
|
682
|
+
"loss_sft, from loss": float(loss.loss_sft),
|
|
683
|
+
"loss_kl_to_ref, from loss": float(loss.loss_kl_to_ref),
|
|
684
|
+
"kl_to_ref, from loss": float(loss.kl_to_ref),
|
|
685
|
+
"grad_norm": float(grad_norm)
|
|
686
|
+
if global_step % gradient_accumulation_steps == 0
|
|
687
|
+
else 0.0,
|
|
688
|
+
"write_count, from buffer": int(replay_buffer.write_count),
|
|
689
|
+
# how many gradient steps per write
|
|
690
|
+
"gradient_step_throughput (gradient step per write)": float(
|
|
691
|
+
global_step / replay_buffer.write_count
|
|
692
|
+
),
|
|
693
|
+
# how many optim steps per write
|
|
694
|
+
"optim_step_throughput (optim step per write)": float(
|
|
695
|
+
(global_step // gradient_accumulation_steps) / replay_buffer.write_count
|
|
696
|
+
),
|
|
697
|
+
"data_read_count (total)": data_read_count,
|
|
698
|
+
"current_policy_version (collector)": collector.policy_version,
|
|
699
|
+
# FIXME: Assume batch is a single trajectory
|
|
700
|
+
# FIXME: The addition of the transform after the env instantiation + _shuttle creation
|
|
701
|
+
# is messed up - we need the next data
|
|
702
|
+
"batch_policy_version (sampled batch)": batch_policy_version,
|
|
703
|
+
"batch_policy_age (sampled batch)": batch_policy_age,
|
|
704
|
+
"throughput (steps per second)": float(
|
|
705
|
+
global_step / (time.time() - start_time)
|
|
706
|
+
),
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
for name, value in metrics.items():
|
|
710
|
+
wandb_logger.log_scalar(name, value, step=global_step)
|
|
711
|
+
|
|
712
|
+
if history_str is not None:
|
|
713
|
+
wandb_logger.log_str("history", history_str, step=global_step)
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
class RemoteDataLogger:
|
|
717
|
+
"""A remote post-processing function that sends logging data to the main process via Ray for centralized logging."""
|
|
718
|
+
|
|
719
|
+
def __init__(self, log_queue):
|
|
720
|
+
"""Initialize RemoteDataLogger with a Ray actor reference for logging.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
log_queue: Ray queue for logging data.
|
|
724
|
+
"""
|
|
725
|
+
self.log_queue = log_queue
|
|
726
|
+
self.last_time = None
|
|
727
|
+
|
|
728
|
+
def __call__(self, data: TensorDict):
|
|
729
|
+
self.log_data(data)
|
|
730
|
+
return data
|
|
731
|
+
|
|
732
|
+
def log_data(self, data: TensorDict):
|
|
733
|
+
logs = {}
|
|
734
|
+
if self.last_time is None:
|
|
735
|
+
self.last_time = time.time()
|
|
736
|
+
else:
|
|
737
|
+
t = time.time()
|
|
738
|
+
elapsed = t - self.last_time
|
|
739
|
+
logs["collector/time/elapsed"] = elapsed
|
|
740
|
+
self.last_time = t
|
|
741
|
+
|
|
742
|
+
# Prepare logging data
|
|
743
|
+
logs["collector/rewards/mean"] = float(data["next", "reward"].mean())
|
|
744
|
+
logs["collector/rewards/std"] = float(data["next", "reward"].std())
|
|
745
|
+
logs["collector/rewards/min"] = float(data["next", "reward"].min())
|
|
746
|
+
logs["collector/rewards/max"] = float(data["next", "reward"].max())
|
|
747
|
+
|
|
748
|
+
# Response length
|
|
749
|
+
lengths = []
|
|
750
|
+
responses = data["text", "response"]
|
|
751
|
+
for r in responses:
|
|
752
|
+
lengths.append(len(r))
|
|
753
|
+
lengths = torch.tensor(lengths, dtype=torch.float32)
|
|
754
|
+
logs["collector/response_length/mean"] = float(lengths.mean())
|
|
755
|
+
logs["collector/response_length/std"] = float(lengths.std())
|
|
756
|
+
logs["collector/response_length/min"] = float(lengths.min())
|
|
757
|
+
logs["collector/response_length/max"] = float(lengths.max())
|
|
758
|
+
|
|
759
|
+
policy_versions = data.get(("next", "policy_version"))
|
|
760
|
+
if isinstance(policy_versions, torch.Tensor):
|
|
761
|
+
policy_versions = policy_versions.float()
|
|
762
|
+
logs["collector/policy_version/mean"] = float(policy_versions.mean())
|
|
763
|
+
logs["collector/policy_version/min"] = float(policy_versions.min())
|
|
764
|
+
logs["collector/policy_version/max"] = float(policy_versions.max())
|
|
765
|
+
|
|
766
|
+
# Send to main process via Ray actor
|
|
767
|
+
try:
|
|
768
|
+
self.log_queue.put(logs)
|
|
769
|
+
except Exception as e:
|
|
770
|
+
torchrl_logger.error(f"Failed to send logs to main process: {e}")
|