torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from torchrl._utils import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def register_fp32_overrides() -> None:
|
|
12
|
+
"""Register FP32 overrides for vLLM models."""
|
|
13
|
+
from vllm.model_executor.models.registry import ModelRegistry
|
|
14
|
+
|
|
15
|
+
# ======= Register models here =======
|
|
16
|
+
# Register Qwen3 models with FP32 override
|
|
17
|
+
ModelRegistry.register_model(
|
|
18
|
+
"Qwen3ForCausalLM",
|
|
19
|
+
"torchrl.modules.llm.backends._models:Qwen3ForCausalLMFP32",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger.info("Registered Qwen3 FP32 model overrides")
|
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Synchronous vLLM backend for TorchRL.
|
|
7
|
+
|
|
8
|
+
From https://docs.vllm.ai/en/v0.7.0/getting_started/examples/rlhf.html
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
from collections.abc import Iterator
|
|
15
|
+
from contextlib import nullcontext
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torchrl._utils import logger as torchrl_logger
|
|
19
|
+
from torchrl.modules.llm.utils import _cuda_visible_devices
|
|
20
|
+
|
|
21
|
+
from .base import RLvLLMEngine
|
|
22
|
+
from .vllm_utils import stateless_init_process_group
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from vllm import LLM
|
|
26
|
+
from vllm.worker.worker import Worker
|
|
27
|
+
|
|
28
|
+
_has_vllm = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
|
|
31
|
+
class LLM:
|
|
32
|
+
"""Placeholder for LLM class when vLLM is not installed."""
|
|
33
|
+
|
|
34
|
+
class Worker:
|
|
35
|
+
"""Placeholder for Worker class when vLLM is not installed."""
|
|
36
|
+
|
|
37
|
+
_has_vllm = False
|
|
38
|
+
|
|
39
|
+
# get_open_port may not be available in all vLLM versions
|
|
40
|
+
try:
|
|
41
|
+
from vllm.utils import get_open_port
|
|
42
|
+
except ImportError:
|
|
43
|
+
|
|
44
|
+
def get_open_port():
|
|
45
|
+
"""Fallback get_open_port using standard library."""
|
|
46
|
+
import socket
|
|
47
|
+
|
|
48
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
49
|
+
s.bind(("", 0))
|
|
50
|
+
return s.getsockname()[1]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class _vLLMWorker(Worker):
|
|
54
|
+
"""Private vLLM worker for Ray.
|
|
55
|
+
|
|
56
|
+
vLLMParameterServer will always take rank 0 in the stateless process group
|
|
57
|
+
initialized by this worker. And the tp ranks associated with the LLM class
|
|
58
|
+
will be in the range [1, tp_size].
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, *args, **kwargs):
|
|
62
|
+
if not _has_vllm:
|
|
63
|
+
raise ImportError(
|
|
64
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
torchrl_logger.info(f"=> in {type(self).__name__}.__init__")
|
|
68
|
+
torchrl_logger.info(f"visible devices {os.getenv('CUDA_VISIBLE_DEVICES')}")
|
|
69
|
+
torchrl_logger.info(f"device count {torch.cuda.device_count()}")
|
|
70
|
+
super().__init__(*args, **kwargs)
|
|
71
|
+
|
|
72
|
+
def init_weight_update_group(
|
|
73
|
+
self, master_address, master_port, rank_offset, world_size
|
|
74
|
+
):
|
|
75
|
+
from vllm.distributed.parallel_state import get_world_group
|
|
76
|
+
|
|
77
|
+
torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group")
|
|
78
|
+
|
|
79
|
+
# Get the local rank within the tensor parallel group
|
|
80
|
+
tp_group = get_world_group()
|
|
81
|
+
local_rank = tp_group.rank
|
|
82
|
+
torchrl_logger.info(f"Local rank in tensor parallel group: {local_rank}")
|
|
83
|
+
|
|
84
|
+
# Calculate the global rank for weight update group
|
|
85
|
+
# rank_offset is 1, so ranks will be [1, 2] for tp_size=2
|
|
86
|
+
rank = local_rank + rank_offset
|
|
87
|
+
torchrl_logger.info(
|
|
88
|
+
f"Initializing {type(self).__name__} weight update group with "
|
|
89
|
+
f"{master_address=}, {master_port=}, {rank=}, {world_size=}, device={self.device}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.model_update_group = stateless_init_process_group(
|
|
93
|
+
master_address,
|
|
94
|
+
master_port,
|
|
95
|
+
rank,
|
|
96
|
+
world_size,
|
|
97
|
+
self.device,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
torchrl_logger.info(f"{type(self).__name__}.init_weight_update_group success")
|
|
101
|
+
|
|
102
|
+
def update_weight_broadcast(self, name, dtype, shape):
|
|
103
|
+
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
|
104
|
+
self.model_update_group.broadcast(
|
|
105
|
+
weight, src=0, stream=torch.cuda.current_stream()
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.model_runner.model.load_weights(weights=[(name, weight)])
|
|
109
|
+
del weight
|
|
110
|
+
|
|
111
|
+
def update_weight(self, name, weight):
|
|
112
|
+
self.model_runner.model.load_weights(weights=[(name, weight)])
|
|
113
|
+
del weight
|
|
114
|
+
|
|
115
|
+
def check_weights_changed(self):
|
|
116
|
+
"""Check if the weights are updated to 0."""
|
|
117
|
+
# TODO: This is a test and should be treated as such
|
|
118
|
+
weights_updated = True
|
|
119
|
+
for p in self.model_runner.model.parameters():
|
|
120
|
+
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
|
121
|
+
return weights_updated
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class _LLMOnDevice(LLM):
|
|
125
|
+
"""Private wrapper around `vllm.LLM` to control its placement devices."""
|
|
126
|
+
|
|
127
|
+
def __init__(self, *args, bundle_indices: list | None = None, **kwargs):
|
|
128
|
+
if not _has_vllm:
|
|
129
|
+
raise ImportError(
|
|
130
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Stop Ray from manipulating CUDA_VISIBLE_DEVICES at the top-level
|
|
134
|
+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
|
135
|
+
|
|
136
|
+
# Configure GPU utilization for Ray workers
|
|
137
|
+
if bundle_indices is not None:
|
|
138
|
+
os.environ[
|
|
139
|
+
"VLLM_RAY_PER_WORKER_GPUS"
|
|
140
|
+
] = "0.4" # Allow multiple workers per GPU
|
|
141
|
+
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
|
|
142
|
+
torchrl_logger.info(
|
|
143
|
+
f"Initializing LLM with bundle_indices={bundle_indices}"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.args = args
|
|
147
|
+
self.kwargs = kwargs
|
|
148
|
+
|
|
149
|
+
def initialize(self):
|
|
150
|
+
# Let vLLM handle device placement
|
|
151
|
+
super().__init__(*self.args, **self.kwargs)
|
|
152
|
+
return True
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class RayLLMWorker(RLvLLMEngine):
|
|
156
|
+
"""A wrapper for Ray-based vLLM workers that implements the RLvLLMEngine interface.
|
|
157
|
+
|
|
158
|
+
This class wraps a Ray actor handle for a vLLM worker and provides the
|
|
159
|
+
standardized interface for weight updates and configuration access.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
def __init__(self, ray_actor, tensor_parallel_size: int, model_name: str):
|
|
163
|
+
self.ray_actor = ray_actor
|
|
164
|
+
self._tensor_parallel_size = tensor_parallel_size
|
|
165
|
+
self._model_name = model_name
|
|
166
|
+
self._master_address = None
|
|
167
|
+
self._master_port = None
|
|
168
|
+
|
|
169
|
+
def get_tp_size(self) -> int:
|
|
170
|
+
"""Get the tensor parallel size."""
|
|
171
|
+
return self._tensor_parallel_size
|
|
172
|
+
|
|
173
|
+
def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
|
|
174
|
+
"""Get model parameter metadata.
|
|
175
|
+
|
|
176
|
+
For Ray workers, this requires loading the model to inspect parameters.
|
|
177
|
+
Currently returns empty dict - should be implemented when needed.
|
|
178
|
+
"""
|
|
179
|
+
# TODO: Implement metadata extraction from Ray worker
|
|
180
|
+
torchrl_logger.warning(
|
|
181
|
+
"RayLLMWorker.get_model_metadata() not implemented - returning empty dict"
|
|
182
|
+
)
|
|
183
|
+
return {}
|
|
184
|
+
|
|
185
|
+
def get_master_address(self) -> str:
|
|
186
|
+
"""Get the master address for weight synchronization."""
|
|
187
|
+
if self._master_address is None:
|
|
188
|
+
self._master_address = "localhost"
|
|
189
|
+
return self._master_address
|
|
190
|
+
|
|
191
|
+
def get_master_port(self) -> int:
|
|
192
|
+
"""Get the master port for weight synchronization."""
|
|
193
|
+
if self._master_port is None:
|
|
194
|
+
self._master_port = get_open_port() if callable(get_open_port) else 29500
|
|
195
|
+
return self._master_port
|
|
196
|
+
|
|
197
|
+
def init_weight_update_group(self) -> None:
|
|
198
|
+
"""Initialize the weight update communication group."""
|
|
199
|
+
weight_sync_world_size = self._tensor_parallel_size + 1
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
import ray
|
|
203
|
+
|
|
204
|
+
# Initialize weight update group on the Ray worker
|
|
205
|
+
ray.get(
|
|
206
|
+
self.ray_actor.collective_rpc.remote(
|
|
207
|
+
"init_weight_update_group",
|
|
208
|
+
args=(
|
|
209
|
+
self.get_master_address(),
|
|
210
|
+
self.get_master_port(),
|
|
211
|
+
1,
|
|
212
|
+
weight_sync_world_size,
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
torchrl_logger.info("Ray worker weight update group initialized")
|
|
217
|
+
except ImportError:
|
|
218
|
+
raise ImportError(
|
|
219
|
+
"Ray not available for weight update group initialization"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
|
|
223
|
+
"""Update model weights via the Ray worker.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
weights: Iterator yielding (parameter_name, tensor) tuples
|
|
227
|
+
"""
|
|
228
|
+
try:
|
|
229
|
+
import ray
|
|
230
|
+
|
|
231
|
+
# Convert iterator to list for Ray serialization
|
|
232
|
+
weights_list = list(weights)
|
|
233
|
+
|
|
234
|
+
if not weights_list:
|
|
235
|
+
torchrl_logger.warning("No weights provided for update")
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
torchrl_logger.info(
|
|
239
|
+
f"Updating {len(weights_list)} parameters on Ray worker"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Send weights to the Ray worker
|
|
243
|
+
remotes = []
|
|
244
|
+
for name, weight in weights_list:
|
|
245
|
+
remotes.append(
|
|
246
|
+
self.ray_actor.collective_rpc.remote(
|
|
247
|
+
"update_weight", args=(name, weight.to("cuda"))
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
ray.get(remotes)
|
|
252
|
+
torchrl_logger.info("Ray worker weight update completed")
|
|
253
|
+
|
|
254
|
+
except ImportError:
|
|
255
|
+
raise ImportError("Ray not available for weight updates")
|
|
256
|
+
|
|
257
|
+
# Delegate generation methods to the Ray actor
|
|
258
|
+
def generate(self, *args, **kwargs):
|
|
259
|
+
"""Generate text using the Ray worker."""
|
|
260
|
+
try:
|
|
261
|
+
import ray
|
|
262
|
+
|
|
263
|
+
return ray.get(self.ray_actor.generate.remote(*args, **kwargs))
|
|
264
|
+
except ImportError:
|
|
265
|
+
raise ImportError("Ray not available for generation")
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class LocalLLMWrapper(RLvLLMEngine):
|
|
269
|
+
"""A wrapper for local vLLM.LLM instances that implements the RLvLLMEngine interface.
|
|
270
|
+
|
|
271
|
+
This wrapper provides the standardized interface for local vLLM instances,
|
|
272
|
+
though weight updates are not applicable since the model is in the same process.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(self, llm_instance, tensor_parallel_size: int, model_name: str):
|
|
276
|
+
self.llm_instance = llm_instance
|
|
277
|
+
self._tensor_parallel_size = tensor_parallel_size
|
|
278
|
+
self._model_name = model_name
|
|
279
|
+
self._master_address = None
|
|
280
|
+
self._master_port = None
|
|
281
|
+
|
|
282
|
+
def get_tp_size(self) -> int:
|
|
283
|
+
"""Get the tensor parallel size."""
|
|
284
|
+
return self._tensor_parallel_size
|
|
285
|
+
|
|
286
|
+
def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
|
|
287
|
+
"""Get model parameter metadata.
|
|
288
|
+
|
|
289
|
+
For local LLM instances, this would require accessing the model directly.
|
|
290
|
+
Currently returns empty dict.
|
|
291
|
+
"""
|
|
292
|
+
# TODO: Implement metadata extraction from local LLM
|
|
293
|
+
torchrl_logger.warning(
|
|
294
|
+
"LocalLLMWrapper.get_model_metadata() not implemented - returning empty dict"
|
|
295
|
+
)
|
|
296
|
+
return {}
|
|
297
|
+
|
|
298
|
+
def get_master_address(self) -> str:
|
|
299
|
+
"""Get the master address for weight synchronization."""
|
|
300
|
+
if self._master_address is None:
|
|
301
|
+
self._master_address = "localhost"
|
|
302
|
+
return self._master_address
|
|
303
|
+
|
|
304
|
+
def get_master_port(self) -> int:
|
|
305
|
+
"""Get the master port for weight synchronization."""
|
|
306
|
+
if self._master_port is None:
|
|
307
|
+
self._master_port = get_open_port() if callable(get_open_port) else 29500
|
|
308
|
+
return self._master_port
|
|
309
|
+
|
|
310
|
+
def init_weight_update_group(self) -> None:
|
|
311
|
+
"""Initialize the weight update communication group."""
|
|
312
|
+
torchrl_logger.info("Local LLM weight update group initialized (no-op)")
|
|
313
|
+
|
|
314
|
+
def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
|
|
315
|
+
"""Update model weights.
|
|
316
|
+
|
|
317
|
+
For local LLM instances, weight updates are not applicable since
|
|
318
|
+
the model is in the same process space.
|
|
319
|
+
"""
|
|
320
|
+
weights_list = list(weights)
|
|
321
|
+
torchrl_logger.info(
|
|
322
|
+
f"Local LLM weight update (no-op) for {len(weights_list)} parameters"
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Delegate generation methods to the local LLM
|
|
326
|
+
def generate(self, *args, **kwargs):
|
|
327
|
+
"""Generate text using the local LLM."""
|
|
328
|
+
return self.llm_instance.generate(*args, **kwargs)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def make_vllm_worker(
|
|
332
|
+
*,
|
|
333
|
+
model_name: str,
|
|
334
|
+
devices: list[torch.device | int] | None = None,
|
|
335
|
+
num_devices: int | None = None,
|
|
336
|
+
make_ray_worker: bool = True,
|
|
337
|
+
enforce_eager: bool = False,
|
|
338
|
+
enable_fp32_output: bool = False,
|
|
339
|
+
**kwargs,
|
|
340
|
+
) -> RayLLMWorker | LocalLLMWrapper:
|
|
341
|
+
"""Creates a vLLM inference engine with tensor parallelism support.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
model_name (str): The model name to pass to vLLM.LLM.
|
|
345
|
+
devices (list[torch.device | int], optional): List of devices to use. Exclusive with num_devices.
|
|
346
|
+
num_devices (int, optional): Number of devices to use. Exclusive with devices.
|
|
347
|
+
make_ray_worker (bool, optional): Whether to create a Ray actor. Defaults to True.
|
|
348
|
+
enforce_eager (bool, optional): Whether to enforce eager execution. Defaults to `False`.
|
|
349
|
+
enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
|
|
350
|
+
This can help with numerical stability for certain models. Requires model-specific support in
|
|
351
|
+
torchrl.modules.llm.backends._models.
|
|
352
|
+
**kwargs: Additional arguments passed to vLLM.LLM.__init__.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
RayLLMWorker | LocalLLMWrapper: Either a Ray worker wrapper or a local LLM wrapper, both implementing RLvLLMEngine.
|
|
356
|
+
|
|
357
|
+
Example:
|
|
358
|
+
>>> # Create a 2-GPU tensor parallel worker with Ray
|
|
359
|
+
>>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", num_devices=2)
|
|
360
|
+
>>> # Create a local LLM instance on GPU 1
|
|
361
|
+
>>> llm = make_vllm_worker("Qwen/Qwen2.5-3B", devices=[1], make_ray_worker=False)
|
|
362
|
+
>>> # Create with FP32 output enabled
|
|
363
|
+
>>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", num_devices=2, enable_fp32_output=True)
|
|
364
|
+
"""
|
|
365
|
+
if not _has_vllm:
|
|
366
|
+
raise ImportError(
|
|
367
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Set FP32 output environment variable if requested
|
|
371
|
+
if enable_fp32_output:
|
|
372
|
+
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
|
|
373
|
+
torchrl_logger.info(
|
|
374
|
+
"Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
|
|
375
|
+
"This will use FP32 for the final output layer if the model supports it."
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Handle device specification
|
|
379
|
+
if num_devices is not None and devices is not None:
|
|
380
|
+
raise ValueError("Cannot specify both num_devices and devices")
|
|
381
|
+
if num_devices is not None:
|
|
382
|
+
devices = None
|
|
383
|
+
elif devices is None:
|
|
384
|
+
devices = [0] # Default to first GPU
|
|
385
|
+
num_devices = 1
|
|
386
|
+
elif len(devices) > 1:
|
|
387
|
+
# Convert devices to indices
|
|
388
|
+
devices = [
|
|
389
|
+
torch.device(device).index if not isinstance(device, int) else device
|
|
390
|
+
for device in devices
|
|
391
|
+
]
|
|
392
|
+
num_devices = len(devices)
|
|
393
|
+
|
|
394
|
+
# Validate devices
|
|
395
|
+
if devices is not None:
|
|
396
|
+
for d in devices:
|
|
397
|
+
if not isinstance(d, int) or d < 0 or d >= torch.cuda.device_count():
|
|
398
|
+
raise ValueError(f"Invalid device index: {d}")
|
|
399
|
+
|
|
400
|
+
if make_ray_worker:
|
|
401
|
+
import ray
|
|
402
|
+
|
|
403
|
+
if not ray.is_initialized():
|
|
404
|
+
raise RuntimeError("Ray is not initialized")
|
|
405
|
+
|
|
406
|
+
torchrl_logger.info(
|
|
407
|
+
f"Creating vLLM Ray worker with tensor_parallel_size={num_devices}"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Configure Ray remote class with minimal resources
|
|
411
|
+
# Let vLLM handle GPU allocation through environment variables
|
|
412
|
+
worker_cls = ray.remote(
|
|
413
|
+
num_cpus=4, # Minimal CPU request
|
|
414
|
+
num_gpus=0, # Let vLLM handle GPU allocation
|
|
415
|
+
)(_LLMOnDevice)
|
|
416
|
+
|
|
417
|
+
# Create worker with tensor parallelism config
|
|
418
|
+
worker = worker_cls.remote(
|
|
419
|
+
model=model_name,
|
|
420
|
+
bundle_indices=devices, # Pass device indices to _LLMOnDevice
|
|
421
|
+
tensor_parallel_size=num_devices,
|
|
422
|
+
distributed_executor_backend="ray",
|
|
423
|
+
enforce_eager=enforce_eager,
|
|
424
|
+
worker_cls="torchrl.modules.llm.backends.vllm.vllm_sync._vLLMWorker",
|
|
425
|
+
**kwargs,
|
|
426
|
+
)
|
|
427
|
+
ray.get(worker.initialize.remote())
|
|
428
|
+
|
|
429
|
+
# Wrap the Ray actor in RayLLMWorker to provide RLvLLMEngine interface
|
|
430
|
+
return RayLLMWorker(worker, num_devices or 1, model_name)
|
|
431
|
+
|
|
432
|
+
else:
|
|
433
|
+
# Local non-Ray mode - use LLM directly
|
|
434
|
+
with _cuda_visible_devices(devices) if devices is not None else nullcontext():
|
|
435
|
+
torchrl_logger.info(
|
|
436
|
+
f"Creating local vLLM LLM with tensor_parallel_size={num_devices}, devices={devices}"
|
|
437
|
+
)
|
|
438
|
+
llm_instance = LLM(
|
|
439
|
+
model=model_name,
|
|
440
|
+
tensor_parallel_size=num_devices,
|
|
441
|
+
enforce_eager=True,
|
|
442
|
+
**kwargs,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Wrap the local LLM to provide RLvLLMEngine interface
|
|
446
|
+
return LocalLLMWrapper(llm_instance, num_devices or 1, model_name)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Shared utilities for vLLM backends."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from torchrl._utils import logger as torchrl_logger
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
16
|
+
from vllm.distributed.utils import StatelessProcessGroup
|
|
17
|
+
|
|
18
|
+
_has_vllm = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
PyNcclCommunicator = None
|
|
21
|
+
StatelessProcessGroup = None
|
|
22
|
+
_has_vllm = False
|
|
23
|
+
|
|
24
|
+
# get_open_port may not be available in all vLLM versions
|
|
25
|
+
try:
|
|
26
|
+
from vllm.utils import get_open_port
|
|
27
|
+
except ImportError:
|
|
28
|
+
|
|
29
|
+
def get_open_port():
|
|
30
|
+
"""Fallback get_open_port using standard library."""
|
|
31
|
+
import socket
|
|
32
|
+
|
|
33
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
34
|
+
s.bind(("", 0))
|
|
35
|
+
return s.getsockname()[1]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def stateless_init_process_group(
|
|
39
|
+
master_address: str | None, master_port: str | None, rank, world_size, device=None
|
|
40
|
+
):
|
|
41
|
+
"""Initializes a stateless process group for distributed communication.
|
|
42
|
+
|
|
43
|
+
Creates a `StatelessProcessGroup` instance without relying on the global
|
|
44
|
+
process group in `torch.distributed`. This approach is recommended for
|
|
45
|
+
initializing data-plane communication (NCCL) between external processes
|
|
46
|
+
(e.g., training processes) and vLLM workers.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
master_address (str | None): The address of the master node. Defaults to "localhost" if not specified.
|
|
50
|
+
master_port (str | None): The port used by the master node. Automatically assigns an open port if not specified.
|
|
51
|
+
rank (int): The rank of the current process.
|
|
52
|
+
world_size (int): The total number of processes in the distributed group.
|
|
53
|
+
device: The device to use for communication. Defaults to None.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
PyNcclCommunicator: A PyNcclCommunicator instance initialized with the created StatelessProcessGroup.
|
|
57
|
+
"""
|
|
58
|
+
if not _has_vllm:
|
|
59
|
+
raise ImportError(
|
|
60
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if StatelessProcessGroup is None or PyNcclCommunicator is None:
|
|
64
|
+
raise ImportError(
|
|
65
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if master_address is None:
|
|
69
|
+
master_address = "localhost" # get_ip()
|
|
70
|
+
if master_port is None:
|
|
71
|
+
master_port = get_open_port() if callable(get_open_port) else 29500
|
|
72
|
+
|
|
73
|
+
torchrl_logger.info(
|
|
74
|
+
f"Initializing stateless process group: rank={rank}, world_size={world_size}, master_address={master_address}, master_port={master_port}"
|
|
75
|
+
)
|
|
76
|
+
pg = StatelessProcessGroup.create(
|
|
77
|
+
host=master_address, port=int(master_port), rank=rank, world_size=world_size
|
|
78
|
+
)
|
|
79
|
+
if device is None:
|
|
80
|
+
device = torch.device("cuda:0")
|
|
81
|
+
pynccl = PyNcclCommunicator(pg, device=device)
|
|
82
|
+
return pynccl
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def stateless_init_process_group_async(
|
|
86
|
+
master_address: str | None,
|
|
87
|
+
master_port: str | None,
|
|
88
|
+
rank: int,
|
|
89
|
+
world_size: int,
|
|
90
|
+
device,
|
|
91
|
+
):
|
|
92
|
+
"""Initializes a stateless process group for distributed communication (async version).
|
|
93
|
+
|
|
94
|
+
Creates a `StatelessProcessGroup` instance without relying on the global
|
|
95
|
+
process group in `torch.distributed`. This approach is recommended for
|
|
96
|
+
initializing data-plane communication (NCCL) between external processes
|
|
97
|
+
(e.g., training processes) and vLLM workers.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
master_address (str | None): The address of the master node. Defaults to "localhost" if not specified.
|
|
101
|
+
master_port (str | None): The port used by the master node. Automatically assigns an open port if not specified.
|
|
102
|
+
rank (int): The rank of the current process.
|
|
103
|
+
world_size (int): The total number of processes in the distributed group.
|
|
104
|
+
device: The device to use for communication.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
PyNcclCommunicator: A PyNcclCommunicator instance initialized with the created StatelessProcessGroup.
|
|
108
|
+
"""
|
|
109
|
+
if not _has_vllm:
|
|
110
|
+
raise ImportError(
|
|
111
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if StatelessProcessGroup is None or PyNcclCommunicator is None:
|
|
115
|
+
raise ImportError(
|
|
116
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if master_address is None:
|
|
120
|
+
master_address = "localhost"
|
|
121
|
+
if master_port is None:
|
|
122
|
+
master_port = get_open_port() if callable(get_open_port) else 29500
|
|
123
|
+
|
|
124
|
+
master_port_int = int(master_port) if master_port is not None else 0
|
|
125
|
+
pg = StatelessProcessGroup.create(
|
|
126
|
+
host=master_address, port=master_port_int, rank=rank, world_size=world_size
|
|
127
|
+
)
|
|
128
|
+
pynccl = PyNcclCommunicator(pg, device=device)
|
|
129
|
+
return pynccl
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""LLM policy wrappers.
|
|
6
|
+
|
|
7
|
+
This subpackage includes optional wrappers that may rely on native extensions
|
|
8
|
+
(e.g. vLLM). To avoid importing optional dependencies at module import time,
|
|
9
|
+
we avoid importing those dependencies at module import time.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from .common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens
|
|
15
|
+
from .transformers_wrapper import RemoteTransformersWrapper, TransformersWrapper
|
|
16
|
+
from .vllm_wrapper import vLLMWrapper
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"TransformersWrapper",
|
|
20
|
+
"RemoteTransformersWrapper",
|
|
21
|
+
"vLLMWrapper",
|
|
22
|
+
"LLMWrapperBase",
|
|
23
|
+
"Text",
|
|
24
|
+
"LogProbs",
|
|
25
|
+
"Masks",
|
|
26
|
+
"Tokens",
|
|
27
|
+
"ChatHistory",
|
|
28
|
+
]
|