torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,2075 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Async vLLM engine implementation for efficient batching and inference.
|
|
7
|
+
|
|
8
|
+
This module provides an async vLLM engine that leverages native vLLM batching
|
|
9
|
+
for better performance and memory efficiency compared to the explicit batching
|
|
10
|
+
approach used in the legacy vLLM backend.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import os
|
|
17
|
+
import random
|
|
18
|
+
import time
|
|
19
|
+
import uuid
|
|
20
|
+
from collections.abc import Iterator, Sequence
|
|
21
|
+
from concurrent.futures import ThreadPoolExecutor, wait
|
|
22
|
+
from typing import Any, Literal, TYPE_CHECKING
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
from torchrl._utils import logger as torchrl_logger
|
|
27
|
+
|
|
28
|
+
# Import RLvLLMEngine and shared utilities
|
|
29
|
+
from .base import RLvLLMEngine
|
|
30
|
+
from .vllm_utils import stateless_init_process_group
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_has_vllm = True
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from vllm.engine.async_llm_engine import AsyncEngineArgs
|
|
37
|
+
from vllm.engine.request import RequestOutput
|
|
38
|
+
from vllm.engine.sampling_params import SamplingParams
|
|
39
|
+
|
|
40
|
+
TIMEOUT_SECONDS = os.getenv("TORCHRL_VLLM_TIMEOUT_SECONDS", 300)
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
import vllm
|
|
44
|
+
|
|
45
|
+
_has_vllm = True
|
|
46
|
+
except ImportError:
|
|
47
|
+
vllm = None
|
|
48
|
+
_has_vllm = False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_ray():
|
|
52
|
+
"""Import Ray on demand to avoid global import side-effects.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
ModuleType: The imported Ray module.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ImportError: If Ray is not installed.
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
import ray # type: ignore
|
|
62
|
+
|
|
63
|
+
return ray
|
|
64
|
+
except Exception as e: # pragma: no cover - surfaced to callers
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"ray is not installed. Please install it with `pip install ray`."
|
|
67
|
+
) from e
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class _AsyncvLLMWorker:
|
|
71
|
+
"""Async vLLM worker extension for Ray with weight update capabilities."""
|
|
72
|
+
|
|
73
|
+
def init_weight_update_group(
|
|
74
|
+
self,
|
|
75
|
+
master_address: str,
|
|
76
|
+
master_port: str,
|
|
77
|
+
rank_offset: int,
|
|
78
|
+
world_size: int,
|
|
79
|
+
):
|
|
80
|
+
"""Initialize weight update group for this worker (non-blocking).
|
|
81
|
+
|
|
82
|
+
This method starts NCCL initialization in a background thread and returns immediately,
|
|
83
|
+
allowing the RPC to complete. The NCCL collective will complete when the trainer joins.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
master_address (str): The master address for distributed training.
|
|
87
|
+
master_port (str): The master port for distributed training.
|
|
88
|
+
rank_offset (int): Rank offset for this worker in the global weight update group.
|
|
89
|
+
world_size (int): Total number of processes in the weight update group.
|
|
90
|
+
"""
|
|
91
|
+
import threading
|
|
92
|
+
|
|
93
|
+
from vllm.distributed.parallel_state import get_world_group
|
|
94
|
+
|
|
95
|
+
torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group")
|
|
96
|
+
if getattr(self, "model_update_group", None) is not None:
|
|
97
|
+
torchrl_logger.info("Model update group already initialized")
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# Get the local rank within the tensor parallel group
|
|
101
|
+
tp_group = get_world_group()
|
|
102
|
+
local_rank = tp_group.rank
|
|
103
|
+
torchrl_logger.info(f"Local rank in tensor parallel group: {local_rank}")
|
|
104
|
+
|
|
105
|
+
# Calculate the global rank for weight update group
|
|
106
|
+
rank = local_rank + rank_offset
|
|
107
|
+
torchrl_logger.info(
|
|
108
|
+
f"Starting {type(self).__name__} weight update group init (non-blocking) with "
|
|
109
|
+
f"{master_address=}, {master_port=}, {rank=}, {world_size=}, device={self.device}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Start NCCL init in a background thread so this RPC can return immediately
|
|
113
|
+
def _init_nccl_background():
|
|
114
|
+
try:
|
|
115
|
+
from .vllm_utils import stateless_init_process_group
|
|
116
|
+
|
|
117
|
+
torchrl_logger.info(
|
|
118
|
+
f"Worker rank {rank}: Starting NCCL init (will block until collective completes)..."
|
|
119
|
+
)
|
|
120
|
+
self.model_update_group = stateless_init_process_group(
|
|
121
|
+
master_address, master_port, rank, world_size, self.device
|
|
122
|
+
)
|
|
123
|
+
torchrl_logger.info(f"Worker rank {rank}: NCCL init complete!")
|
|
124
|
+
except Exception as e:
|
|
125
|
+
torchrl_logger.error(f"Worker rank {rank}: NCCL init failed: {e}")
|
|
126
|
+
raise
|
|
127
|
+
|
|
128
|
+
thread = threading.Thread(target=_init_nccl_background, daemon=False)
|
|
129
|
+
thread.start()
|
|
130
|
+
|
|
131
|
+
# Store thread reference for potential cleanup
|
|
132
|
+
self._nccl_init_thread = thread
|
|
133
|
+
|
|
134
|
+
torchrl_logger.info(
|
|
135
|
+
f"{type(self).__name__}.init_weight_update_group dispatched (non-blocking)"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def update_weight(self, name: str, dtype_name: str, shape: tuple[int, ...]):
|
|
139
|
+
"""Update weight via broadcast from master (rank 0) - periodic-mono pattern.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
name (str): Parameter name.
|
|
143
|
+
dtype_name (str): Parameter dtype name (e.g., 'bfloat16').
|
|
144
|
+
shape (tuple[int, ...]): Parameter shape.
|
|
145
|
+
"""
|
|
146
|
+
if self.model_update_group is None:
|
|
147
|
+
raise RuntimeError("Weight update group not initialized")
|
|
148
|
+
|
|
149
|
+
# Convert dtype name to dtype (like periodic-mono)
|
|
150
|
+
dtype = getattr(torch, dtype_name)
|
|
151
|
+
|
|
152
|
+
# Workers receive broadcast from master (rank 0)
|
|
153
|
+
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
|
154
|
+
self.model_update_group.broadcast(
|
|
155
|
+
weight, src=0, stream=torch.cuda.current_stream()
|
|
156
|
+
)
|
|
157
|
+
self.model_runner.model.load_weights(weights=[(name, weight)])
|
|
158
|
+
del weight
|
|
159
|
+
|
|
160
|
+
def check_nccl_group_ready(self):
|
|
161
|
+
"""Check if NCCL group is ready for communication."""
|
|
162
|
+
ready = self.model_update_group is not None
|
|
163
|
+
torchrl_logger.info(f"Worker NCCL group ready: {ready}")
|
|
164
|
+
return ready
|
|
165
|
+
|
|
166
|
+
def load_weights_from_storage(self, storage_path: str, num_threads: int = 1):
|
|
167
|
+
"""Load weights from shared storage (double-buffer approach).
|
|
168
|
+
|
|
169
|
+
This method reads weights from a memory-mapped TensorDict directory
|
|
170
|
+
and loads them into the model. Used for file-based weight synchronization
|
|
171
|
+
as an alternative to NCCL collectives.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
storage_path: Path to the directory containing memory-mapped weights
|
|
175
|
+
num_threads: Number of threads for reading (default: 1)
|
|
176
|
+
"""
|
|
177
|
+
from tensordict import TensorDict
|
|
178
|
+
|
|
179
|
+
torchrl_logger.info(f"Worker loading weights from {storage_path}")
|
|
180
|
+
|
|
181
|
+
# Read weights from shared storage
|
|
182
|
+
weights = TensorDict.load_memmap(storage_path)
|
|
183
|
+
weights = weights.flatten_keys(".")
|
|
184
|
+
|
|
185
|
+
# Convert to list of (name, tensor) tuples
|
|
186
|
+
weights_list = list(weights.items())
|
|
187
|
+
|
|
188
|
+
torchrl_logger.info(f"Worker loading {len(weights_list)} weights into model")
|
|
189
|
+
|
|
190
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
191
|
+
futures = [
|
|
192
|
+
executor.submit(self.model_runner.model.load_weights, weights)
|
|
193
|
+
for weights in weights_list
|
|
194
|
+
]
|
|
195
|
+
wait(futures)
|
|
196
|
+
|
|
197
|
+
torchrl_logger.info(
|
|
198
|
+
f"Worker successfully loaded {len(weights_list)} weights from storage"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class _AsyncLLMEngine:
|
|
203
|
+
"""Extended AsyncLLMEngine with TorchRL-specific features.
|
|
204
|
+
|
|
205
|
+
This class wraps vLLM's AsyncLLMEngine and adds functionality needed
|
|
206
|
+
for TorchRL integration, including weight updates and batch management.
|
|
207
|
+
|
|
208
|
+
This is a private class and should not be used directly. Use the ray remote actor class :class:`AsyncLLMEngineActor` instead.
|
|
209
|
+
|
|
210
|
+
Keyword Args:
|
|
211
|
+
engine_args (AsyncEngineArgs): Arguments for creating the AsyncLLMEngine instances.
|
|
212
|
+
bundle_indices (list[int], optional): Bundle indices for the engine.
|
|
213
|
+
enable_prefix_caching (bool, optional): Whether to enable prefix caching.
|
|
214
|
+
|
|
215
|
+
.. warning::
|
|
216
|
+
enable_prefix_caching is set to False by default, which is recommended if prompt log probs are needed.
|
|
217
|
+
Set it to True if prompt log probs are not needed.
|
|
218
|
+
See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
*,
|
|
224
|
+
engine_args: AsyncEngineArgs,
|
|
225
|
+
bundle_indices: list[int] | None = None,
|
|
226
|
+
enable_prefix_caching: bool = False,
|
|
227
|
+
):
|
|
228
|
+
if not _has_vllm:
|
|
229
|
+
raise ImportError(
|
|
230
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
from vllm import AsyncLLMEngine
|
|
234
|
+
|
|
235
|
+
if bundle_indices is not None:
|
|
236
|
+
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
|
|
237
|
+
|
|
238
|
+
engine_args.enable_prefix_caching = enable_prefix_caching
|
|
239
|
+
|
|
240
|
+
# Fix for vLLM issue #19123: Set RAY_ADDRESS so vLLM subprocesses connect
|
|
241
|
+
# to the same Ray cluster instead of starting a new one (causes KeyError: 'bundles')
|
|
242
|
+
try:
|
|
243
|
+
import ray
|
|
244
|
+
|
|
245
|
+
if ray.is_initialized():
|
|
246
|
+
# Get the current Ray address and set it in the environment
|
|
247
|
+
# so vLLM's subprocess connects to the same cluster
|
|
248
|
+
ray_address = ray.get_runtime_context().gcs_address
|
|
249
|
+
if ray_address and "RAY_ADDRESS" not in os.environ:
|
|
250
|
+
os.environ["RAY_ADDRESS"] = ray_address
|
|
251
|
+
torchrl_logger.debug(
|
|
252
|
+
f"Set RAY_ADDRESS={ray_address} for vLLM subprocess"
|
|
253
|
+
)
|
|
254
|
+
except Exception:
|
|
255
|
+
pass # Ray not available or not initialized, let vLLM handle it
|
|
256
|
+
|
|
257
|
+
# Create the engine directly - this is the source of the blocking ray.get issue
|
|
258
|
+
# but we need to handle it differently for multiple replicas
|
|
259
|
+
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
260
|
+
self.bundle_indices = bundle_indices
|
|
261
|
+
|
|
262
|
+
def ready(self) -> bool:
|
|
263
|
+
"""Check if engine is ready for inference."""
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
async def generate(
|
|
267
|
+
self,
|
|
268
|
+
prompts: Any = None,
|
|
269
|
+
sampling_params: SamplingParams | None = None,
|
|
270
|
+
*,
|
|
271
|
+
prompt_token_ids: list[int] | list[list[int]] | None = None,
|
|
272
|
+
use_tqdm: bool = True,
|
|
273
|
+
lora_request: Any = None,
|
|
274
|
+
prompt_adapter_request: Any = None,
|
|
275
|
+
guided_options_request: Any = None,
|
|
276
|
+
timeout_seconds: float | None = None,
|
|
277
|
+
) -> RequestOutput | list[RequestOutput]:
|
|
278
|
+
"""Generate text with the same interface as vLLM.LLM.generate.
|
|
279
|
+
|
|
280
|
+
This method mirrors the interface of vLLM.LLM.generate to provide seamless
|
|
281
|
+
compatibility between sync and async engines.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
prompts: String, TokensPrompt, or list of these. Input prompts for generation.
|
|
285
|
+
sampling_params: SamplingParams object for controlling generation behavior.
|
|
286
|
+
prompt_token_ids: Alternative to prompts - token IDs for generation.
|
|
287
|
+
use_tqdm: Whether to show progress bar (not used in async engine).
|
|
288
|
+
lora_request: LoRA request for adapter-based generation.
|
|
289
|
+
guided_options_request: Guided decoding options.
|
|
290
|
+
timeout_seconds: Timeout for generation in seconds.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
RequestOutput or list of RequestOutput: Generated outputs from vLLM.
|
|
294
|
+
"""
|
|
295
|
+
if not _has_vllm:
|
|
296
|
+
raise ImportError(
|
|
297
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
from vllm import SamplingParams, TokensPrompt
|
|
301
|
+
|
|
302
|
+
# Track whether input was originally a single prompt
|
|
303
|
+
single_prompt_input = False
|
|
304
|
+
|
|
305
|
+
# Handle prompt_token_ids if provided
|
|
306
|
+
if prompt_token_ids is not None:
|
|
307
|
+
if prompts is not None:
|
|
308
|
+
raise ValueError("Cannot specify both prompts and prompt_token_ids")
|
|
309
|
+
|
|
310
|
+
# Convert token IDs to TokensPrompt objects
|
|
311
|
+
if not prompt_token_ids:
|
|
312
|
+
raise ValueError("prompt_token_ids cannot be empty")
|
|
313
|
+
|
|
314
|
+
# Check if it's a list of lists or a single list
|
|
315
|
+
if prompt_token_ids and isinstance(prompt_token_ids[0], list):
|
|
316
|
+
# List of token ID lists
|
|
317
|
+
prompts = [
|
|
318
|
+
TokensPrompt(prompt_token_ids=tokens) for tokens in prompt_token_ids
|
|
319
|
+
]
|
|
320
|
+
else:
|
|
321
|
+
# Single token ID list - cast to ensure type compatibility
|
|
322
|
+
token_list = list(prompt_token_ids) if prompt_token_ids else []
|
|
323
|
+
prompts = TokensPrompt(prompt_token_ids=token_list)
|
|
324
|
+
single_prompt_input = True
|
|
325
|
+
|
|
326
|
+
elif prompts is None:
|
|
327
|
+
raise ValueError("Must specify either prompts or prompt_token_ids")
|
|
328
|
+
else:
|
|
329
|
+
# prompts was provided directly
|
|
330
|
+
if not isinstance(prompts, (list, tuple)):
|
|
331
|
+
single_prompt_input = True
|
|
332
|
+
|
|
333
|
+
# Default sampling params if not provided
|
|
334
|
+
if sampling_params is None:
|
|
335
|
+
sampling_params = SamplingParams()
|
|
336
|
+
|
|
337
|
+
async def _gen_one(prompt) -> RequestOutput:
|
|
338
|
+
request_id = str(uuid.uuid4())
|
|
339
|
+
final = None
|
|
340
|
+
|
|
341
|
+
# Build kwargs for engine.generate
|
|
342
|
+
gen_kwargs = {
|
|
343
|
+
"prompt": prompt,
|
|
344
|
+
"sampling_params": sampling_params,
|
|
345
|
+
"request_id": request_id,
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
# Add optional parameters if provided
|
|
349
|
+
if lora_request is not None:
|
|
350
|
+
gen_kwargs["lora_request"] = lora_request
|
|
351
|
+
if prompt_adapter_request is not None:
|
|
352
|
+
gen_kwargs["prompt_adapter_request"] = prompt_adapter_request
|
|
353
|
+
if guided_options_request is not None:
|
|
354
|
+
gen_kwargs["guided_options_request"] = guided_options_request
|
|
355
|
+
|
|
356
|
+
async for output in self.engine.generate(**gen_kwargs):
|
|
357
|
+
if output.finished:
|
|
358
|
+
final = output
|
|
359
|
+
assert final is not None
|
|
360
|
+
return final
|
|
361
|
+
|
|
362
|
+
async def _run_generation():
|
|
363
|
+
if single_prompt_input:
|
|
364
|
+
return await _gen_one(prompts)
|
|
365
|
+
|
|
366
|
+
# List of prompts: run concurrently
|
|
367
|
+
tasks = [asyncio.create_task(_gen_one(p)) for p in prompts]
|
|
368
|
+
results = await asyncio.gather(*tasks)
|
|
369
|
+
return results
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
if timeout_seconds is not None and timeout_seconds > 0:
|
|
373
|
+
return await asyncio.wait_for(
|
|
374
|
+
_run_generation(), timeout=timeout_seconds
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
return await _run_generation()
|
|
378
|
+
except TimeoutError:
|
|
379
|
+
# Best-effort cleanup
|
|
380
|
+
try:
|
|
381
|
+
abort_fn = getattr(self.engine, "abort", None)
|
|
382
|
+
if callable(abort_fn):
|
|
383
|
+
# We can't easily track all request IDs, so this is best-effort
|
|
384
|
+
pass
|
|
385
|
+
except Exception:
|
|
386
|
+
pass
|
|
387
|
+
raise TimeoutError(
|
|
388
|
+
f"vLLM generation timed out after {timeout_seconds} seconds"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
async def get_tokenizer(self):
|
|
392
|
+
"""Get the tokenizer from the engine."""
|
|
393
|
+
return await self.engine.get_tokenizer()
|
|
394
|
+
|
|
395
|
+
async def collective_rpc_v1(
|
|
396
|
+
self,
|
|
397
|
+
method: str,
|
|
398
|
+
timeout: float | None = None,
|
|
399
|
+
args: tuple = (),
|
|
400
|
+
kwargs: dict | None = None,
|
|
401
|
+
):
|
|
402
|
+
"""Perform a collective RPC call to the given method (vLLM V1).
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
method (str): Method name to call.
|
|
406
|
+
timeout (float | None): Timeout for the RPC call.
|
|
407
|
+
args (tuple): Arguments to pass to the method.
|
|
408
|
+
kwargs (dict | None): Keyword arguments to pass to the method.
|
|
409
|
+
"""
|
|
410
|
+
from vllm import envs
|
|
411
|
+
|
|
412
|
+
if envs and envs.VLLM_USE_V1:
|
|
413
|
+
return await self.engine.collective_rpc(method, timeout, args, kwargs)
|
|
414
|
+
else:
|
|
415
|
+
return self.engine.engine.collective_rpc(method, timeout, args, kwargs)
|
|
416
|
+
|
|
417
|
+
def collective_rpc_v0(
|
|
418
|
+
self,
|
|
419
|
+
method: str,
|
|
420
|
+
timeout: float | None = None,
|
|
421
|
+
args: tuple = (),
|
|
422
|
+
kwargs: dict | None = None,
|
|
423
|
+
):
|
|
424
|
+
"""Perform a collective RPC call to the given method (vLLM V0).
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
method (str): Method name to call.
|
|
428
|
+
timeout (float | None): Timeout for the RPC call.
|
|
429
|
+
args (tuple): Arguments to pass to the method.
|
|
430
|
+
kwargs (dict | None): Keyword arguments to pass to the method.
|
|
431
|
+
"""
|
|
432
|
+
return self.engine.engine.collective_rpc(method, timeout, args, kwargs)
|
|
433
|
+
|
|
434
|
+
def get_num_unfinished_requests(self) -> int:
|
|
435
|
+
"""Get the number of unfinished requests in the engine.
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
int: Number of unfinished requests.
|
|
439
|
+
"""
|
|
440
|
+
try:
|
|
441
|
+
# Try to access the method directly if available
|
|
442
|
+
if hasattr(self.engine, "get_num_unfinished_requests"):
|
|
443
|
+
return self.engine.get_num_unfinished_requests()
|
|
444
|
+
# Fallback to accessing through engine.engine for v0
|
|
445
|
+
elif hasattr(self.engine, "engine") and hasattr(
|
|
446
|
+
self.engine.engine, "get_num_unfinished_requests"
|
|
447
|
+
):
|
|
448
|
+
return self.engine.engine.get_num_unfinished_requests()
|
|
449
|
+
else:
|
|
450
|
+
# If method not available, return 0 as fallback
|
|
451
|
+
torchrl_logger.warning(
|
|
452
|
+
"get_num_unfinished_requests not available, returning 0"
|
|
453
|
+
)
|
|
454
|
+
return 0
|
|
455
|
+
except Exception as e:
|
|
456
|
+
torchrl_logger.warning(f"Error getting unfinished requests count: {e}")
|
|
457
|
+
return 0
|
|
458
|
+
|
|
459
|
+
def get_cache_usage(self) -> float:
|
|
460
|
+
"""Get the KV cache usage as a fraction between 0 and 1.
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
float: Cache usage fraction (0.0 = empty, 1.0 = full).
|
|
464
|
+
"""
|
|
465
|
+
try:
|
|
466
|
+
# Try to get cache usage from the engine
|
|
467
|
+
if hasattr(self.engine, "engine") and hasattr(
|
|
468
|
+
self.engine.engine, "cache_config"
|
|
469
|
+
):
|
|
470
|
+
# Access the LLM engine's cache information
|
|
471
|
+
cache_config = self.engine.engine.cache_config
|
|
472
|
+
if hasattr(cache_config, "cache_usage"):
|
|
473
|
+
return cache_config.cache_usage
|
|
474
|
+
elif hasattr(self.engine.engine, "scheduler"):
|
|
475
|
+
# Try to get usage from the scheduler
|
|
476
|
+
scheduler = self.engine.engine.scheduler
|
|
477
|
+
if hasattr(scheduler, "get_num_free_gpu_blocks") and hasattr(
|
|
478
|
+
scheduler, "get_num_total_gpu_blocks"
|
|
479
|
+
):
|
|
480
|
+
free_blocks = scheduler.get_num_free_gpu_blocks()
|
|
481
|
+
total_blocks = scheduler.get_num_total_gpu_blocks()
|
|
482
|
+
if total_blocks > 0:
|
|
483
|
+
return 1.0 - (free_blocks / total_blocks)
|
|
484
|
+
# Fallback: return a random value for now (this should be replaced with actual metrics)
|
|
485
|
+
torchrl_logger.warning(
|
|
486
|
+
"Cache usage metrics not available, returning random value"
|
|
487
|
+
)
|
|
488
|
+
return (
|
|
489
|
+
random.random() * 0.5
|
|
490
|
+
) # Return a value between 0 and 0.5 to simulate partial usage
|
|
491
|
+
except Exception as e:
|
|
492
|
+
torchrl_logger.warning(f"Error getting cache usage: {e}")
|
|
493
|
+
return 0.0
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
|
|
497
|
+
"""Get the number of GPUs per replica for the given engine args."""
|
|
498
|
+
return (
|
|
499
|
+
engine_args.tensor_parallel_size
|
|
500
|
+
* getattr(engine_args, "data_parallel_size", 1) # Default to 1 if not present
|
|
501
|
+
* getattr(
|
|
502
|
+
engine_args, "pipeline_parallel_size", 1
|
|
503
|
+
) # Default to 1 if not present
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
# Ray actor wrapper is created lazily in __init__ to avoid global Ray import.
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class AsyncVLLM(RLvLLMEngine):
|
|
511
|
+
"""A service that manages multiple async vLLM engine actors for distributed inference.
|
|
512
|
+
|
|
513
|
+
This is the main entry point for async vLLM inference in TorchRL. It manages multiple
|
|
514
|
+
vLLM engine replicas running as Ray actors, providing load balancing, weight updates,
|
|
515
|
+
and a unified interface for text generation.
|
|
516
|
+
|
|
517
|
+
The service automatically handles Ray actor lifecycle management, GPU allocation through
|
|
518
|
+
placement groups, and provides both synchronous and asynchronous generation interfaces
|
|
519
|
+
that are compatible with the standard vLLM API.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
engine_args (AsyncEngineArgs): Configuration for the vLLM engines.
|
|
523
|
+
num_replicas (int, optional): Number of engine replicas to create. Defaults to 1.
|
|
524
|
+
actor_class (optional): Custom Ray actor class. Defaults to the internal actor implementation.
|
|
525
|
+
enable_prefix_caching (bool, optional): Whether to enable prefix caching. Defaults to False.
|
|
526
|
+
|
|
527
|
+
.. warning::
|
|
528
|
+
enable_prefix_caching is set to False by default, which is recommended if prompt log probs are needed.
|
|
529
|
+
Set it to True if prompt log probs are not needed.
|
|
530
|
+
See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
|
|
531
|
+
|
|
532
|
+
Example:
|
|
533
|
+
>>> from torchrl.modules.llm import AsyncVLLM
|
|
534
|
+
>>> from vllm import SamplingParams
|
|
535
|
+
>>>
|
|
536
|
+
>>> # Simple usage - single GPU, single replica
|
|
537
|
+
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
|
|
538
|
+
>>>
|
|
539
|
+
>>> # Advanced usage - multi-GPU tensor parallel with multiple replicas
|
|
540
|
+
>>> service = AsyncVLLM.from_pretrained(
|
|
541
|
+
... "Qwen/Qwen2.5-7B",
|
|
542
|
+
... num_devices=2, # Use 2 GPUs for tensor parallelism
|
|
543
|
+
... num_replicas=2, # Create 2 replicas for higher throughput
|
|
544
|
+
... max_model_len=4096
|
|
545
|
+
... )
|
|
546
|
+
>>>
|
|
547
|
+
>>> # Generate text
|
|
548
|
+
>>> sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
|
|
549
|
+
>>> result = service.generate("Hello, world!", sampling_params)
|
|
550
|
+
>>> print(result.outputs[0].text)
|
|
551
|
+
>>>
|
|
552
|
+
>>> # Alternative: using AsyncEngineArgs directly for advanced configuration
|
|
553
|
+
>>> from vllm import AsyncEngineArgs
|
|
554
|
+
>>> engine_args = AsyncEngineArgs(
|
|
555
|
+
... model="Qwen/Qwen2.5-3B",
|
|
556
|
+
... tensor_parallel_size=2
|
|
557
|
+
... )
|
|
558
|
+
>>> service = AsyncVLLM.launch(engine_args, num_replicas=2)
|
|
559
|
+
|
|
560
|
+
.. note::
|
|
561
|
+
**Architecture and Design**
|
|
562
|
+
|
|
563
|
+
The AsyncVLLM service implements a distributed inference architecture with the following key components:
|
|
564
|
+
|
|
565
|
+
1. **Ray Actor Management**: Each replica runs as a separate Ray actor with dedicated GPU resources.
|
|
566
|
+
The service creates a placement group to ensure optimal GPU allocation and co-location of
|
|
567
|
+
tensor-parallel workers on the same node when possible.
|
|
568
|
+
|
|
569
|
+
2. **Load Balancing**: Generation requests are distributed across replicas using random selection
|
|
570
|
+
by default, or can target specific replicas using the `actor_index` parameter.
|
|
571
|
+
|
|
572
|
+
3. **Weight Synchronization**: The service supports weight updates across all replicas through
|
|
573
|
+
NCCL communication groups, enabling integration with distributed training workflows.
|
|
574
|
+
|
|
575
|
+
4. **Resource Management**: Automatic GPU allocation and cleanup through Ray placement groups,
|
|
576
|
+
with proper shutdown procedures to prevent resource leaks.
|
|
577
|
+
|
|
578
|
+
5. **API Compatibility**: Provides the same interface as vLLM's synchronous `LLM.generate()`
|
|
579
|
+
method, making it a drop-in replacement for async workloads.
|
|
580
|
+
|
|
581
|
+
**Ray Integration**
|
|
582
|
+
|
|
583
|
+
The service leverages Ray's actor model for distributed execution. Each replica is an independent
|
|
584
|
+
Ray actor that can be scheduled on different nodes. The service handles actor lifecycle,
|
|
585
|
+
monitors readiness, and provides centralized access to all replicas.
|
|
586
|
+
|
|
587
|
+
**Performance Considerations**
|
|
588
|
+
|
|
589
|
+
- Prefix caching is enabled by default for better performance with repeated prompts
|
|
590
|
+
- Tensor parallelism is supported for large models that don't fit on single GPUs
|
|
591
|
+
- Multiple replicas allow concurrent processing of different requests
|
|
592
|
+
- Native vLLM batching is used within each replica for optimal throughput
|
|
593
|
+
|
|
594
|
+
**Error Handling**
|
|
595
|
+
|
|
596
|
+
The service includes timeout support, graceful shutdown procedures, and best-effort
|
|
597
|
+
request cleanup on failures. Ray's fault tolerance mechanisms provide additional
|
|
598
|
+
resilience for long-running inference workloads.
|
|
599
|
+
"""
|
|
600
|
+
|
|
601
|
+
def __init__(
|
|
602
|
+
self,
|
|
603
|
+
engine_args: AsyncEngineArgs,
|
|
604
|
+
num_replicas: int = 1,
|
|
605
|
+
actor_class=None,
|
|
606
|
+
enable_prefix_caching: bool = False,
|
|
607
|
+
):
|
|
608
|
+
if not _has_vllm:
|
|
609
|
+
raise ImportError(
|
|
610
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
611
|
+
)
|
|
612
|
+
# Lazily import ray only when constructing the actor class to avoid global import
|
|
613
|
+
|
|
614
|
+
# Enable prefix caching by default for better performance
|
|
615
|
+
engine_args.enable_prefix_caching = enable_prefix_caching
|
|
616
|
+
|
|
617
|
+
self.engine_args = engine_args
|
|
618
|
+
self.num_replicas = num_replicas
|
|
619
|
+
if actor_class is None:
|
|
620
|
+
ray = _get_ray()
|
|
621
|
+
self.actor_class = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
|
|
622
|
+
else:
|
|
623
|
+
self.actor_class = actor_class
|
|
624
|
+
self.actors: list = []
|
|
625
|
+
self._launched = False
|
|
626
|
+
self._service_id = uuid.uuid4().hex[
|
|
627
|
+
:8
|
|
628
|
+
] # Unique suffix to avoid name collisions
|
|
629
|
+
self._placement_group = None
|
|
630
|
+
self._load_balancer = None
|
|
631
|
+
|
|
632
|
+
def _launch(self):
|
|
633
|
+
"""Launch all actor replicas."""
|
|
634
|
+
if self._launched:
|
|
635
|
+
torchrl_logger.warning("AsyncVLLMEngineService already launched")
|
|
636
|
+
return
|
|
637
|
+
|
|
638
|
+
# Local imports to avoid global Ray dependency
|
|
639
|
+
ray = _get_ray()
|
|
640
|
+
from ray.util.placement_group import placement_group
|
|
641
|
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
642
|
+
|
|
643
|
+
torchrl_logger.info(
|
|
644
|
+
f"Launching {self.num_replicas} async vLLM engine actors..."
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
# Create placement groups - one per replica to avoid conflicts
|
|
648
|
+
self._placement_groups = []
|
|
649
|
+
|
|
650
|
+
# Create actor replicas sequentially to avoid race conditions
|
|
651
|
+
for i in range(self.num_replicas):
|
|
652
|
+
torchrl_logger.info(
|
|
653
|
+
f"Creating async actor replica {i + 1}/{self.num_replicas} ..."
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Create individual placement group for this replica
|
|
657
|
+
num_gpus = _gpus_per_replica(self.engine_args)
|
|
658
|
+
bundles = [{"GPU": 1.0, "CPU": 1.0} for _ in range(num_gpus)]
|
|
659
|
+
torchrl_logger.info(
|
|
660
|
+
f"Creating placement group for replica {i + 1} with {len(bundles)} bundles"
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
placement_group_name = f"vllm-replica-{self._service_id}-{i}"
|
|
664
|
+
pg = placement_group(bundles, strategy="PACK", name=placement_group_name)
|
|
665
|
+
self._placement_groups.append(pg)
|
|
666
|
+
torchrl_logger.info(f"Placement group {placement_group_name} created: {pg}")
|
|
667
|
+
|
|
668
|
+
# Wait for placement group to be ready
|
|
669
|
+
ray.get(pg.ready(), timeout=180)
|
|
670
|
+
torchrl_logger.info(f"Placement group {placement_group_name} ready")
|
|
671
|
+
|
|
672
|
+
# Calculate bundle indices for tensor parallelism
|
|
673
|
+
bundle_indices = None
|
|
674
|
+
if num_gpus > 1:
|
|
675
|
+
bundle_indices = list(range(num_gpus))
|
|
676
|
+
bundle_index = 0 # Always use first bundle since each replica has its own placement group
|
|
677
|
+
|
|
678
|
+
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
|
679
|
+
placement_group=pg,
|
|
680
|
+
placement_group_capture_child_tasks=True,
|
|
681
|
+
placement_group_bundle_index=bundle_index,
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
actor = self.actor_class.options(
|
|
685
|
+
name=f"async-vllm-replica-{self._service_id}-{i}",
|
|
686
|
+
namespace="torchrl_vllm",
|
|
687
|
+
scheduling_strategy=scheduling_strategy,
|
|
688
|
+
num_gpus=0,
|
|
689
|
+
num_cpus=0,
|
|
690
|
+
).remote(
|
|
691
|
+
engine_args=self.engine_args,
|
|
692
|
+
bundle_indices=bundle_indices,
|
|
693
|
+
enable_prefix_caching=self.engine_args.enable_prefix_caching,
|
|
694
|
+
)
|
|
695
|
+
self.actors.append(actor)
|
|
696
|
+
|
|
697
|
+
torchrl_logger.info("Waiting for actors to be ready")
|
|
698
|
+
# Wait for this actor to be ready before creating the next one
|
|
699
|
+
ready_futures = [actor.ready.remote() for actor in self.actors]
|
|
700
|
+
try:
|
|
701
|
+
ray.get(
|
|
702
|
+
ready_futures, timeout=TIMEOUT_SECONDS
|
|
703
|
+
) # 5 minute timeout for engine initialization
|
|
704
|
+
torchrl_logger.info("✅ Actors are ready")
|
|
705
|
+
except Exception as e:
|
|
706
|
+
torchrl_logger.error(
|
|
707
|
+
f"❌ Failed to initialize actors within {TIMEOUT_SECONDS} seconds: {e}. You can increase the timeout by setting the TORCHRL_VLLM_TIMEOUT_SECONDS environment variable."
|
|
708
|
+
)
|
|
709
|
+
raise
|
|
710
|
+
|
|
711
|
+
# Store the first placement group for backward compatibility
|
|
712
|
+
self._placement_group = (
|
|
713
|
+
self._placement_groups[0] if self._placement_groups else None
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
self._launched = True
|
|
717
|
+
torchrl_logger.info(
|
|
718
|
+
f"✅ Successfully launched {len(self.actors)} async vLLM engine actors"
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
@classmethod
|
|
722
|
+
def launch(
|
|
723
|
+
cls,
|
|
724
|
+
engine_args: AsyncEngineArgs,
|
|
725
|
+
num_replicas: int = 1,
|
|
726
|
+
) -> AsyncVLLM:
|
|
727
|
+
"""Launch a new AsyncVLLMEngineService.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
engine_args (AsyncEngineArgs): Arguments for creating the AsyncLLMEngine instances.
|
|
731
|
+
num_replicas (int): Number of actor replicas to create.
|
|
732
|
+
|
|
733
|
+
Returns:
|
|
734
|
+
AsyncVLLMEngineService: The launched service.
|
|
735
|
+
"""
|
|
736
|
+
service = cls(engine_args, num_replicas)
|
|
737
|
+
service._launch()
|
|
738
|
+
# create a default load balancer with smart routing
|
|
739
|
+
service.create_load_balancer()
|
|
740
|
+
return service
|
|
741
|
+
|
|
742
|
+
@classmethod
|
|
743
|
+
def from_pretrained(
|
|
744
|
+
cls,
|
|
745
|
+
model_name: str,
|
|
746
|
+
num_devices: int | None = None,
|
|
747
|
+
num_replicas: int = 1,
|
|
748
|
+
verbose: bool = True,
|
|
749
|
+
compile: bool = True,
|
|
750
|
+
enable_fp32_output: bool = False,
|
|
751
|
+
**kwargs,
|
|
752
|
+
) -> AsyncVLLM:
|
|
753
|
+
"""Create an AsyncVLLM instance from a pretrained model.
|
|
754
|
+
|
|
755
|
+
This is a convenience method that combines model loading and service launching
|
|
756
|
+
in a single call, similar to how other ML libraries work.
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
model_name (str): The model name to pass to vLLM.
|
|
760
|
+
num_devices (int, optional): Number of devices to use, per replica.
|
|
761
|
+
num_replicas (int): Number of engine replicas to create.
|
|
762
|
+
verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True.
|
|
763
|
+
compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
|
|
764
|
+
enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
|
|
765
|
+
**kwargs: Additional arguments passed to AsyncEngineArgs.
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
AsyncVLLM: The launched async vLLM service.
|
|
769
|
+
|
|
770
|
+
Example:
|
|
771
|
+
>>> # Simple usage with defaults
|
|
772
|
+
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
|
|
773
|
+
>>>
|
|
774
|
+
>>> # Multi-GPU tensor parallel with multiple replicas
|
|
775
|
+
>>> service = AsyncVLLM.from_pretrained(
|
|
776
|
+
... "Qwen/Qwen2.5-7B",
|
|
777
|
+
... num_devices=2,
|
|
778
|
+
... num_replicas=2,
|
|
779
|
+
... max_model_len=4096
|
|
780
|
+
... )
|
|
781
|
+
>>>
|
|
782
|
+
>>> # Generate text
|
|
783
|
+
>>> from vllm import SamplingParams
|
|
784
|
+
>>> result = service.generate("Hello, world!", SamplingParams(max_tokens=50))
|
|
785
|
+
>>>
|
|
786
|
+
>>> # Enable FP32 output for better numerical stability
|
|
787
|
+
>>> service = AsyncVLLM.from_pretrained(
|
|
788
|
+
... "Qwen/Qwen2.5-3B",
|
|
789
|
+
... enable_fp32_output=True
|
|
790
|
+
... )
|
|
791
|
+
"""
|
|
792
|
+
return make_async_vllm_engine(
|
|
793
|
+
model_name=model_name,
|
|
794
|
+
num_devices=num_devices,
|
|
795
|
+
num_replicas=num_replicas,
|
|
796
|
+
verbose=verbose,
|
|
797
|
+
compile=compile,
|
|
798
|
+
enable_fp32_output=enable_fp32_output,
|
|
799
|
+
**kwargs,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
def _is_batch(
|
|
803
|
+
self, prompts: Any, prompt_token_ids: list[int] | list[list[int]] | None = None
|
|
804
|
+
) -> bool:
|
|
805
|
+
"""Check if the input represents a batch of prompts.
|
|
806
|
+
|
|
807
|
+
Args:
|
|
808
|
+
prompts: Input prompts that can be string, TokensPrompt, or list of these
|
|
809
|
+
prompt_token_ids: Alternative token IDs input
|
|
810
|
+
|
|
811
|
+
Returns:
|
|
812
|
+
bool: True if this represents multiple prompts, False for single prompt
|
|
813
|
+
"""
|
|
814
|
+
# If prompts is a list, we need to determine if it's a batch or a single prompt
|
|
815
|
+
if isinstance(prompts, list):
|
|
816
|
+
# Empty list is not a batch
|
|
817
|
+
if len(prompts) == 0:
|
|
818
|
+
return False
|
|
819
|
+
|
|
820
|
+
# If all elements are integers, it's a single prompt represented as token IDs
|
|
821
|
+
# We trust that if one is an int, then all are ints.
|
|
822
|
+
if any(isinstance(item, int) for item in prompts):
|
|
823
|
+
return False
|
|
824
|
+
|
|
825
|
+
# If it contains strings, TokensPrompt objects, or other non-integer types,
|
|
826
|
+
# it's a batch of prompts
|
|
827
|
+
return True
|
|
828
|
+
|
|
829
|
+
# If prompt_token_ids is provided and is a list of lists, it's a batch
|
|
830
|
+
if prompt_token_ids is not None and isinstance(prompt_token_ids, list):
|
|
831
|
+
if len(prompt_token_ids) > 0 and isinstance(prompt_token_ids[0], list):
|
|
832
|
+
return True
|
|
833
|
+
|
|
834
|
+
return False
|
|
835
|
+
|
|
836
|
+
def _iterate(
|
|
837
|
+
self, prompts: Any, prompt_token_ids: list[int] | list[list[int]] | None = None
|
|
838
|
+
):
|
|
839
|
+
"""Iterate over individual prompts in a batch.
|
|
840
|
+
|
|
841
|
+
Args:
|
|
842
|
+
prompts: Input prompts that can be string, TokensPrompt, or list of these
|
|
843
|
+
prompt_token_ids: Alternative token IDs input
|
|
844
|
+
|
|
845
|
+
Yields:
|
|
846
|
+
tuple: (individual_prompt, individual_prompt_token_ids) for each item
|
|
847
|
+
"""
|
|
848
|
+
if isinstance(prompts, list):
|
|
849
|
+
# Check if this is actually a single prompt represented as token IDs
|
|
850
|
+
if all(isinstance(item, int) for item in prompts):
|
|
851
|
+
# This is a single prompt as token IDs, not a batch
|
|
852
|
+
yield prompts, prompt_token_ids
|
|
853
|
+
return
|
|
854
|
+
|
|
855
|
+
# Handle list of prompts (actual batch)
|
|
856
|
+
if prompt_token_ids is None:
|
|
857
|
+
for prompt in prompts:
|
|
858
|
+
yield prompt, None
|
|
859
|
+
elif (
|
|
860
|
+
isinstance(prompt_token_ids, list)
|
|
861
|
+
and len(prompt_token_ids) > 0
|
|
862
|
+
and isinstance(prompt_token_ids[0], list)
|
|
863
|
+
):
|
|
864
|
+
# Both prompts and prompt_token_ids are lists
|
|
865
|
+
for prompt, token_ids in zip(prompts, prompt_token_ids):
|
|
866
|
+
yield prompt, token_ids
|
|
867
|
+
else:
|
|
868
|
+
# prompts is list, but prompt_token_ids is single list - replicate it
|
|
869
|
+
for prompt in prompts:
|
|
870
|
+
yield prompt, prompt_token_ids
|
|
871
|
+
else:
|
|
872
|
+
# Single prompt case
|
|
873
|
+
if (
|
|
874
|
+
prompt_token_ids is not None
|
|
875
|
+
and isinstance(prompt_token_ids, list)
|
|
876
|
+
and len(prompt_token_ids) > 0
|
|
877
|
+
and isinstance(prompt_token_ids[0], list)
|
|
878
|
+
):
|
|
879
|
+
# Single prompt but multiple token_ids - replicate prompt
|
|
880
|
+
for token_ids in prompt_token_ids:
|
|
881
|
+
yield prompts, token_ids
|
|
882
|
+
else:
|
|
883
|
+
# Single prompt, single (or no) token_ids
|
|
884
|
+
yield prompts, prompt_token_ids
|
|
885
|
+
|
|
886
|
+
def _generate_impl(
|
|
887
|
+
self,
|
|
888
|
+
prompt: Any,
|
|
889
|
+
sampling_params: SamplingParams | None = None,
|
|
890
|
+
*,
|
|
891
|
+
prompt_token_ids: list[int] | None = None,
|
|
892
|
+
use_tqdm: bool = True,
|
|
893
|
+
lora_request: Any = None,
|
|
894
|
+
prompt_adapter_request: Any = None,
|
|
895
|
+
guided_options_request: Any = None,
|
|
896
|
+
timeout_seconds: float | None = None,
|
|
897
|
+
actor_index: int | None = None,
|
|
898
|
+
):
|
|
899
|
+
"""Generate text for a single prompt and return a Ray future.
|
|
900
|
+
|
|
901
|
+
This is the internal implementation that returns a future instead of the result.
|
|
902
|
+
Used for batched generation to enable parallel execution.
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
prompt: Single prompt (string, TokensPrompt, etc.)
|
|
906
|
+
sampling_params: SamplingParams object for controlling generation behavior
|
|
907
|
+
prompt_token_ids: Token IDs for a single prompt
|
|
908
|
+
use_tqdm: Whether to show progress bar (not used in async engine)
|
|
909
|
+
lora_request: LoRA request for adapter-based generation
|
|
910
|
+
prompt_adapter_request: Prompt adapter request
|
|
911
|
+
guided_options_request: Guided decoding options
|
|
912
|
+
timeout_seconds: Timeout for generation in seconds
|
|
913
|
+
actor_index: Specific actor to use (random if None)
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
Ray ObjectRef: Future that will resolve to RequestOutput
|
|
917
|
+
"""
|
|
918
|
+
if actor_index is None:
|
|
919
|
+
if len(self.actors) == 1:
|
|
920
|
+
actor = self.actors[0]
|
|
921
|
+
else:
|
|
922
|
+
if self._load_balancer is None:
|
|
923
|
+
raise RuntimeError(
|
|
924
|
+
"LoadBalancer is not created. Create a LoadBalancer using AsyncVLLM.create_load_balancer before calling generate."
|
|
925
|
+
)
|
|
926
|
+
# Extract single prompt for prefix-aware routing
|
|
927
|
+
single_prompt = self._extract_single_prompt_for_routing(
|
|
928
|
+
prompt, prompt_token_ids
|
|
929
|
+
)
|
|
930
|
+
actor_index = self._load_balancer.select_actor(prompt=single_prompt)
|
|
931
|
+
actor = self.actors[actor_index]
|
|
932
|
+
else:
|
|
933
|
+
actor = self.actors[actor_index]
|
|
934
|
+
|
|
935
|
+
return actor.generate.remote(
|
|
936
|
+
prompt,
|
|
937
|
+
sampling_params,
|
|
938
|
+
prompt_token_ids=prompt_token_ids,
|
|
939
|
+
use_tqdm=use_tqdm,
|
|
940
|
+
lora_request=lora_request,
|
|
941
|
+
prompt_adapter_request=prompt_adapter_request,
|
|
942
|
+
guided_options_request=guided_options_request,
|
|
943
|
+
timeout_seconds=timeout_seconds,
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
def generate(
|
|
947
|
+
self,
|
|
948
|
+
prompts: Any = None,
|
|
949
|
+
sampling_params: SamplingParams | None = None,
|
|
950
|
+
*,
|
|
951
|
+
prompt_token_ids: list[int] | list[list[int]] | None = None,
|
|
952
|
+
use_tqdm: bool = True,
|
|
953
|
+
lora_request: Any = None,
|
|
954
|
+
prompt_adapter_request: Any = None,
|
|
955
|
+
guided_options_request: Any = None,
|
|
956
|
+
timeout_seconds: float | None = None,
|
|
957
|
+
actor_index: int | None = None,
|
|
958
|
+
) -> RequestOutput | list[RequestOutput]:
|
|
959
|
+
"""Generate text using one of the actors with vLLM.LLM.generate interface.
|
|
960
|
+
|
|
961
|
+
This method provides the same interface as vLLM.LLM.generate for seamless
|
|
962
|
+
compatibility between sync and async engines. It can be used to generate text
|
|
963
|
+
within multiple threads / actors. If `actor_index` is not provided, the load balancer
|
|
964
|
+
will be used to select the actor.
|
|
965
|
+
|
|
966
|
+
`generate` is a blocking method, so it will wait for the generation to complete.
|
|
967
|
+
|
|
968
|
+
Args:
|
|
969
|
+
prompts (String, TokensPrompt, or list of these): Input prompts for generation.
|
|
970
|
+
sampling_params (SamplingParams): SamplingParams object for controlling generation behavior.
|
|
971
|
+
prompt_token_ids (list[int] | list[list[int]]): Alternative to prompts - token IDs for generation.
|
|
972
|
+
use_tqdm (bool): Whether to show progress bar (not used in async engine).
|
|
973
|
+
lora_request (Any): LoRA request for adapter-based generation.
|
|
974
|
+
prompt_adapter_request (Any): Prompt adapter request.
|
|
975
|
+
guided_options_request (Any): Guided decoding options.
|
|
976
|
+
timeout_seconds (float | None): Timeout for generation in seconds.
|
|
977
|
+
actor_index (int | None): Specific actor to use (random if None).
|
|
978
|
+
|
|
979
|
+
Returns:
|
|
980
|
+
RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
|
|
981
|
+
"""
|
|
982
|
+
ray = _get_ray()
|
|
983
|
+
# Check if this is a batch request
|
|
984
|
+
if self._is_batch(prompts, prompt_token_ids):
|
|
985
|
+
# Handle batched input by unbinding and sending individual requests
|
|
986
|
+
futures = []
|
|
987
|
+
for prompt, prompt_token_ids_i in self._iterate(prompts, prompt_token_ids):
|
|
988
|
+
future = self._generate_impl(
|
|
989
|
+
prompt,
|
|
990
|
+
sampling_params,
|
|
991
|
+
prompt_token_ids=prompt_token_ids_i,
|
|
992
|
+
use_tqdm=use_tqdm,
|
|
993
|
+
lora_request=lora_request,
|
|
994
|
+
prompt_adapter_request=prompt_adapter_request,
|
|
995
|
+
guided_options_request=guided_options_request,
|
|
996
|
+
timeout_seconds=timeout_seconds,
|
|
997
|
+
actor_index=actor_index,
|
|
998
|
+
)
|
|
999
|
+
futures.append(future)
|
|
1000
|
+
|
|
1001
|
+
# Collect all results
|
|
1002
|
+
results = ray.get(futures)
|
|
1003
|
+
return results
|
|
1004
|
+
else:
|
|
1005
|
+
# Single prompt case - call _generate_impt and get result directly
|
|
1006
|
+
future = self._generate_impl(
|
|
1007
|
+
prompts,
|
|
1008
|
+
sampling_params,
|
|
1009
|
+
prompt_token_ids=prompt_token_ids,
|
|
1010
|
+
use_tqdm=use_tqdm,
|
|
1011
|
+
lora_request=lora_request,
|
|
1012
|
+
prompt_adapter_request=prompt_adapter_request,
|
|
1013
|
+
guided_options_request=guided_options_request,
|
|
1014
|
+
timeout_seconds=timeout_seconds,
|
|
1015
|
+
actor_index=actor_index,
|
|
1016
|
+
)
|
|
1017
|
+
result = ray.get(future)
|
|
1018
|
+
return result
|
|
1019
|
+
|
|
1020
|
+
def get_random_actor_index(self) -> int:
|
|
1021
|
+
"""Get a random actor index."""
|
|
1022
|
+
return random.randint(0, len(self.actors) - 1)
|
|
1023
|
+
|
|
1024
|
+
def _init_weight_update_group_internal(self, master_address: str, master_port: str):
|
|
1025
|
+
"""Initialize NCCL weight update group across all actors.
|
|
1026
|
+
|
|
1027
|
+
Args:
|
|
1028
|
+
master_address (str): Master address for distributed training.
|
|
1029
|
+
master_port (str): Master port for distributed training.
|
|
1030
|
+
|
|
1031
|
+
Returns:
|
|
1032
|
+
list: Ray futures for initialization calls.
|
|
1033
|
+
"""
|
|
1034
|
+
gpus_per_replica = _gpus_per_replica(self.engine_args)
|
|
1035
|
+
weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
|
|
1036
|
+
torchrl_logger.info(
|
|
1037
|
+
f"AsyncVLLMEngineService requests weight update group for {self.num_replicas} actors "
|
|
1038
|
+
f"with {gpus_per_replica} GPUs per replica and {weight_sync_world_size} world size"
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
from vllm import envs
|
|
1042
|
+
|
|
1043
|
+
refs = []
|
|
1044
|
+
for i, actor in enumerate(self.actors):
|
|
1045
|
+
rank_offset = 1 + i * gpus_per_replica
|
|
1046
|
+
if envs and envs.VLLM_USE_V1:
|
|
1047
|
+
actor_collective_rpc = actor.collective_rpc_v1
|
|
1048
|
+
else:
|
|
1049
|
+
actor_collective_rpc = actor.collective_rpc_v0
|
|
1050
|
+
|
|
1051
|
+
refs.append(
|
|
1052
|
+
actor_collective_rpc.remote(
|
|
1053
|
+
"init_weight_update_group",
|
|
1054
|
+
args=(
|
|
1055
|
+
master_address,
|
|
1056
|
+
master_port,
|
|
1057
|
+
rank_offset,
|
|
1058
|
+
weight_sync_world_size,
|
|
1059
|
+
),
|
|
1060
|
+
)
|
|
1061
|
+
)
|
|
1062
|
+
torchrl_logger.info(
|
|
1063
|
+
f"AsyncVLLMEngineService args: {master_address=}, {master_port=}, "
|
|
1064
|
+
f"{rank_offset=}, {weight_sync_world_size=}"
|
|
1065
|
+
)
|
|
1066
|
+
torchrl_logger.info(
|
|
1067
|
+
f"AsyncVLLMEngineService requests weight update group for actor {i} "
|
|
1068
|
+
f"with rank_offset {rank_offset}"
|
|
1069
|
+
)
|
|
1070
|
+
return refs
|
|
1071
|
+
|
|
1072
|
+
def collective_rpc(
|
|
1073
|
+
self,
|
|
1074
|
+
method: str,
|
|
1075
|
+
timeout: float | None = None,
|
|
1076
|
+
args: tuple = (),
|
|
1077
|
+
kwargs: dict | None = None,
|
|
1078
|
+
) -> list[Any]:
|
|
1079
|
+
"""Forward an RPC to all actors.
|
|
1080
|
+
|
|
1081
|
+
Args:
|
|
1082
|
+
method (str): Method name to call.
|
|
1083
|
+
timeout (float | None): Timeout for the RPC call.
|
|
1084
|
+
args (tuple): Arguments to pass to the method.
|
|
1085
|
+
kwargs (dict | None): Keyword arguments to pass to the method.
|
|
1086
|
+
|
|
1087
|
+
Returns:
|
|
1088
|
+
list[Any]: Ray futures for all RPC calls.
|
|
1089
|
+
"""
|
|
1090
|
+
from vllm import envs
|
|
1091
|
+
|
|
1092
|
+
futures = []
|
|
1093
|
+
for actor in self.actors:
|
|
1094
|
+
if envs and envs.VLLM_USE_V1:
|
|
1095
|
+
actor_collective_rpc = actor.collective_rpc_v1
|
|
1096
|
+
else:
|
|
1097
|
+
actor_collective_rpc = actor.collective_rpc_v0
|
|
1098
|
+
futures.append(actor_collective_rpc.remote(method, timeout, args, kwargs))
|
|
1099
|
+
return futures
|
|
1100
|
+
|
|
1101
|
+
def shutdown(self):
|
|
1102
|
+
"""Shutdown all actors and clean up resources."""
|
|
1103
|
+
torchrl_logger.info(
|
|
1104
|
+
f"Shutting down {len(self.actors)} async vLLM engine actors..."
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
ray = _get_ray()
|
|
1108
|
+
from ray.util.placement_group import remove_placement_group
|
|
1109
|
+
|
|
1110
|
+
# Kill all actors
|
|
1111
|
+
for i, actor in enumerate(self.actors):
|
|
1112
|
+
try:
|
|
1113
|
+
ray.kill(actor)
|
|
1114
|
+
torchrl_logger.info(f"Shutdown async actor {i + 1}/{len(self.actors)}")
|
|
1115
|
+
except Exception as e:
|
|
1116
|
+
torchrl_logger.warning(f"Error shutting down async actor {i + 1}: {e}")
|
|
1117
|
+
|
|
1118
|
+
# Clear the actors list
|
|
1119
|
+
self.actors.clear()
|
|
1120
|
+
|
|
1121
|
+
# Remove placement groups if any
|
|
1122
|
+
if hasattr(self, "_placement_groups") and self._placement_groups:
|
|
1123
|
+
for i, pg in enumerate(self._placement_groups):
|
|
1124
|
+
try:
|
|
1125
|
+
remove_placement_group(pg)
|
|
1126
|
+
torchrl_logger.info(
|
|
1127
|
+
f"Removed placement group {i + 1}/{len(self._placement_groups)}"
|
|
1128
|
+
)
|
|
1129
|
+
except Exception as e:
|
|
1130
|
+
torchrl_logger.warning(
|
|
1131
|
+
f"Error removing placement group {i + 1}: {e}"
|
|
1132
|
+
)
|
|
1133
|
+
self._placement_groups = []
|
|
1134
|
+
|
|
1135
|
+
# Remove legacy single placement group if any
|
|
1136
|
+
if self._placement_group is not None:
|
|
1137
|
+
remove_placement_group(self._placement_group)
|
|
1138
|
+
self._placement_group = None
|
|
1139
|
+
self._launched = False
|
|
1140
|
+
torchrl_logger.info("AsyncVLLMEngineService shutdown complete")
|
|
1141
|
+
|
|
1142
|
+
# RLvLLMEngine interface implementation
|
|
1143
|
+
def get_tp_size(self) -> int:
|
|
1144
|
+
"""Get the tensor parallel size."""
|
|
1145
|
+
return self.engine_args.tensor_parallel_size
|
|
1146
|
+
|
|
1147
|
+
def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
|
|
1148
|
+
"""Get model parameter metadata.
|
|
1149
|
+
|
|
1150
|
+
Note: This requires the model to be loaded. For now, we return an empty dict
|
|
1151
|
+
and expect the metadata to be provided externally during weight updates.
|
|
1152
|
+
"""
|
|
1153
|
+
# TODO: Implement metadata extraction from loaded model
|
|
1154
|
+
# This would require accessing the model from one of the actors
|
|
1155
|
+
torchrl_logger.warning(
|
|
1156
|
+
"AsyncVLLM.get_model_metadata() not yet implemented - returning empty dict"
|
|
1157
|
+
)
|
|
1158
|
+
return {}
|
|
1159
|
+
|
|
1160
|
+
def get_master_address(self) -> str:
|
|
1161
|
+
"""Get the master address for weight synchronization."""
|
|
1162
|
+
return "localhost" # Default for now
|
|
1163
|
+
|
|
1164
|
+
def get_master_port(self) -> int:
|
|
1165
|
+
"""Get the master port for weight synchronization."""
|
|
1166
|
+
# Cache the port like V1 does to ensure consistency
|
|
1167
|
+
if not hasattr(self, "_cached_master_port"):
|
|
1168
|
+
if _has_vllm:
|
|
1169
|
+
try:
|
|
1170
|
+
from vllm.utils import get_open_port
|
|
1171
|
+
|
|
1172
|
+
self._cached_master_port = get_open_port()
|
|
1173
|
+
except ImportError:
|
|
1174
|
+
self._cached_master_port = 29500 # Default port if import fails
|
|
1175
|
+
else:
|
|
1176
|
+
self._cached_master_port = 29500 # Default port
|
|
1177
|
+
return self._cached_master_port
|
|
1178
|
+
|
|
1179
|
+
def init_weight_update_group(
|
|
1180
|
+
self,
|
|
1181
|
+
master_address: str,
|
|
1182
|
+
master_port: int | str,
|
|
1183
|
+
) -> list[Any]:
|
|
1184
|
+
"""Forward the request to init NCCL weight update group to all actors.
|
|
1185
|
+
|
|
1186
|
+
This method initializes the weight update group for all vLLM workers.
|
|
1187
|
+
The external trainer should be rank 0, and vLLM workers will be ranks 1+.
|
|
1188
|
+
|
|
1189
|
+
Args:
|
|
1190
|
+
master_address: Master address for NCCL communication.
|
|
1191
|
+
master_port: Master port for NCCL communication.
|
|
1192
|
+
|
|
1193
|
+
Returns:
|
|
1194
|
+
List of Ray futures for the initialization calls.
|
|
1195
|
+
|
|
1196
|
+
Note:
|
|
1197
|
+
The caller must wait on the returned futures (ray.get(refs)) to ensure
|
|
1198
|
+
all workers have completed initialization before sending weights.
|
|
1199
|
+
"""
|
|
1200
|
+
if not self._launched:
|
|
1201
|
+
raise RuntimeError(
|
|
1202
|
+
"AsyncVLLM service must be launched before initializing weight update group"
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
gpus_per_replica = _gpus_per_replica(self.engine_args)
|
|
1206
|
+
weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
|
|
1207
|
+
|
|
1208
|
+
torchrl_logger.info(
|
|
1209
|
+
f"Initializing weight update group for {self.num_replicas} replicas "
|
|
1210
|
+
f"with {gpus_per_replica} GPUs each (world_size={weight_sync_world_size})"
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1213
|
+
from vllm import envs
|
|
1214
|
+
|
|
1215
|
+
refs = []
|
|
1216
|
+
for i, actor in enumerate(self.actors):
|
|
1217
|
+
rank_offset = 1 + i * gpus_per_replica
|
|
1218
|
+
if envs and envs.VLLM_USE_V1:
|
|
1219
|
+
actor_collective_rpc = actor.collective_rpc_v1
|
|
1220
|
+
else:
|
|
1221
|
+
actor_collective_rpc = actor.collective_rpc_v0
|
|
1222
|
+
refs.append(
|
|
1223
|
+
actor_collective_rpc.remote(
|
|
1224
|
+
"init_weight_update_group",
|
|
1225
|
+
args=(
|
|
1226
|
+
master_address,
|
|
1227
|
+
str(master_port),
|
|
1228
|
+
rank_offset,
|
|
1229
|
+
weight_sync_world_size,
|
|
1230
|
+
),
|
|
1231
|
+
)
|
|
1232
|
+
)
|
|
1233
|
+
torchrl_logger.info(
|
|
1234
|
+
f"Requested init for actor {i} with rank_offset {rank_offset}"
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
return refs
|
|
1238
|
+
|
|
1239
|
+
def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
|
|
1240
|
+
"""Update model weights across all replicas using NCCL broadcast.
|
|
1241
|
+
|
|
1242
|
+
Args:
|
|
1243
|
+
weights: Iterator yielding (parameter_name, tensor) tuples
|
|
1244
|
+
"""
|
|
1245
|
+
if not self._launched:
|
|
1246
|
+
raise RuntimeError(
|
|
1247
|
+
"AsyncVLLM service must be launched before updating weights"
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
# Convert iterator to dict for easier handling
|
|
1251
|
+
weights_dict = dict(weights)
|
|
1252
|
+
|
|
1253
|
+
if not weights_dict:
|
|
1254
|
+
torchrl_logger.warning("No weights provided for update")
|
|
1255
|
+
return
|
|
1256
|
+
|
|
1257
|
+
torchrl_logger.info(
|
|
1258
|
+
f"Updating {len(weights_dict)} parameters across {len(self.actors)} replicas using NCCL broadcast"
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
self._update_weights_with_nccl_broadcast_simple(weights_dict)
|
|
1262
|
+
|
|
1263
|
+
torchrl_logger.info("AsyncVLLM NCCL weight update completed")
|
|
1264
|
+
|
|
1265
|
+
def _update_weights_with_nccl_broadcast_simple(
|
|
1266
|
+
self, weights_dict: dict[str, torch.Tensor]
|
|
1267
|
+
) -> None:
|
|
1268
|
+
"""Update weights using simple NCCL broadcast like V1.
|
|
1269
|
+
|
|
1270
|
+
This approach follows the V1 pattern:
|
|
1271
|
+
1. Training process (master) broadcasts as rank 0
|
|
1272
|
+
2. All vLLM workers receive as ranks 1, 2, 3...
|
|
1273
|
+
3. Simple and reliable like the working V1 implementation
|
|
1274
|
+
|
|
1275
|
+
Args:
|
|
1276
|
+
weights_dict: Dictionary of parameter names to weight tensors
|
|
1277
|
+
"""
|
|
1278
|
+
if not hasattr(self, "_nccl_master_group") or self._nccl_master_group is None:
|
|
1279
|
+
raise RuntimeError(
|
|
1280
|
+
"NCCL master group not initialized. This is a bug in the setup process."
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
t0 = time.time()
|
|
1284
|
+
|
|
1285
|
+
# Move all weights to cuda:0 (matching NCCL communicator device)
|
|
1286
|
+
gpu_weights = {}
|
|
1287
|
+
for name, weight in weights_dict.items():
|
|
1288
|
+
# Ensure weight is on cuda:0 (matching NCCL communicator)
|
|
1289
|
+
if weight.device != torch.device("cuda:0"):
|
|
1290
|
+
gpu_weights[name] = weight.to("cuda:0", non_blocking=True)
|
|
1291
|
+
else:
|
|
1292
|
+
gpu_weights[name] = weight
|
|
1293
|
+
|
|
1294
|
+
# Use periodic-mono pattern: individual weight updates with immediate RPC->NCCL
|
|
1295
|
+
torchrl_logger.info(
|
|
1296
|
+
f"Updating {len(gpu_weights)} weights using periodic-mono pattern..."
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
updated_weights = 0
|
|
1300
|
+
ray = _get_ray()
|
|
1301
|
+
with torch.cuda.device(0): # Ensure we're on the correct CUDA device
|
|
1302
|
+
for name, weight in gpu_weights.items():
|
|
1303
|
+
# Convert dtype to string name (like periodic-mono)
|
|
1304
|
+
dtype_name = str(weight.dtype).split(".")[
|
|
1305
|
+
-1
|
|
1306
|
+
] # "torch.bfloat16" -> "bfloat16"
|
|
1307
|
+
|
|
1308
|
+
# Step 1: Send RPC to workers for this weight
|
|
1309
|
+
futures = self.collective_rpc(
|
|
1310
|
+
"update_weight", args=(name, dtype_name, tuple(weight.shape))
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
# Step 2: Immediately broadcast this weight (like periodic-mono)
|
|
1314
|
+
self._nccl_master_group.broadcast(
|
|
1315
|
+
weight, src=0, stream=torch.cuda.current_stream()
|
|
1316
|
+
)
|
|
1317
|
+
|
|
1318
|
+
# Step 3: Wait for workers to complete this weight
|
|
1319
|
+
ray.get(futures)
|
|
1320
|
+
updated_weights += 1
|
|
1321
|
+
|
|
1322
|
+
torch.cuda.synchronize()
|
|
1323
|
+
t2 = time.time()
|
|
1324
|
+
torchrl_logger.info(
|
|
1325
|
+
f"Successfully updated {updated_weights}/{len(gpu_weights)} weights in {t2 - t0:.3f}s"
|
|
1326
|
+
)
|
|
1327
|
+
|
|
1328
|
+
def _setup_nccl_master_group(self) -> None:
|
|
1329
|
+
"""Set up NCCL communication group for the master node (rank 0)."""
|
|
1330
|
+
# Calculate world size (should match what workers use)
|
|
1331
|
+
gpus_per_replica = _gpus_per_replica(self.engine_args)
|
|
1332
|
+
weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
|
|
1333
|
+
|
|
1334
|
+
master_address = self.get_master_address()
|
|
1335
|
+
master_port = self.get_master_port()
|
|
1336
|
+
|
|
1337
|
+
torchrl_logger.info(
|
|
1338
|
+
f"Setting up NCCL master group: rank=0, world_size={weight_sync_world_size}, "
|
|
1339
|
+
f"address={master_address}:{master_port}"
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
# Ensure CUDA is available and initialized
|
|
1343
|
+
if not torch.cuda.is_available():
|
|
1344
|
+
raise RuntimeError("CUDA not available for NCCL communication")
|
|
1345
|
+
|
|
1346
|
+
# Set CUDA device before initializing NCCL
|
|
1347
|
+
torch.cuda.set_device(0)
|
|
1348
|
+
|
|
1349
|
+
# Initialize master as rank 0 in the NCCL group (use synchronous version)
|
|
1350
|
+
self._nccl_master_group = stateless_init_process_group(
|
|
1351
|
+
master_address=master_address,
|
|
1352
|
+
master_port=str(master_port),
|
|
1353
|
+
rank=0, # Master is always rank 0
|
|
1354
|
+
world_size=weight_sync_world_size,
|
|
1355
|
+
device=torch.device("cuda:0"),
|
|
1356
|
+
)
|
|
1357
|
+
|
|
1358
|
+
torchrl_logger.info("NCCL master group initialized successfully")
|
|
1359
|
+
|
|
1360
|
+
def get_num_unfinished_requests(
|
|
1361
|
+
self, actor_index: int | None = None
|
|
1362
|
+
) -> int | list[int]:
|
|
1363
|
+
"""Get the number of unfinished requests for one or all actors.
|
|
1364
|
+
|
|
1365
|
+
Args:
|
|
1366
|
+
actor_index (int | None): Index of specific actor, or None for all actors.
|
|
1367
|
+
|
|
1368
|
+
Returns:
|
|
1369
|
+
int | list[int]: Number of unfinished requests for the specified actor,
|
|
1370
|
+
or list of counts for all actors if actor_index is None.
|
|
1371
|
+
"""
|
|
1372
|
+
if not self._launched:
|
|
1373
|
+
raise RuntimeError(
|
|
1374
|
+
"AsyncVLLM service must be launched before getting request counts"
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
ray = _get_ray()
|
|
1378
|
+
if actor_index is not None:
|
|
1379
|
+
if not (0 <= actor_index < len(self.actors)):
|
|
1380
|
+
raise IndexError(
|
|
1381
|
+
f"Actor index {actor_index} out of range [0, {len(self.actors)})"
|
|
1382
|
+
)
|
|
1383
|
+
|
|
1384
|
+
actor = self.actors[actor_index]
|
|
1385
|
+
return ray.get(actor.get_num_unfinished_requests.remote())
|
|
1386
|
+
else:
|
|
1387
|
+
# Get counts from all actors
|
|
1388
|
+
futures = [
|
|
1389
|
+
actor.get_num_unfinished_requests.remote() for actor in self.actors
|
|
1390
|
+
]
|
|
1391
|
+
return ray.get(futures)
|
|
1392
|
+
|
|
1393
|
+
def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]:
|
|
1394
|
+
"""Get the KV cache usage for one or all actors.
|
|
1395
|
+
|
|
1396
|
+
Args:
|
|
1397
|
+
actor_index (int | None): Index of specific actor, or None for all actors.
|
|
1398
|
+
|
|
1399
|
+
Returns:
|
|
1400
|
+
float | list[float]: Cache usage fraction for the specified actor,
|
|
1401
|
+
or list of usage fractions for all actors if actor_index is None.
|
|
1402
|
+
"""
|
|
1403
|
+
if not self._launched:
|
|
1404
|
+
raise RuntimeError(
|
|
1405
|
+
"AsyncVLLM service must be launched before getting cache usage"
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
ray = _get_ray()
|
|
1409
|
+
if actor_index is not None:
|
|
1410
|
+
if not (0 <= actor_index < len(self.actors)):
|
|
1411
|
+
raise IndexError(
|
|
1412
|
+
f"Actor index {actor_index} out of range [0, {len(self.actors)})"
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
actor = self.actors[actor_index]
|
|
1416
|
+
return ray.get(actor.get_cache_usage.remote())
|
|
1417
|
+
else:
|
|
1418
|
+
# Get usage from all actors
|
|
1419
|
+
futures = [actor.get_cache_usage.remote() for actor in self.actors]
|
|
1420
|
+
return ray.get(futures)
|
|
1421
|
+
|
|
1422
|
+
def create_load_balancer(
|
|
1423
|
+
self,
|
|
1424
|
+
strategy: Literal["requests", "kv-cache"]
|
|
1425
|
+
| Sequence[Literal["prefix-aware", "requests", "kv-cache", "round-robin"]]
|
|
1426
|
+
| None = None,
|
|
1427
|
+
**kwargs,
|
|
1428
|
+
) -> LoadBalancer:
|
|
1429
|
+
"""Create a load balancer for this AsyncVLLM service.
|
|
1430
|
+
|
|
1431
|
+
Args:
|
|
1432
|
+
strategy: Load balancing strategy or sequence of strategies in fallback order.
|
|
1433
|
+
Default: ["prefix-aware", "requests"] - tries cache-aware routing first,
|
|
1434
|
+
then load balancing. Single strategies: "requests", "kv-cache"
|
|
1435
|
+
Strategy sequences: ["prefix-aware", "requests", "round-robin"]
|
|
1436
|
+
**kwargs: Additional arguments passed to LoadBalancer constructor.
|
|
1437
|
+
|
|
1438
|
+
Returns:
|
|
1439
|
+
LoadBalancer: Configured load balancer instance. This is stored in the AsyncVLLM instance.
|
|
1440
|
+
|
|
1441
|
+
Examples:
|
|
1442
|
+
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=3)
|
|
1443
|
+
|
|
1444
|
+
>>> # Use smart defaults (prefix-aware -> requests)
|
|
1445
|
+
>>> lb = service.create_load_balancer()
|
|
1446
|
+
>>> selected_actor_index = lb.select_actor(prompt="Hello world")
|
|
1447
|
+
|
|
1448
|
+
>>> # Simple single strategy
|
|
1449
|
+
>>> lb = service.create_load_balancer("requests")
|
|
1450
|
+
>>> selected_actor_index = lb.select_actor()
|
|
1451
|
+
|
|
1452
|
+
>>> # Custom strategy hierarchy
|
|
1453
|
+
>>> lb = service.create_load_balancer(
|
|
1454
|
+
... ["prefix-aware", "kv-cache", "round-robin"],
|
|
1455
|
+
... prefix_length=16,
|
|
1456
|
+
... overload_threshold=2.0
|
|
1457
|
+
... )
|
|
1458
|
+
>>> selected_actor_index = lb.select_actor(prompt="Hello world")
|
|
1459
|
+
"""
|
|
1460
|
+
if not self._launched:
|
|
1461
|
+
raise RuntimeError(
|
|
1462
|
+
"AsyncVLLM service must be launched before creating load balancer"
|
|
1463
|
+
)
|
|
1464
|
+
|
|
1465
|
+
load_balancer = LoadBalancer(self, strategy, **kwargs)
|
|
1466
|
+
self._load_balancer = load_balancer
|
|
1467
|
+
return load_balancer
|
|
1468
|
+
|
|
1469
|
+
def _extract_single_prompt_for_routing(
|
|
1470
|
+
self,
|
|
1471
|
+
prompts: Any = None,
|
|
1472
|
+
prompt_token_ids: list[int] | list[list[int]] | None = None,
|
|
1473
|
+
) -> str | list[int] | None:
|
|
1474
|
+
"""Extract a single prompt for load balancer routing, if possible.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
prompts: The prompts argument passed to generate().
|
|
1478
|
+
prompt_token_ids: The prompt_token_ids argument passed to generate().
|
|
1479
|
+
|
|
1480
|
+
Returns:
|
|
1481
|
+
str | list[int] | None: Single prompt for routing, or None if multiple prompts.
|
|
1482
|
+
"""
|
|
1483
|
+
try:
|
|
1484
|
+
# Handle prompt_token_ids first (takes precedence over prompts)
|
|
1485
|
+
if prompt_token_ids is not None:
|
|
1486
|
+
if isinstance(prompt_token_ids, list):
|
|
1487
|
+
if len(prompt_token_ids) == 0:
|
|
1488
|
+
return None # Empty list
|
|
1489
|
+
elif len(prompt_token_ids) == 1:
|
|
1490
|
+
# Single prompt case - could be tokens directly or nested list
|
|
1491
|
+
if isinstance(prompt_token_ids[0], int):
|
|
1492
|
+
# Single token sequence: [token1, token2, ...]
|
|
1493
|
+
return prompt_token_ids
|
|
1494
|
+
elif isinstance(prompt_token_ids[0], list):
|
|
1495
|
+
# Nested list with single prompt: [[token1, token2, ...]]
|
|
1496
|
+
return prompt_token_ids[0]
|
|
1497
|
+
else:
|
|
1498
|
+
return None
|
|
1499
|
+
else:
|
|
1500
|
+
# Multiple prompts: [[tokens1...], [tokens2...], ...]
|
|
1501
|
+
return None
|
|
1502
|
+
else:
|
|
1503
|
+
# Not a list, invalid format
|
|
1504
|
+
return None
|
|
1505
|
+
|
|
1506
|
+
# Handle prompts argument
|
|
1507
|
+
if prompts is None:
|
|
1508
|
+
return None
|
|
1509
|
+
|
|
1510
|
+
# Import vLLM types for proper checking
|
|
1511
|
+
try:
|
|
1512
|
+
pass
|
|
1513
|
+
except ImportError:
|
|
1514
|
+
# Fallback if imports fail
|
|
1515
|
+
type(None)
|
|
1516
|
+
type(None)
|
|
1517
|
+
|
|
1518
|
+
# Single string prompt
|
|
1519
|
+
if isinstance(prompts, str):
|
|
1520
|
+
return prompts
|
|
1521
|
+
|
|
1522
|
+
# TokensPrompt object
|
|
1523
|
+
elif hasattr(prompts, "prompt_token_ids"): # TokensPrompt-like object
|
|
1524
|
+
return prompts.prompt_token_ids
|
|
1525
|
+
|
|
1526
|
+
# TextPrompt object
|
|
1527
|
+
elif hasattr(prompts, "prompt"): # TextPrompt-like object
|
|
1528
|
+
return prompts.prompt
|
|
1529
|
+
|
|
1530
|
+
# List of prompts
|
|
1531
|
+
elif isinstance(prompts, (list, tuple)):
|
|
1532
|
+
if len(prompts) == 0:
|
|
1533
|
+
return None # Empty list
|
|
1534
|
+
elif len(prompts) == 1:
|
|
1535
|
+
# Single prompt in list - recursively extract
|
|
1536
|
+
return self._extract_single_prompt_for_routing(prompts[0], None)
|
|
1537
|
+
else:
|
|
1538
|
+
# Multiple prompts - cannot do prefix routing
|
|
1539
|
+
return None
|
|
1540
|
+
|
|
1541
|
+
# Other types (shouldn't happen in normal usage)
|
|
1542
|
+
else:
|
|
1543
|
+
torchrl_logger.debug(
|
|
1544
|
+
f"Unknown prompt type for routing: {type(prompts)}"
|
|
1545
|
+
)
|
|
1546
|
+
return None
|
|
1547
|
+
|
|
1548
|
+
except Exception as e:
|
|
1549
|
+
torchrl_logger.debug(f"Error extracting single prompt for routing: {e}")
|
|
1550
|
+
return None
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
class LoadBalancer:
|
|
1554
|
+
"""Load balancer for distributing requests across AsyncVLLM actors with strategy hierarchy.
|
|
1555
|
+
|
|
1556
|
+
This class implements sophisticated load balancing with multiple strategies and intelligent
|
|
1557
|
+
fallback mechanisms. Strategies are tried in order until one succeeds, providing robust
|
|
1558
|
+
request routing even when some strategies fail.
|
|
1559
|
+
|
|
1560
|
+
Args:
|
|
1561
|
+
actors: Either a single AsyncVLLM instance or a list of Ray actors.
|
|
1562
|
+
strategy: Single strategy or sequence of strategies in fallback order.
|
|
1563
|
+
Available strategies:
|
|
1564
|
+
|
|
1565
|
+
- "prefix-aware": Route based on prompt prefix for cache locality
|
|
1566
|
+
- "requests": Select actor with fewest pending requests
|
|
1567
|
+
- "kv-cache": Select actor with lowest KV cache utilization
|
|
1568
|
+
- "round-robin": Simple round-robin distribution
|
|
1569
|
+
|
|
1570
|
+
Default: ["prefix-aware", "requests"]
|
|
1571
|
+
|
|
1572
|
+
prefix_length: Number of tokens/words to use for prefix routing (default: 8).
|
|
1573
|
+
overload_threshold: Multiplier for average load to consider actor overloaded (default: 1.5).
|
|
1574
|
+
|
|
1575
|
+
Examples:
|
|
1576
|
+
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=3)
|
|
1577
|
+
|
|
1578
|
+
>>> # Simple strategy
|
|
1579
|
+
>>> lb = LoadBalancer(service, "requests")
|
|
1580
|
+
>>> actor_idx = lb.select_actor()
|
|
1581
|
+
|
|
1582
|
+
>>> # Strategy hierarchy: try prefix-aware first, fall back to requests, then round-robin
|
|
1583
|
+
>>> lb = LoadBalancer(service, ["prefix-aware", "requests", "round-robin"])
|
|
1584
|
+
>>> actor_idx = lb.select_actor(prompt="Hello world") # Uses prefix routing
|
|
1585
|
+
>>> actor_idx = lb.select_actor() # Falls back to requests (no prompt)
|
|
1586
|
+
|
|
1587
|
+
>>> # Custom configuration
|
|
1588
|
+
>>> lb = LoadBalancer(
|
|
1589
|
+
... service,
|
|
1590
|
+
... ["prefix-aware", "kv-cache"],
|
|
1591
|
+
... prefix_length=16,
|
|
1592
|
+
... overload_threshold=2.0
|
|
1593
|
+
... )
|
|
1594
|
+
"""
|
|
1595
|
+
|
|
1596
|
+
def __init__(
|
|
1597
|
+
self,
|
|
1598
|
+
actors: list[Any] | AsyncVLLM,
|
|
1599
|
+
strategy: Literal["requests", "kv-cache"]
|
|
1600
|
+
| Sequence[Literal["prefix-aware", "requests", "kv-cache", "round-robin"]]
|
|
1601
|
+
| None = None,
|
|
1602
|
+
prefix_length: int = 8,
|
|
1603
|
+
overload_threshold: float = 1.5,
|
|
1604
|
+
):
|
|
1605
|
+
if strategy is None:
|
|
1606
|
+
strategy = ["prefix-aware", "requests"]
|
|
1607
|
+
# Handle both AsyncVLLM instances and direct actor lists
|
|
1608
|
+
if hasattr(actors, "actors"): # AsyncVLLM instance
|
|
1609
|
+
self.actors = actors.actors
|
|
1610
|
+
self.async_vllm = actors
|
|
1611
|
+
elif isinstance(actors, list): # Direct list of actors
|
|
1612
|
+
self.actors = actors
|
|
1613
|
+
self.async_vllm = None
|
|
1614
|
+
else:
|
|
1615
|
+
raise ValueError(
|
|
1616
|
+
"actors must be either an AsyncVLLM instance or a list of actors"
|
|
1617
|
+
)
|
|
1618
|
+
|
|
1619
|
+
if not self.actors:
|
|
1620
|
+
raise ValueError("No actors provided")
|
|
1621
|
+
|
|
1622
|
+
# Handle both single strategy and strategy hierarchy
|
|
1623
|
+
if isinstance(strategy, str):
|
|
1624
|
+
self.strategies = [strategy]
|
|
1625
|
+
else:
|
|
1626
|
+
self.strategies = list(strategy)
|
|
1627
|
+
|
|
1628
|
+
# Validate strategies
|
|
1629
|
+
valid_strategies = {"prefix-aware", "requests", "kv-cache", "round-robin"}
|
|
1630
|
+
for s in self.strategies:
|
|
1631
|
+
if s not in valid_strategies:
|
|
1632
|
+
raise ValueError(
|
|
1633
|
+
f"Invalid strategy '{s}'. Must be one of {valid_strategies}"
|
|
1634
|
+
)
|
|
1635
|
+
|
|
1636
|
+
if not self.strategies:
|
|
1637
|
+
raise ValueError("At least one strategy must be provided")
|
|
1638
|
+
|
|
1639
|
+
self.strategy = self.strategies[
|
|
1640
|
+
0
|
|
1641
|
+
] # Primary strategy for backward compatibility
|
|
1642
|
+
self.prefix_length = prefix_length
|
|
1643
|
+
self.overload_threshold = overload_threshold
|
|
1644
|
+
self._round_robin_index = 0 # For round-robin fallback
|
|
1645
|
+
|
|
1646
|
+
def select_actor(
|
|
1647
|
+
self,
|
|
1648
|
+
prompt: str | list[int] | None = None,
|
|
1649
|
+
request_context: dict[str, Any] | None = None,
|
|
1650
|
+
) -> int:
|
|
1651
|
+
"""Select the optimal actor index based on the configured strategy hierarchy.
|
|
1652
|
+
|
|
1653
|
+
Args:
|
|
1654
|
+
prompt: The input prompt (string or token list) for prefix-aware routing.
|
|
1655
|
+
request_context: Additional context for routing decisions.
|
|
1656
|
+
|
|
1657
|
+
Returns:
|
|
1658
|
+
int: Index of the selected actor in the actors list.
|
|
1659
|
+
|
|
1660
|
+
Raises:
|
|
1661
|
+
RuntimeError: If unable to gather metrics from actors.
|
|
1662
|
+
ValueError: If no actors are available.
|
|
1663
|
+
"""
|
|
1664
|
+
if not self.actors:
|
|
1665
|
+
raise ValueError("No actors available for selection")
|
|
1666
|
+
|
|
1667
|
+
# Try each strategy in order until one succeeds
|
|
1668
|
+
for i, strategy in enumerate(self.strategies):
|
|
1669
|
+
try:
|
|
1670
|
+
torchrl_logger.debug(
|
|
1671
|
+
f"Trying strategy {i + 1}/{len(self.strategies)}: {strategy}"
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
if strategy == "prefix-aware":
|
|
1675
|
+
if prompt is not None:
|
|
1676
|
+
return self._select_by_prefix_aware(prompt)
|
|
1677
|
+
else:
|
|
1678
|
+
torchrl_logger.debug(
|
|
1679
|
+
"No prompt provided for prefix-aware routing, trying next strategy"
|
|
1680
|
+
)
|
|
1681
|
+
continue
|
|
1682
|
+
|
|
1683
|
+
elif strategy == "requests":
|
|
1684
|
+
return self._select_by_requests()
|
|
1685
|
+
|
|
1686
|
+
elif strategy == "kv-cache":
|
|
1687
|
+
return self._select_by_cache_usage()
|
|
1688
|
+
|
|
1689
|
+
elif strategy == "round-robin":
|
|
1690
|
+
return self._select_round_robin()
|
|
1691
|
+
|
|
1692
|
+
else:
|
|
1693
|
+
torchrl_logger.warning(
|
|
1694
|
+
f"Unknown strategy: {strategy}, trying next strategy"
|
|
1695
|
+
)
|
|
1696
|
+
continue
|
|
1697
|
+
|
|
1698
|
+
except Exception as e:
|
|
1699
|
+
torchrl_logger.warning(
|
|
1700
|
+
f"Strategy '{strategy}' failed with error: {e}. "
|
|
1701
|
+
f"Trying next strategy..."
|
|
1702
|
+
)
|
|
1703
|
+
continue
|
|
1704
|
+
|
|
1705
|
+
# All strategies failed, final fallback to random
|
|
1706
|
+
torchrl_logger.warning(
|
|
1707
|
+
f"All strategies {self.strategies} failed. Falling back to random selection."
|
|
1708
|
+
)
|
|
1709
|
+
return random.randint(0, len(self.actors) - 1)
|
|
1710
|
+
|
|
1711
|
+
def _select_by_requests(self) -> int:
|
|
1712
|
+
"""Select actor with fewest pending requests."""
|
|
1713
|
+
if self.async_vllm is not None:
|
|
1714
|
+
# Use AsyncVLLM's built-in method to get request counts
|
|
1715
|
+
request_counts = self.async_vllm.get_num_unfinished_requests()
|
|
1716
|
+
else:
|
|
1717
|
+
# Query actors directly
|
|
1718
|
+
futures = [
|
|
1719
|
+
actor.get_num_unfinished_requests.remote() for actor in self.actors
|
|
1720
|
+
]
|
|
1721
|
+
ray = _get_ray()
|
|
1722
|
+
request_counts = ray.get(futures)
|
|
1723
|
+
|
|
1724
|
+
# Find the actor with minimum pending requests
|
|
1725
|
+
min_requests = min(request_counts)
|
|
1726
|
+
min_indices = [
|
|
1727
|
+
i for i, count in enumerate(request_counts) if count == min_requests
|
|
1728
|
+
]
|
|
1729
|
+
|
|
1730
|
+
# If multiple actors have the same minimum count, choose randomly among them
|
|
1731
|
+
selected_index = random.choice(min_indices)
|
|
1732
|
+
|
|
1733
|
+
torchrl_logger.debug(
|
|
1734
|
+
f"LoadBalancer (requests): Selected actor {selected_index} "
|
|
1735
|
+
f"with {min_requests} pending requests. "
|
|
1736
|
+
f"Request counts: {request_counts}"
|
|
1737
|
+
)
|
|
1738
|
+
|
|
1739
|
+
return selected_index
|
|
1740
|
+
|
|
1741
|
+
def _select_by_cache_usage(self) -> int:
|
|
1742
|
+
"""Select actor with lowest KV cache utilization."""
|
|
1743
|
+
if self.async_vllm is not None:
|
|
1744
|
+
# Use AsyncVLLM's built-in method to get cache usage
|
|
1745
|
+
cache_usages = self.async_vllm.get_cache_usage()
|
|
1746
|
+
else:
|
|
1747
|
+
# Query actors directly
|
|
1748
|
+
futures = [actor.get_cache_usage.remote() for actor in self.actors]
|
|
1749
|
+
ray = _get_ray()
|
|
1750
|
+
cache_usages = ray.get(futures)
|
|
1751
|
+
|
|
1752
|
+
# Find the actor with minimum cache usage
|
|
1753
|
+
min_usage = min(cache_usages)
|
|
1754
|
+
min_indices = [
|
|
1755
|
+
i for i, usage in enumerate(cache_usages) if abs(usage - min_usage) < 1e-6
|
|
1756
|
+
]
|
|
1757
|
+
|
|
1758
|
+
# If multiple actors have similar cache usage, choose randomly among them
|
|
1759
|
+
selected_index = random.choice(min_indices)
|
|
1760
|
+
|
|
1761
|
+
torchrl_logger.debug(
|
|
1762
|
+
f"LoadBalancer (kv-cache): Selected actor {selected_index} "
|
|
1763
|
+
f"with {min_usage:.3f} cache usage. "
|
|
1764
|
+
f"Cache usages: {[f'{u:.3f}' for u in cache_usages]}"
|
|
1765
|
+
)
|
|
1766
|
+
|
|
1767
|
+
return selected_index
|
|
1768
|
+
|
|
1769
|
+
def _select_by_prefix_aware(self, prompt: str | list[int]) -> int:
|
|
1770
|
+
"""Select actor based on prompt prefix for cache locality.
|
|
1771
|
+
|
|
1772
|
+
Args:
|
|
1773
|
+
prompt: Input prompt as string or token list.
|
|
1774
|
+
|
|
1775
|
+
Returns:
|
|
1776
|
+
int: Selected actor index.
|
|
1777
|
+
|
|
1778
|
+
Raises:
|
|
1779
|
+
ValueError: If prefix cannot be extracted.
|
|
1780
|
+
"""
|
|
1781
|
+
try:
|
|
1782
|
+
# Extract prefix tokens
|
|
1783
|
+
prefix_tokens = self._extract_prefix_tokens(prompt)
|
|
1784
|
+
if not prefix_tokens:
|
|
1785
|
+
raise ValueError("Could not extract meaningful prefix tokens")
|
|
1786
|
+
|
|
1787
|
+
# Create consistent hash from prefix
|
|
1788
|
+
prefix_hash = hash(tuple(prefix_tokens))
|
|
1789
|
+
preferred_actor = prefix_hash % len(self.actors)
|
|
1790
|
+
|
|
1791
|
+
# Check if preferred actor is overloaded
|
|
1792
|
+
if self._is_actor_overloaded(preferred_actor):
|
|
1793
|
+
torchrl_logger.debug(
|
|
1794
|
+
f"Preferred actor {preferred_actor} is overloaded "
|
|
1795
|
+
f"(threshold: {self.overload_threshold}), falling back to load-based selection"
|
|
1796
|
+
)
|
|
1797
|
+
# Fall back to requests-based selection
|
|
1798
|
+
return self._select_by_requests()
|
|
1799
|
+
|
|
1800
|
+
torchrl_logger.debug(
|
|
1801
|
+
f"LoadBalancer (prefix-aware): Selected actor {preferred_actor} "
|
|
1802
|
+
f"for prefix hash {prefix_hash} (tokens: {prefix_tokens[:4]}...)"
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1805
|
+
return preferred_actor
|
|
1806
|
+
|
|
1807
|
+
except Exception as e:
|
|
1808
|
+
torchrl_logger.warning(f"Prefix-aware routing failed: {e}")
|
|
1809
|
+
raise
|
|
1810
|
+
|
|
1811
|
+
def _select_round_robin(self) -> int:
|
|
1812
|
+
"""Select actor using round-robin strategy."""
|
|
1813
|
+
selected = self._round_robin_index % len(self.actors)
|
|
1814
|
+
self._round_robin_index = (self._round_robin_index + 1) % len(self.actors)
|
|
1815
|
+
|
|
1816
|
+
torchrl_logger.debug(f"LoadBalancer (round-robin): Selected actor {selected}")
|
|
1817
|
+
return selected
|
|
1818
|
+
|
|
1819
|
+
def _extract_prefix_tokens(self, prompt: str | list[int]) -> list[int]:
|
|
1820
|
+
"""Extract prefix tokens from prompt (string or token list).
|
|
1821
|
+
|
|
1822
|
+
Args:
|
|
1823
|
+
prompt: Input prompt.
|
|
1824
|
+
|
|
1825
|
+
Returns:
|
|
1826
|
+
list[int]: Prefix tokens (up to self.prefix_length).
|
|
1827
|
+
|
|
1828
|
+
Raises:
|
|
1829
|
+
ValueError: If tokenization fails or prompt is invalid.
|
|
1830
|
+
"""
|
|
1831
|
+
if isinstance(prompt, list):
|
|
1832
|
+
# Already tokenized
|
|
1833
|
+
if not prompt:
|
|
1834
|
+
raise ValueError("Empty token list provided")
|
|
1835
|
+
return prompt[: self.prefix_length]
|
|
1836
|
+
|
|
1837
|
+
elif isinstance(prompt, str):
|
|
1838
|
+
# Need to tokenize - this requires access to tokenizer
|
|
1839
|
+
if not prompt.strip():
|
|
1840
|
+
raise ValueError("Empty or whitespace-only string provided")
|
|
1841
|
+
|
|
1842
|
+
# Try to get tokenizer from AsyncVLLM instance
|
|
1843
|
+
if self.async_vllm is not None:
|
|
1844
|
+
try:
|
|
1845
|
+
# This is a simplistic approach - in practice you'd want to cache the tokenizer
|
|
1846
|
+
# For now, use a simple heuristic based on string content
|
|
1847
|
+
return self._simple_string_hash(prompt)
|
|
1848
|
+
except Exception as e:
|
|
1849
|
+
torchrl_logger.warning(f"Could not tokenize string: {e}")
|
|
1850
|
+
return self._simple_string_hash(prompt)
|
|
1851
|
+
else:
|
|
1852
|
+
# Fall back to simple string hashing
|
|
1853
|
+
return self._simple_string_hash(prompt)
|
|
1854
|
+
else:
|
|
1855
|
+
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
|
|
1856
|
+
|
|
1857
|
+
def _simple_string_hash(self, text: str) -> list[int]:
|
|
1858
|
+
"""Create pseudo-tokens from string for prefix routing.
|
|
1859
|
+
|
|
1860
|
+
This is a fallback when proper tokenization isn't available.
|
|
1861
|
+
"""
|
|
1862
|
+
# Use words as pseudo-tokens, limited to prefix_length
|
|
1863
|
+
words = text.strip().split()[: self.prefix_length]
|
|
1864
|
+
if not words:
|
|
1865
|
+
raise ValueError("No words found in text")
|
|
1866
|
+
|
|
1867
|
+
# Convert words to integers using hash
|
|
1868
|
+
pseudo_tokens = [
|
|
1869
|
+
abs(hash(word)) % 50000 for word in words
|
|
1870
|
+
] # Simulate vocab size
|
|
1871
|
+
return pseudo_tokens
|
|
1872
|
+
|
|
1873
|
+
def _is_actor_overloaded(self, actor_index: int) -> bool:
|
|
1874
|
+
"""Check if an actor is overloaded compared to average load.
|
|
1875
|
+
|
|
1876
|
+
Args:
|
|
1877
|
+
actor_index: Index of actor to check.
|
|
1878
|
+
|
|
1879
|
+
Returns:
|
|
1880
|
+
bool: True if actor is overloaded.
|
|
1881
|
+
"""
|
|
1882
|
+
try:
|
|
1883
|
+
if self.async_vllm is not None:
|
|
1884
|
+
request_counts = self.async_vllm.get_num_unfinished_requests()
|
|
1885
|
+
else:
|
|
1886
|
+
futures = [
|
|
1887
|
+
actor.get_num_unfinished_requests.remote() for actor in self.actors
|
|
1888
|
+
]
|
|
1889
|
+
ray = _get_ray()
|
|
1890
|
+
request_counts = ray.get(futures)
|
|
1891
|
+
|
|
1892
|
+
if not request_counts:
|
|
1893
|
+
return False
|
|
1894
|
+
|
|
1895
|
+
avg_requests = sum(request_counts) / len(request_counts)
|
|
1896
|
+
actor_requests = request_counts[actor_index]
|
|
1897
|
+
|
|
1898
|
+
is_overloaded = actor_requests > avg_requests * self.overload_threshold
|
|
1899
|
+
|
|
1900
|
+
torchrl_logger.debug(
|
|
1901
|
+
f"Actor {actor_index}: {actor_requests} requests, "
|
|
1902
|
+
f"avg: {avg_requests:.1f}, threshold: {avg_requests * self.overload_threshold:.1f}, "
|
|
1903
|
+
f"overloaded: {is_overloaded}"
|
|
1904
|
+
)
|
|
1905
|
+
|
|
1906
|
+
return is_overloaded
|
|
1907
|
+
|
|
1908
|
+
except Exception as e:
|
|
1909
|
+
torchrl_logger.warning(f"Could not check actor load: {e}")
|
|
1910
|
+
return False # Assume not overloaded if we can't check
|
|
1911
|
+
|
|
1912
|
+
def get_stats(self) -> dict[str, Any]:
|
|
1913
|
+
"""Get current load balancing statistics for all actors.
|
|
1914
|
+
|
|
1915
|
+
Returns:
|
|
1916
|
+
dict: Statistics including request counts and cache usage for all actors.
|
|
1917
|
+
"""
|
|
1918
|
+
stats = {
|
|
1919
|
+
"strategies": self.strategies,
|
|
1920
|
+
"primary_strategy": self.strategy, # For backward compatibility
|
|
1921
|
+
"num_actors": len(self.actors),
|
|
1922
|
+
"prefix_length": self.prefix_length,
|
|
1923
|
+
"overload_threshold": self.overload_threshold,
|
|
1924
|
+
"round_robin_index": self._round_robin_index,
|
|
1925
|
+
"actor_stats": [],
|
|
1926
|
+
}
|
|
1927
|
+
|
|
1928
|
+
try:
|
|
1929
|
+
if self.async_vllm is not None:
|
|
1930
|
+
request_counts = self.async_vllm.get_num_unfinished_requests()
|
|
1931
|
+
cache_usages = self.async_vllm.get_cache_usage()
|
|
1932
|
+
else:
|
|
1933
|
+
request_futures = [
|
|
1934
|
+
actor.get_num_unfinished_requests.remote() for actor in self.actors
|
|
1935
|
+
]
|
|
1936
|
+
cache_futures = [
|
|
1937
|
+
actor.get_cache_usage.remote() for actor in self.actors
|
|
1938
|
+
]
|
|
1939
|
+
ray = _get_ray()
|
|
1940
|
+
request_counts = ray.get(request_futures)
|
|
1941
|
+
cache_usages = ray.get(cache_futures)
|
|
1942
|
+
|
|
1943
|
+
for i, (requests, cache_usage) in enumerate(
|
|
1944
|
+
zip(request_counts, cache_usages)
|
|
1945
|
+
):
|
|
1946
|
+
stats["actor_stats"].append(
|
|
1947
|
+
{
|
|
1948
|
+
"actor_index": i,
|
|
1949
|
+
"pending_requests": requests,
|
|
1950
|
+
"cache_usage": cache_usage,
|
|
1951
|
+
}
|
|
1952
|
+
)
|
|
1953
|
+
|
|
1954
|
+
except Exception as e:
|
|
1955
|
+
torchrl_logger.warning(f"Error gathering load balancer stats: {e}")
|
|
1956
|
+
stats["error"] = str(e)
|
|
1957
|
+
|
|
1958
|
+
return stats
|
|
1959
|
+
|
|
1960
|
+
|
|
1961
|
+
def make_async_vllm_engine(
|
|
1962
|
+
*,
|
|
1963
|
+
model_name: str,
|
|
1964
|
+
num_devices: int | None = None,
|
|
1965
|
+
num_replicas: int = 1,
|
|
1966
|
+
verbose: bool = True,
|
|
1967
|
+
compile: bool = True,
|
|
1968
|
+
enable_fp32_output: bool = False,
|
|
1969
|
+
tensor_parallel_size: int | None = None,
|
|
1970
|
+
data_parallel_size: int | None = None,
|
|
1971
|
+
pipeline_parallel_size: int | None = None,
|
|
1972
|
+
**kwargs,
|
|
1973
|
+
) -> AsyncVLLM:
|
|
1974
|
+
"""Create an async vLLM engine service.
|
|
1975
|
+
|
|
1976
|
+
Keyword Args:
|
|
1977
|
+
model_name (str): The model name to pass to vLLM.
|
|
1978
|
+
num_devices (int, optional): Number of devices to use, per replica.
|
|
1979
|
+
num_replicas (int): Number of engine replicas to create.
|
|
1980
|
+
verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True.
|
|
1981
|
+
compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
|
|
1982
|
+
enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
|
|
1983
|
+
This can help with numerical stability for certain models. Requires model-specific support in
|
|
1984
|
+
torchrl.modules.llm.backends._models.
|
|
1985
|
+
tensor_parallel_size (int, optional): Number of devices to use, per replica. Defaults to None.
|
|
1986
|
+
data_parallel_size (int, optional): Number of data parallel groups to use. Defaults to None.
|
|
1987
|
+
pipeline_parallel_size (int, optional): Number of pipeline parallel groups to use. Defaults to None.
|
|
1988
|
+
**kwargs: Additional arguments passed to AsyncEngineArgs.
|
|
1989
|
+
|
|
1990
|
+
Returns:
|
|
1991
|
+
AsyncVLLM: The launched engine service.
|
|
1992
|
+
|
|
1993
|
+
Raises:
|
|
1994
|
+
RuntimeError: If no CUDA devices are available.
|
|
1995
|
+
ValueError: If invalid device configuration is provided.
|
|
1996
|
+
|
|
1997
|
+
Example:
|
|
1998
|
+
>>> # Create a single-GPU async engine
|
|
1999
|
+
>>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B")
|
|
2000
|
+
>>>
|
|
2001
|
+
>>> # Create a 2-GPU tensor parallel async engine with 2 replicas
|
|
2002
|
+
>>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B", num_devices=2, num_replicas=2)
|
|
2003
|
+
>>> # Generate text
|
|
2004
|
+
>>> result = service.generate("Hello, world!", sampling_params)
|
|
2005
|
+
>>>
|
|
2006
|
+
>>> # Create with FP32 output enabled
|
|
2007
|
+
>>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B", enable_fp32_output=True)
|
|
2008
|
+
"""
|
|
2009
|
+
if not _has_vllm:
|
|
2010
|
+
raise ImportError(
|
|
2011
|
+
"vllm is not installed. Please install it with `pip install vllm`."
|
|
2012
|
+
)
|
|
2013
|
+
|
|
2014
|
+
from vllm import AsyncEngineArgs
|
|
2015
|
+
|
|
2016
|
+
# Set FP32 output environment variable if requested
|
|
2017
|
+
if enable_fp32_output:
|
|
2018
|
+
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
|
|
2019
|
+
torchrl_logger.info(
|
|
2020
|
+
"Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
|
|
2021
|
+
"This will use FP32 for the final output layer if the model supports it."
|
|
2022
|
+
)
|
|
2023
|
+
|
|
2024
|
+
# Configure verbose logging if requested
|
|
2025
|
+
if verbose:
|
|
2026
|
+
import logging
|
|
2027
|
+
|
|
2028
|
+
# Enable vLLM's throughput logging by setting the appropriate log level
|
|
2029
|
+
logging.getLogger("vllm.engine.metrics").setLevel(logging.INFO)
|
|
2030
|
+
logging.getLogger("vllm").setLevel(logging.INFO)
|
|
2031
|
+
|
|
2032
|
+
# vLLM logs throughput stats at INFO level every few seconds
|
|
2033
|
+
# The stats include: prompt throughput, generation throughput, running/pending requests, GPU KV cache usage
|
|
2034
|
+
torchrl_logger.info(
|
|
2035
|
+
"Enabled verbose vLLM logging - throughput statistics will be displayed"
|
|
2036
|
+
)
|
|
2037
|
+
|
|
2038
|
+
# Set tensor_parallel_size to num_devices if not set
|
|
2039
|
+
if tensor_parallel_size is None:
|
|
2040
|
+
if num_devices is None:
|
|
2041
|
+
tensor_parallel_size = 1
|
|
2042
|
+
else:
|
|
2043
|
+
tensor_parallel_size = num_devices
|
|
2044
|
+
elif num_devices is not None and tensor_parallel_size != num_devices:
|
|
2045
|
+
raise ValueError(f"tensor_parallel_size must be set to {num_devices}")
|
|
2046
|
+
|
|
2047
|
+
if data_parallel_size is None:
|
|
2048
|
+
data_parallel_size = 1
|
|
2049
|
+
|
|
2050
|
+
if pipeline_parallel_size is None:
|
|
2051
|
+
pipeline_parallel_size = 1
|
|
2052
|
+
|
|
2053
|
+
# Create engine args
|
|
2054
|
+
kwargs.setdefault("distributed_executor_backend", "ray")
|
|
2055
|
+
# Don't explicitly set enable_prefix_caching to avoid conflicts
|
|
2056
|
+
kwargs.setdefault("enable_prefix_caching", True)
|
|
2057
|
+
|
|
2058
|
+
# Set compilation flag - this controls whether vLLM will compile the model for better performance
|
|
2059
|
+
# Disabled by default in GRPO since it can cause issues during training
|
|
2060
|
+
if "compilation_config" not in kwargs:
|
|
2061
|
+
if compile:
|
|
2062
|
+
kwargs["compilation_config"] = {"level": 3} # PIECEWISE compilation
|
|
2063
|
+
else:
|
|
2064
|
+
kwargs["compilation_config"] = {"level": 0} # NO_COMPILATION
|
|
2065
|
+
|
|
2066
|
+
engine_args = AsyncEngineArgs(
|
|
2067
|
+
model=model_name,
|
|
2068
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
2069
|
+
data_parallel_size=data_parallel_size,
|
|
2070
|
+
pipeline_parallel_size=pipeline_parallel_size,
|
|
2071
|
+
worker_extension_cls="torchrl.modules.llm.backends.vllm.vllm_async._AsyncvLLMWorker",
|
|
2072
|
+
**kwargs,
|
|
2073
|
+
)
|
|
2074
|
+
|
|
2075
|
+
return AsyncVLLM.launch(engine_args, num_replicas)
|