tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,814 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
import multiprocessing.reduction
|
|
17
|
+
from collections import defaultdict, deque
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from multiprocessing import Process, Queue
|
|
21
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
22
|
+
|
|
23
|
+
import cloudpickle
|
|
24
|
+
import torch
|
|
25
|
+
from vllm.config import VllmConfig
|
|
26
|
+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
27
|
+
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
|
28
|
+
from vllm.v1.core.sched.interface import SchedulerInterface
|
|
29
|
+
from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
|
|
30
|
+
SchedulerOutput)
|
|
31
|
+
from vllm.v1.core.sched.scheduler import Scheduler
|
|
32
|
+
from vllm.v1.engine import EngineCoreOutputs
|
|
33
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
34
|
+
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
|
35
|
+
from vllm.v1.outputs import ModelRunnerOutput
|
|
36
|
+
from vllm.v1.request import Request
|
|
37
|
+
from vllm.v1.structured_output import StructuredOutputManager
|
|
38
|
+
|
|
39
|
+
from tpu_inference.logger import init_logger
|
|
40
|
+
from tpu_inference.utils import time_function
|
|
41
|
+
|
|
42
|
+
logger = init_logger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SchedulerCommand(Enum):
|
|
46
|
+
"""Enum for scheduler worker process commands."""
|
|
47
|
+
ADD_REQUEST = "add_request"
|
|
48
|
+
SCHEDULE = "schedule"
|
|
49
|
+
FINISH_REQUESTS = "finish_requests"
|
|
50
|
+
UPDATE_DRAFT_TOKEN_IDS = "update_draft_token_ids"
|
|
51
|
+
UPDATE_FROM_OUTPUT = "update_from_output"
|
|
52
|
+
GET_GRAMMAR_BITMASK = "get_grammar_bitmask"
|
|
53
|
+
MAKE_STATS = "make_stats"
|
|
54
|
+
RESET_PREFIX_CACHE = "reset_prefix_cache"
|
|
55
|
+
GET_NUM_UNFINISHED_REQUESTS = "get_num_unfinished_requests"
|
|
56
|
+
HAS_FINISHED_REQUESTS = "has_finished_requests"
|
|
57
|
+
GET_REQUEST_COUNTS = "get_request_counts"
|
|
58
|
+
GET_TOKEN_COUNT = "get_token_count"
|
|
59
|
+
GET_COMPUTED_BLOCKS = "get_computed_blocks"
|
|
60
|
+
SHUTDOWN = "shutdown"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class SchedulerWorkerError(Exception):
|
|
64
|
+
"""Exception raised when a scheduler worker process encounters an error."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, rank: int, message: str):
|
|
67
|
+
self.rank = rank
|
|
68
|
+
self.message = message
|
|
69
|
+
super().__init__(f"Scheduler worker {rank} error: {message}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Monkey-patch multiprocessing to use cloudpickle
|
|
73
|
+
# Standard pickle fails to serialize the vLLM Request object.
|
|
74
|
+
_original_dumps = multiprocessing.reduction.ForkingPickler.dumps
|
|
75
|
+
_original_loads = multiprocessing.reduction.ForkingPickler.loads
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _cloudpickle_dumps(obj, protocol=None):
|
|
79
|
+
"""Use cloudpickle for serialization."""
|
|
80
|
+
return cloudpickle.dumps(obj, protocol=protocol)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _cloudpickle_loads(data):
|
|
84
|
+
"""Use cloudpickle for deserialization."""
|
|
85
|
+
return cloudpickle.loads(data)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _enable_cloudpickle():
|
|
89
|
+
"""Enable cloudpickle for multiprocessing queues."""
|
|
90
|
+
multiprocessing.reduction.ForkingPickler.dumps = staticmethod(
|
|
91
|
+
_cloudpickle_dumps)
|
|
92
|
+
multiprocessing.reduction.ForkingPickler.loads = staticmethod(
|
|
93
|
+
_cloudpickle_loads)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _disable_cloudpickle():
|
|
97
|
+
"""Restore original pickle for multiprocessing."""
|
|
98
|
+
multiprocessing.reduction.ForkingPickler.dumps = _original_dumps
|
|
99
|
+
multiprocessing.reduction.ForkingPickler.loads = _original_loads
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _scheduler_worker_process(
|
|
103
|
+
rank: int,
|
|
104
|
+
input_queue: Queue,
|
|
105
|
+
output_queue: Queue,
|
|
106
|
+
vllm_config: Any,
|
|
107
|
+
kv_cache_config: Any,
|
|
108
|
+
structured_output_manager: Any,
|
|
109
|
+
block_size: int,
|
|
110
|
+
mm_registry: Any,
|
|
111
|
+
include_finished_set: bool,
|
|
112
|
+
log_stats: bool,
|
|
113
|
+
original_scheduler_cls: type,
|
|
114
|
+
):
|
|
115
|
+
"""Worker process that manages a single scheduler instance."""
|
|
116
|
+
# Initialize the scheduler in this process
|
|
117
|
+
scheduler = original_scheduler_cls(
|
|
118
|
+
vllm_config=vllm_config,
|
|
119
|
+
kv_cache_config=kv_cache_config,
|
|
120
|
+
structured_output_manager=structured_output_manager,
|
|
121
|
+
block_size=block_size,
|
|
122
|
+
mm_registry=mm_registry,
|
|
123
|
+
include_finished_set=include_finished_set,
|
|
124
|
+
log_stats=log_stats,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
logger.debug(f"Scheduler worker process {rank} started")
|
|
128
|
+
|
|
129
|
+
# Process commands from the input queue
|
|
130
|
+
while True:
|
|
131
|
+
try:
|
|
132
|
+
command, data = input_queue.get()
|
|
133
|
+
|
|
134
|
+
match command:
|
|
135
|
+
case SchedulerCommand.ADD_REQUEST:
|
|
136
|
+
request = data
|
|
137
|
+
scheduler.add_request(request)
|
|
138
|
+
output_queue.put(None) # Signal completion
|
|
139
|
+
|
|
140
|
+
case SchedulerCommand.SCHEDULE:
|
|
141
|
+
output = scheduler.schedule()
|
|
142
|
+
output_queue.put(output)
|
|
143
|
+
|
|
144
|
+
case SchedulerCommand.FINISH_REQUESTS:
|
|
145
|
+
request_ids, finished_status = data
|
|
146
|
+
scheduler.finish_requests(request_ids, finished_status)
|
|
147
|
+
output_queue.put(None) # Signal completion
|
|
148
|
+
|
|
149
|
+
case SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS:
|
|
150
|
+
draft_token_ids = data
|
|
151
|
+
scheduler.update_draft_token_ids(draft_token_ids)
|
|
152
|
+
output_queue.put(None) # Signal completion
|
|
153
|
+
|
|
154
|
+
case SchedulerCommand.UPDATE_FROM_OUTPUT:
|
|
155
|
+
scheduler_output, model_runner_output = data
|
|
156
|
+
result = scheduler.update_from_output(
|
|
157
|
+
scheduler_output, model_runner_output)
|
|
158
|
+
output_queue.put(result)
|
|
159
|
+
|
|
160
|
+
case SchedulerCommand.GET_GRAMMAR_BITMASK:
|
|
161
|
+
scheduler_output = data
|
|
162
|
+
result = scheduler.get_grammar_bitmask(scheduler_output)
|
|
163
|
+
output_queue.put(result)
|
|
164
|
+
|
|
165
|
+
case SchedulerCommand.MAKE_STATS:
|
|
166
|
+
spec_decoding_stats, kv_connector_stats = data
|
|
167
|
+
result = scheduler.make_stats(spec_decoding_stats,
|
|
168
|
+
kv_connector_stats)
|
|
169
|
+
output_queue.put(result)
|
|
170
|
+
|
|
171
|
+
case SchedulerCommand.RESET_PREFIX_CACHE:
|
|
172
|
+
result = scheduler.reset_prefix_cache()
|
|
173
|
+
output_queue.put(result)
|
|
174
|
+
|
|
175
|
+
case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS:
|
|
176
|
+
result = scheduler.get_num_unfinished_requests()
|
|
177
|
+
output_queue.put(result)
|
|
178
|
+
|
|
179
|
+
case SchedulerCommand.HAS_FINISHED_REQUESTS:
|
|
180
|
+
result = scheduler.has_finished_requests()
|
|
181
|
+
output_queue.put(result)
|
|
182
|
+
|
|
183
|
+
case SchedulerCommand.GET_REQUEST_COUNTS:
|
|
184
|
+
running = len(scheduler.running)
|
|
185
|
+
waiting = len(scheduler.waiting)
|
|
186
|
+
output_queue.put((running, waiting))
|
|
187
|
+
|
|
188
|
+
case SchedulerCommand.GET_TOKEN_COUNT:
|
|
189
|
+
# Calculate total tokens across running and waiting requests
|
|
190
|
+
total_tokens = 0
|
|
191
|
+
for req in scheduler.running:
|
|
192
|
+
total_tokens += len(req.all_token_ids)
|
|
193
|
+
for req in scheduler.waiting:
|
|
194
|
+
total_tokens += len(req.all_token_ids)
|
|
195
|
+
output_queue.put(total_tokens)
|
|
196
|
+
|
|
197
|
+
case SchedulerCommand.GET_COMPUTED_BLOCKS:
|
|
198
|
+
request = data
|
|
199
|
+
blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
|
|
200
|
+
request)
|
|
201
|
+
output_queue.put((blocks, cached_tokens))
|
|
202
|
+
|
|
203
|
+
case SchedulerCommand.SHUTDOWN:
|
|
204
|
+
scheduler.shutdown()
|
|
205
|
+
output_queue.put(None) # Signal completion
|
|
206
|
+
break
|
|
207
|
+
case _:
|
|
208
|
+
error = SchedulerWorkerError(
|
|
209
|
+
rank, f"Unknown command: {command}")
|
|
210
|
+
output_queue.put(error)
|
|
211
|
+
raise error
|
|
212
|
+
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.error(f"Error in scheduler worker {rank}: {e}",
|
|
215
|
+
exc_info=True)
|
|
216
|
+
# Put error on output queue
|
|
217
|
+
error = SchedulerWorkerError(rank, str(e))
|
|
218
|
+
output_queue.put(error)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@dataclass
|
|
222
|
+
class DPSchedulerOutput(SchedulerOutput):
|
|
223
|
+
"""Extended SchedulerOutput that includes DP rank assignments."""
|
|
224
|
+
assigned_dp_rank: Optional[Dict[str, int]] = None
|
|
225
|
+
|
|
226
|
+
def __init__(self, *args, assigned_dp_rank=None, **kwargs):
|
|
227
|
+
super().__init__(*args, **kwargs)
|
|
228
|
+
self.assigned_dp_rank = assigned_dp_rank or {}
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class DPScheduler(SchedulerInterface):
|
|
232
|
+
"""
|
|
233
|
+
DPScheduler is used when DP size is >=2. Otherwise the default vLLM scheduler is used.
|
|
234
|
+
|
|
235
|
+
The DPScheduler manages:
|
|
236
|
+
1. Multiple vLLM Schedulers (one per DP rank)
|
|
237
|
+
2. Request-to-scheduler assignment
|
|
238
|
+
|
|
239
|
+
Each Scheduler manages its own logical KV cache shard and scheduling logic.
|
|
240
|
+
|
|
241
|
+
**Load Balancing**
|
|
242
|
+
|
|
243
|
+
For new requests:
|
|
244
|
+
- If there is prefix cache hit, assigns request to the rank with the best hit
|
|
245
|
+
- Otherwise, assigns request to the rank with the least total tokens
|
|
246
|
+
|
|
247
|
+
Once a DP rank is assigned to a request, it remains fixed for the request's lifetime.
|
|
248
|
+
A request will be freed from its assigned rank when it is completed or preempted.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def __init__(
|
|
252
|
+
self,
|
|
253
|
+
vllm_config: VllmConfig,
|
|
254
|
+
kv_cache_config: KVCacheConfig,
|
|
255
|
+
structured_output_manager: StructuredOutputManager,
|
|
256
|
+
block_size: int,
|
|
257
|
+
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
|
258
|
+
include_finished_set: bool = False,
|
|
259
|
+
log_stats: bool = False,
|
|
260
|
+
) -> None:
|
|
261
|
+
self.vllm_config = vllm_config
|
|
262
|
+
self.block_size = block_size
|
|
263
|
+
self.log_stats = log_stats
|
|
264
|
+
self.connector = None
|
|
265
|
+
self.structured_output_manager = structured_output_manager
|
|
266
|
+
|
|
267
|
+
# DP state
|
|
268
|
+
self.dp_size = vllm_config.sharding_config.total_dp_size
|
|
269
|
+
self.assigned_dp_rank: Dict[str, int] = {} # req_id -> dp_rank
|
|
270
|
+
self.cached_schedulers_output = deque()
|
|
271
|
+
self._create_per_rank_configs(kv_cache_config)
|
|
272
|
+
|
|
273
|
+
# The original scheduler class could be Scheduler or AsyncScheduler
|
|
274
|
+
original_scheduler_cls = vllm_config.scheduler_config._original_scheduler_cls
|
|
275
|
+
|
|
276
|
+
# Enable cloudpickle for multiprocessing to handle local functions
|
|
277
|
+
_enable_cloudpickle()
|
|
278
|
+
|
|
279
|
+
# Create worker processes with one input and one output queue each
|
|
280
|
+
import multiprocessing
|
|
281
|
+
ctx = multiprocessing.get_context('fork')
|
|
282
|
+
self.input_queues: List[Queue] = []
|
|
283
|
+
self.output_queues: List[Queue] = []
|
|
284
|
+
self.processes: List[Process] = []
|
|
285
|
+
|
|
286
|
+
for rank in range(self.dp_size):
|
|
287
|
+
input_queue = ctx.Queue()
|
|
288
|
+
output_queue = ctx.Queue()
|
|
289
|
+
|
|
290
|
+
self.input_queues.append(input_queue)
|
|
291
|
+
self.output_queues.append(output_queue)
|
|
292
|
+
|
|
293
|
+
process = ctx.Process(
|
|
294
|
+
target=_scheduler_worker_process,
|
|
295
|
+
args=(
|
|
296
|
+
rank,
|
|
297
|
+
input_queue,
|
|
298
|
+
output_queue,
|
|
299
|
+
self.vllm_config,
|
|
300
|
+
self.per_rank_kv_cache_configs[rank],
|
|
301
|
+
structured_output_manager,
|
|
302
|
+
block_size,
|
|
303
|
+
mm_registry,
|
|
304
|
+
include_finished_set,
|
|
305
|
+
log_stats,
|
|
306
|
+
original_scheduler_cls,
|
|
307
|
+
),
|
|
308
|
+
)
|
|
309
|
+
process.start()
|
|
310
|
+
self.processes.append(process)
|
|
311
|
+
|
|
312
|
+
logger.info(
|
|
313
|
+
f"DPScheduler (Async = {self.vllm_config.scheduler_config.async_scheduling}) "
|
|
314
|
+
f"started {self.dp_size} worker processes with cloudpickle. "
|
|
315
|
+
f"Per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
|
|
316
|
+
f"max_tokens={self.vllm_config.scheduler_config.max_num_batched_tokens}"
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def _create_per_rank_configs(self, kv_cache_config: KVCacheConfig) -> None:
|
|
320
|
+
self.per_rank_kv_cache_configs: List[KVCacheConfig] = []
|
|
321
|
+
for _ in range(self.dp_size):
|
|
322
|
+
rank_config = copy.deepcopy(kv_cache_config)
|
|
323
|
+
rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
|
|
324
|
+
self.per_rank_kv_cache_configs.append(rank_config)
|
|
325
|
+
|
|
326
|
+
def _get_result_from_queue(self, queue: Queue) -> Any:
|
|
327
|
+
result = queue.get()
|
|
328
|
+
if isinstance(result, SchedulerWorkerError):
|
|
329
|
+
raise result
|
|
330
|
+
return result
|
|
331
|
+
|
|
332
|
+
def _get_rank_token_counts(self) -> Dict[int, int]:
|
|
333
|
+
"""Calculate total tokens currently assigned to each DP rank."""
|
|
334
|
+
for rank in range(self.dp_size):
|
|
335
|
+
self.input_queues[rank].put(
|
|
336
|
+
(SchedulerCommand.GET_TOKEN_COUNT, None))
|
|
337
|
+
|
|
338
|
+
rank_tokens = {}
|
|
339
|
+
for rank in range(self.dp_size):
|
|
340
|
+
token_count = self._get_result_from_queue(self.output_queues[rank])
|
|
341
|
+
rank_tokens[rank] = token_count
|
|
342
|
+
|
|
343
|
+
return rank_tokens
|
|
344
|
+
|
|
345
|
+
def _find_best_rank_for_request(self, request: Request) -> int:
|
|
346
|
+
"""Find the best DP rank for a new request based on load balancing."""
|
|
347
|
+
rank_tokens = self._get_rank_token_counts()
|
|
348
|
+
|
|
349
|
+
# First, try to find a rank with prefix cache hit
|
|
350
|
+
for rank in range(self.dp_size):
|
|
351
|
+
self.input_queues[rank].put(
|
|
352
|
+
(SchedulerCommand.GET_COMPUTED_BLOCKS, request))
|
|
353
|
+
|
|
354
|
+
best_cache_rank = None
|
|
355
|
+
best_cache_tokens = 0
|
|
356
|
+
for rank in range(self.dp_size):
|
|
357
|
+
blocks, cached_tokens = self._get_result_from_queue(
|
|
358
|
+
self.output_queues[rank])
|
|
359
|
+
if cached_tokens > best_cache_tokens:
|
|
360
|
+
best_cache_tokens = cached_tokens
|
|
361
|
+
best_cache_rank = rank
|
|
362
|
+
if best_cache_tokens > 0:
|
|
363
|
+
return best_cache_rank
|
|
364
|
+
|
|
365
|
+
# Otherwise, find rank with least tokens
|
|
366
|
+
selected_rank = min(rank_tokens, key=rank_tokens.get)
|
|
367
|
+
return selected_rank
|
|
368
|
+
|
|
369
|
+
def add_request(self, request: Request) -> None:
|
|
370
|
+
"""
|
|
371
|
+
Add a new request to the appropriate DP rank scheduler.
|
|
372
|
+
|
|
373
|
+
This is the main entry point for new requests. The scheduler will:
|
|
374
|
+
1. Determine the best DP rank for the request (load balancing + cache hits)
|
|
375
|
+
2. Assign the request to that rank
|
|
376
|
+
3. Add the request to the rank's scheduler
|
|
377
|
+
"""
|
|
378
|
+
assert request.request_id not in self.assigned_dp_rank, (
|
|
379
|
+
f"Request {request.request_id} already "
|
|
380
|
+
f"assigned to rank {self.assigned_dp_rank[request.request_id]})")
|
|
381
|
+
rank = self._find_best_rank_for_request(request)
|
|
382
|
+
self.assigned_dp_rank[request.request_id] = rank
|
|
383
|
+
|
|
384
|
+
self.input_queues[rank].put((SchedulerCommand.ADD_REQUEST, request))
|
|
385
|
+
self._get_result_from_queue(self.output_queues[rank])
|
|
386
|
+
|
|
387
|
+
@time_function
|
|
388
|
+
def schedule(self) -> DPSchedulerOutput:
|
|
389
|
+
"""
|
|
390
|
+
Main scheduling method that coordinates all DP rank schedulers.
|
|
391
|
+
|
|
392
|
+
Process:
|
|
393
|
+
1. Add any new requests to appropriate DP ranks
|
|
394
|
+
2. Run each scheduler independently in parallel
|
|
395
|
+
3. Combine outputs from all schedulers
|
|
396
|
+
4. Return unified scheduling result
|
|
397
|
+
"""
|
|
398
|
+
# Run each scheduler independently
|
|
399
|
+
for rank in range(self.dp_size):
|
|
400
|
+
self.input_queues[rank].put((SchedulerCommand.SCHEDULE, None))
|
|
401
|
+
|
|
402
|
+
# Collect outputs from all workers (blocking)
|
|
403
|
+
rank_outputs = []
|
|
404
|
+
for rank in range(self.dp_size):
|
|
405
|
+
output = self._get_result_from_queue(self.output_queues[rank])
|
|
406
|
+
rank_outputs.append(output)
|
|
407
|
+
|
|
408
|
+
# Cache scheduler outputs to use in `update_from_output`
|
|
409
|
+
self.cached_schedulers_output.append(rank_outputs)
|
|
410
|
+
|
|
411
|
+
# Return combined scheduler outputs
|
|
412
|
+
combined_output = self._combine_scheduler_outputs(rank_outputs)
|
|
413
|
+
|
|
414
|
+
logger.debug(
|
|
415
|
+
f"DPScheduler scheduled: "
|
|
416
|
+
f"{combined_output.total_num_scheduled_tokens} total tokens, "
|
|
417
|
+
f"{len(combined_output.scheduled_new_reqs)} new requests, "
|
|
418
|
+
f"{len(combined_output.scheduled_cached_reqs.req_ids)} cached requests"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
return combined_output
|
|
422
|
+
|
|
423
|
+
def _combine_scheduler_outputs(
|
|
424
|
+
self, rank_outputs: List[SchedulerOutput]) -> DPSchedulerOutput:
|
|
425
|
+
"""Combine outputs from all DP rank schedulers into a unified output."""
|
|
426
|
+
|
|
427
|
+
# Combine new requests
|
|
428
|
+
all_new_reqs = []
|
|
429
|
+
for output in rank_outputs:
|
|
430
|
+
all_new_reqs.extend(output.scheduled_new_reqs)
|
|
431
|
+
|
|
432
|
+
# Combine cached request data
|
|
433
|
+
combined_cached_data = self._combine_cached_request_data(rank_outputs)
|
|
434
|
+
|
|
435
|
+
# Combine token counts and other metrics
|
|
436
|
+
combined_num_scheduled_tokens = {}
|
|
437
|
+
combined_spec_decode_tokens = {}
|
|
438
|
+
combined_encoder_inputs = {}
|
|
439
|
+
total_scheduled_tokens = 0
|
|
440
|
+
|
|
441
|
+
for output in rank_outputs:
|
|
442
|
+
combined_num_scheduled_tokens.update(output.num_scheduled_tokens)
|
|
443
|
+
combined_spec_decode_tokens.update(
|
|
444
|
+
output.scheduled_spec_decode_tokens)
|
|
445
|
+
combined_encoder_inputs.update(output.scheduled_encoder_inputs)
|
|
446
|
+
total_scheduled_tokens += output.total_num_scheduled_tokens
|
|
447
|
+
|
|
448
|
+
# Combine finished request IDs
|
|
449
|
+
combined_finished_req_ids = set()
|
|
450
|
+
for output in rank_outputs:
|
|
451
|
+
combined_finished_req_ids.update(output.finished_req_ids)
|
|
452
|
+
|
|
453
|
+
# Combine other fields (take from first non-empty or use defaults)
|
|
454
|
+
num_common_prefix_blocks = rank_outputs[
|
|
455
|
+
0].num_common_prefix_blocks if rank_outputs else []
|
|
456
|
+
|
|
457
|
+
# Create DP rank assignment mapping for scheduled requests
|
|
458
|
+
assigned_dp_rank = {}
|
|
459
|
+
for req_id in combined_num_scheduled_tokens.keys():
|
|
460
|
+
assigned_dp_rank[req_id] = self.assigned_dp_rank[req_id]
|
|
461
|
+
|
|
462
|
+
return DPSchedulerOutput(
|
|
463
|
+
scheduled_new_reqs=all_new_reqs,
|
|
464
|
+
scheduled_cached_reqs=combined_cached_data,
|
|
465
|
+
num_scheduled_tokens=combined_num_scheduled_tokens,
|
|
466
|
+
total_num_scheduled_tokens=total_scheduled_tokens,
|
|
467
|
+
scheduled_spec_decode_tokens=combined_spec_decode_tokens,
|
|
468
|
+
scheduled_encoder_inputs=combined_encoder_inputs,
|
|
469
|
+
num_common_prefix_blocks=num_common_prefix_blocks,
|
|
470
|
+
finished_req_ids=combined_finished_req_ids,
|
|
471
|
+
free_encoder_mm_hashes=set(),
|
|
472
|
+
assigned_dp_rank=assigned_dp_rank,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
def _combine_cached_request_data(
|
|
476
|
+
self, rank_outputs: List[SchedulerOutput]) -> CachedRequestData:
|
|
477
|
+
"""Combine cached request data from all DP rank schedulers."""
|
|
478
|
+
combined_req_ids = []
|
|
479
|
+
combined_resumed_req_ids = []
|
|
480
|
+
combined_new_token_ids = []
|
|
481
|
+
combined_all_token_ids = []
|
|
482
|
+
combined_new_block_ids = []
|
|
483
|
+
combined_num_computed_tokens = []
|
|
484
|
+
combined_num_output_tokens = []
|
|
485
|
+
|
|
486
|
+
for output in rank_outputs:
|
|
487
|
+
cached_data = output.scheduled_cached_reqs
|
|
488
|
+
|
|
489
|
+
combined_req_ids.extend(cached_data.req_ids)
|
|
490
|
+
combined_resumed_req_ids.extend(cached_data.resumed_req_ids)
|
|
491
|
+
combined_new_token_ids.extend(cached_data.new_token_ids)
|
|
492
|
+
combined_all_token_ids.extend(cached_data.all_token_ids)
|
|
493
|
+
combined_new_block_ids.extend(cached_data.new_block_ids)
|
|
494
|
+
combined_num_computed_tokens.extend(
|
|
495
|
+
cached_data.num_computed_tokens)
|
|
496
|
+
combined_num_output_tokens.extend(cached_data.num_output_tokens)
|
|
497
|
+
|
|
498
|
+
return CachedRequestData(
|
|
499
|
+
req_ids=combined_req_ids,
|
|
500
|
+
resumed_req_ids=combined_resumed_req_ids,
|
|
501
|
+
new_token_ids=combined_new_token_ids,
|
|
502
|
+
all_token_ids=combined_all_token_ids,
|
|
503
|
+
new_block_ids=combined_new_block_ids,
|
|
504
|
+
num_computed_tokens=combined_num_computed_tokens,
|
|
505
|
+
num_output_tokens=combined_num_output_tokens,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
def get_grammar_bitmask(
|
|
509
|
+
self,
|
|
510
|
+
scheduler_output: DPSchedulerOutput,
|
|
511
|
+
) -> GrammarOutput | None:
|
|
512
|
+
"""
|
|
513
|
+
Generate grammar bitmask for structured output requests across all DP ranks.
|
|
514
|
+
|
|
515
|
+
This method calls get_grammar_bitmask on each underlying scheduler and
|
|
516
|
+
combines their outputs, similar to how other operations are handled.
|
|
517
|
+
"""
|
|
518
|
+
# Use the most recent cached outputs from the schedule() call
|
|
519
|
+
if not self.cached_schedulers_output:
|
|
520
|
+
return None
|
|
521
|
+
|
|
522
|
+
rank_scheduler_outputs = self.cached_schedulers_output[
|
|
523
|
+
-1] # Get the most recent
|
|
524
|
+
|
|
525
|
+
combined_structured_output_request_ids = []
|
|
526
|
+
combined_bitmasks = []
|
|
527
|
+
|
|
528
|
+
# Get grammar bitmask from each DP rank scheduler
|
|
529
|
+
for rank in range(self.dp_size):
|
|
530
|
+
self.input_queues[rank].put((SchedulerCommand.GET_GRAMMAR_BITMASK,
|
|
531
|
+
rank_scheduler_outputs[rank]))
|
|
532
|
+
for rank in range(self.dp_size):
|
|
533
|
+
grammar_output = self._get_result_from_queue(
|
|
534
|
+
self.output_queues[rank])
|
|
535
|
+
if grammar_output is not None:
|
|
536
|
+
combined_structured_output_request_ids.extend(
|
|
537
|
+
grammar_output.structured_output_request_ids)
|
|
538
|
+
combined_bitmasks.append(grammar_output.grammar_bitmask)
|
|
539
|
+
|
|
540
|
+
if not combined_structured_output_request_ids:
|
|
541
|
+
return None
|
|
542
|
+
|
|
543
|
+
# Combine bitmasks - concatenate along the batch dimension
|
|
544
|
+
if len(combined_bitmasks) == 1:
|
|
545
|
+
combined_bitmask = combined_bitmasks[0]
|
|
546
|
+
else:
|
|
547
|
+
combined_bitmask = torch.cat(combined_bitmasks, dim=0)
|
|
548
|
+
|
|
549
|
+
return GrammarOutput(combined_structured_output_request_ids,
|
|
550
|
+
combined_bitmask)
|
|
551
|
+
|
|
552
|
+
def update_from_output(
|
|
553
|
+
self, scheduler_output: DPSchedulerOutput,
|
|
554
|
+
model_runner_output: ModelRunnerOutput
|
|
555
|
+
) -> dict[int, EngineCoreOutputs]:
|
|
556
|
+
"""
|
|
557
|
+
Update all DP rank schedulers based on model runner output.
|
|
558
|
+
|
|
559
|
+
We need to route the model runner output to the appropriate scheduler
|
|
560
|
+
based on which rank each request belongs to.
|
|
561
|
+
"""
|
|
562
|
+
# Group model runner outputs by DP rank
|
|
563
|
+
rank_model_outputs = self._split_model_output_by_rank(
|
|
564
|
+
model_runner_output)
|
|
565
|
+
rank_scheduler_outputs = self.cached_schedulers_output.popleft()
|
|
566
|
+
# Update each scheduler with its portion of the output
|
|
567
|
+
for rank in range(self.dp_size):
|
|
568
|
+
self.input_queues[rank].put(
|
|
569
|
+
(SchedulerCommand.UPDATE_FROM_OUTPUT,
|
|
570
|
+
(rank_scheduler_outputs[rank], rank_model_outputs[rank])))
|
|
571
|
+
|
|
572
|
+
combined_engine_outputs = defaultdict(list)
|
|
573
|
+
for rank in range(self.dp_size):
|
|
574
|
+
rank_engine_outputs = self._get_result_from_queue(
|
|
575
|
+
self.output_queues[rank])
|
|
576
|
+
for client_idx, engine_output in rank_engine_outputs.items():
|
|
577
|
+
combined_engine_outputs[client_idx].append(engine_output)
|
|
578
|
+
|
|
579
|
+
# Clean up finished requests from DP tracking
|
|
580
|
+
self._cleanup_finished_requests(scheduler_output.finished_req_ids)
|
|
581
|
+
|
|
582
|
+
# Return combined EngineCoreOutput
|
|
583
|
+
for client_idx, engine_outputs in combined_engine_outputs.items():
|
|
584
|
+
combined_output = EngineCoreOutputs()
|
|
585
|
+
outputs = []
|
|
586
|
+
finished_requests = set()
|
|
587
|
+
for engine_output in engine_outputs:
|
|
588
|
+
outputs.extend(engine_output.outputs)
|
|
589
|
+
if engine_output.finished_requests:
|
|
590
|
+
finished_requests.update(engine_output.finished_requests)
|
|
591
|
+
combined_output.engine_index = engine_outputs[0].engine_index
|
|
592
|
+
combined_output.outputs = outputs
|
|
593
|
+
combined_output.finished_requests = finished_requests
|
|
594
|
+
combined_output.scheduler_stats = self.make_stats()
|
|
595
|
+
combined_engine_outputs[client_idx] = combined_output
|
|
596
|
+
|
|
597
|
+
return combined_engine_outputs
|
|
598
|
+
|
|
599
|
+
def _split_model_output_by_rank(
|
|
600
|
+
self,
|
|
601
|
+
global_model_output: ModelRunnerOutput) -> List[ModelRunnerOutput]:
|
|
602
|
+
"""Split the model runner output by DP rank for individual scheduler updates."""
|
|
603
|
+
outputs = [
|
|
604
|
+
ModelRunnerOutput(
|
|
605
|
+
req_ids=[],
|
|
606
|
+
req_id_to_index=global_model_output.req_id_to_index,
|
|
607
|
+
sampled_token_ids=global_model_output.sampled_token_ids,
|
|
608
|
+
logprobs=global_model_output.logprobs,
|
|
609
|
+
prompt_logprobs_dict=global_model_output.prompt_logprobs_dict,
|
|
610
|
+
pooler_output=None,
|
|
611
|
+
num_nans_in_logits=global_model_output.num_nans_in_logits,
|
|
612
|
+
kv_connector_output=global_model_output.kv_connector_output,
|
|
613
|
+
) for _ in range(self.dp_size)
|
|
614
|
+
]
|
|
615
|
+
|
|
616
|
+
for req_id in global_model_output.req_ids:
|
|
617
|
+
rank = self.assigned_dp_rank[req_id]
|
|
618
|
+
outputs[rank].req_ids.append(req_id)
|
|
619
|
+
|
|
620
|
+
return outputs
|
|
621
|
+
|
|
622
|
+
def _cleanup_finished_requests(self, finished_req_ids: set[str]) -> None:
|
|
623
|
+
"""Remove finished requests from our DP rank assignment tracking."""
|
|
624
|
+
for req_id in finished_req_ids:
|
|
625
|
+
if req_id in self.assigned_dp_rank:
|
|
626
|
+
del self.assigned_dp_rank[req_id]
|
|
627
|
+
|
|
628
|
+
def finish_requests(self, request_ids, finished_status) -> None:
|
|
629
|
+
"""Forward request finish signals to the appropriate DP rank schedulers."""
|
|
630
|
+
if isinstance(request_ids, str):
|
|
631
|
+
request_ids = [request_ids]
|
|
632
|
+
|
|
633
|
+
# Route finish signals to appropriate schedulers
|
|
634
|
+
rank_request_ids = defaultdict(list)
|
|
635
|
+
for req_id in request_ids:
|
|
636
|
+
rank = self.assigned_dp_rank[req_id]
|
|
637
|
+
rank_request_ids[rank].append(req_id)
|
|
638
|
+
|
|
639
|
+
# Forward to each scheduler
|
|
640
|
+
for rank, req_ids in rank_request_ids.items():
|
|
641
|
+
self.input_queues[rank].put(
|
|
642
|
+
(SchedulerCommand.FINISH_REQUESTS, (req_ids, finished_status)))
|
|
643
|
+
self._get_result_from_queue(self.output_queues[rank])
|
|
644
|
+
|
|
645
|
+
def get_num_unfinished_requests(self) -> int:
|
|
646
|
+
"""Get total number of unfinished requests across all DP ranks."""
|
|
647
|
+
for rank in range(self.dp_size):
|
|
648
|
+
self.input_queues[rank].put(
|
|
649
|
+
(SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
|
|
650
|
+
|
|
651
|
+
total = 0
|
|
652
|
+
for rank in range(self.dp_size):
|
|
653
|
+
count = self._get_result_from_queue(self.output_queues[rank])
|
|
654
|
+
total += count
|
|
655
|
+
return total
|
|
656
|
+
|
|
657
|
+
def has_finished_requests(self) -> bool:
|
|
658
|
+
"""Check if any DP rank has finished requests."""
|
|
659
|
+
for rank in range(self.dp_size):
|
|
660
|
+
self.input_queues[rank].put(
|
|
661
|
+
(SchedulerCommand.HAS_FINISHED_REQUESTS, None))
|
|
662
|
+
|
|
663
|
+
has_finished_any = False
|
|
664
|
+
for rank in range(self.dp_size):
|
|
665
|
+
has_finished_any |= self._get_result_from_queue(
|
|
666
|
+
self.output_queues[rank])
|
|
667
|
+
return has_finished_any
|
|
668
|
+
|
|
669
|
+
def get_request_counts(self) -> Tuple[int, int]:
|
|
670
|
+
"""Get total (running, waiting) request counts across all DP ranks."""
|
|
671
|
+
for rank in range(self.dp_size):
|
|
672
|
+
self.input_queues[rank].put(
|
|
673
|
+
(SchedulerCommand.GET_REQUEST_COUNTS, None))
|
|
674
|
+
|
|
675
|
+
total_running = 0
|
|
676
|
+
total_waiting = 0
|
|
677
|
+
for rank in range(self.dp_size):
|
|
678
|
+
running, waiting = self._get_result_from_queue(
|
|
679
|
+
self.output_queues[rank])
|
|
680
|
+
total_running += running
|
|
681
|
+
total_waiting += waiting
|
|
682
|
+
return total_running, total_waiting
|
|
683
|
+
|
|
684
|
+
def reset_prefix_cache(self) -> bool:
|
|
685
|
+
"""Reset prefix cache for all DP rank schedulers."""
|
|
686
|
+
for rank in range(self.dp_size):
|
|
687
|
+
self.input_queues[rank].put(
|
|
688
|
+
(SchedulerCommand.RESET_PREFIX_CACHE, None))
|
|
689
|
+
|
|
690
|
+
all_success = True
|
|
691
|
+
for rank in range(self.dp_size):
|
|
692
|
+
success = self._get_result_from_queue(self.output_queues[rank])
|
|
693
|
+
all_success &= success
|
|
694
|
+
return all_success
|
|
695
|
+
|
|
696
|
+
def make_stats(self,
|
|
697
|
+
spec_decoding_stats=None,
|
|
698
|
+
kv_connector_stats=None) -> Optional[SchedulerStats]:
|
|
699
|
+
"""Combine stats from all DP rank schedulers."""
|
|
700
|
+
if not self.log_stats:
|
|
701
|
+
return None
|
|
702
|
+
|
|
703
|
+
# Aggregate stats from all schedulers
|
|
704
|
+
total_running_reqs = 0
|
|
705
|
+
total_waiting_reqs = 0
|
|
706
|
+
total_kv_cache_usage = 0.0
|
|
707
|
+
|
|
708
|
+
combined_prefix_cache_stats = PrefixCacheStats()
|
|
709
|
+
combined_connector_prefix_cache_stats: Optional[
|
|
710
|
+
PrefixCacheStats] = None
|
|
711
|
+
|
|
712
|
+
for rank in range(self.dp_size):
|
|
713
|
+
self.input_queues[rank].put(
|
|
714
|
+
(SchedulerCommand.MAKE_STATS, (spec_decoding_stats,
|
|
715
|
+
kv_connector_stats)))
|
|
716
|
+
|
|
717
|
+
for rank in range(self.dp_size):
|
|
718
|
+
rank_stats = self._get_result_from_queue(self.output_queues[rank])
|
|
719
|
+
if rank_stats is None:
|
|
720
|
+
continue
|
|
721
|
+
|
|
722
|
+
total_running_reqs += rank_stats.num_running_reqs
|
|
723
|
+
total_waiting_reqs += rank_stats.num_waiting_reqs
|
|
724
|
+
total_kv_cache_usage += rank_stats.kv_cache_usage
|
|
725
|
+
|
|
726
|
+
# Combine prefix cache stats
|
|
727
|
+
if rank_stats.prefix_cache_stats:
|
|
728
|
+
combined_prefix_cache_stats.reset = rank_stats.prefix_cache_stats.reset
|
|
729
|
+
combined_prefix_cache_stats.requests += rank_stats.prefix_cache_stats.requests
|
|
730
|
+
combined_prefix_cache_stats.queries += rank_stats.prefix_cache_stats.queries
|
|
731
|
+
combined_prefix_cache_stats.hits += rank_stats.prefix_cache_stats.hits
|
|
732
|
+
|
|
733
|
+
# Combine connector prefix cache stats
|
|
734
|
+
if rank_stats.connector_prefix_cache_stats:
|
|
735
|
+
if combined_connector_prefix_cache_stats is None:
|
|
736
|
+
combined_connector_prefix_cache_stats = PrefixCacheStats()
|
|
737
|
+
combined_connector_prefix_cache_stats.reset = rank_stats.connector_prefix_cache_stats.reset
|
|
738
|
+
combined_connector_prefix_cache_stats.requests += rank_stats.connector_prefix_cache_stats.requests
|
|
739
|
+
combined_connector_prefix_cache_stats.queries += rank_stats.connector_prefix_cache_stats.queries
|
|
740
|
+
combined_connector_prefix_cache_stats.hits += rank_stats.connector_prefix_cache_stats.hits
|
|
741
|
+
|
|
742
|
+
# Average KV cache usage across ranks
|
|
743
|
+
avg_kv_cache_usage = total_kv_cache_usage / self.dp_size if self.dp_size else 0.0
|
|
744
|
+
|
|
745
|
+
return SchedulerStats(
|
|
746
|
+
num_running_reqs=total_running_reqs,
|
|
747
|
+
num_waiting_reqs=total_waiting_reqs,
|
|
748
|
+
kv_cache_usage=avg_kv_cache_usage,
|
|
749
|
+
prefix_cache_stats=combined_prefix_cache_stats,
|
|
750
|
+
connector_prefix_cache_stats=combined_connector_prefix_cache_stats,
|
|
751
|
+
spec_decoding_stats=spec_decoding_stats,
|
|
752
|
+
kv_connector_stats=kv_connector_stats.data
|
|
753
|
+
if kv_connector_stats else None,
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
def update_draft_token_ids(self, draft_token_ids) -> None:
|
|
757
|
+
"""Forward draft token updates to the appropriate DP rank schedulers."""
|
|
758
|
+
# Group draft tokens by DP rank based on request assignments
|
|
759
|
+
rank_draft_tokens = defaultdict(lambda: {
|
|
760
|
+
"req_ids": [],
|
|
761
|
+
"draft_token_ids": []
|
|
762
|
+
})
|
|
763
|
+
|
|
764
|
+
for req_id, tokens in zip(draft_token_ids.req_ids,
|
|
765
|
+
draft_token_ids.draft_token_ids):
|
|
766
|
+
if req_id in self.assigned_dp_rank:
|
|
767
|
+
rank = self.assigned_dp_rank[req_id]
|
|
768
|
+
rank_draft_tokens[rank]["req_ids"].append(req_id)
|
|
769
|
+
rank_draft_tokens[rank]["draft_token_ids"].append(tokens)
|
|
770
|
+
|
|
771
|
+
for rank, draft_data in rank_draft_tokens.items():
|
|
772
|
+
# Create a draft_token_ids object for this rank (mock structure)
|
|
773
|
+
rank_draft_token_ids = type(draft_token_ids)(
|
|
774
|
+
req_ids=draft_data["req_ids"],
|
|
775
|
+
draft_token_ids=draft_data["draft_token_ids"])
|
|
776
|
+
self.input_queues[rank].put(
|
|
777
|
+
(SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS,
|
|
778
|
+
rank_draft_token_ids))
|
|
779
|
+
self._get_result_from_queue(self.output_queues[rank])
|
|
780
|
+
|
|
781
|
+
def shutdown(self) -> None:
|
|
782
|
+
"""Shutdown all DP rank scheduler worker processes."""
|
|
783
|
+
# Send shutdown command to all workers
|
|
784
|
+
for rank in range(self.dp_size):
|
|
785
|
+
self.input_queues[rank].put((SchedulerCommand.SHUTDOWN, None))
|
|
786
|
+
|
|
787
|
+
# Wait for acknowledgment (blocking)
|
|
788
|
+
for rank in range(self.dp_size):
|
|
789
|
+
self._get_result_from_queue(self.output_queues[rank])
|
|
790
|
+
|
|
791
|
+
# Terminate and join all processes
|
|
792
|
+
for process in self.processes:
|
|
793
|
+
process.join(timeout=5.0)
|
|
794
|
+
if process.is_alive():
|
|
795
|
+
process.terminate()
|
|
796
|
+
process.join()
|
|
797
|
+
|
|
798
|
+
# Restore original pickle
|
|
799
|
+
_disable_cloudpickle()
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def update_vllm_config_for_dp_scheduler(vllm_config: Any) -> None:
|
|
803
|
+
"""
|
|
804
|
+
Update vLLM configuration to use DPScheduler when DP size > 1.
|
|
805
|
+
"""
|
|
806
|
+
dp_size = vllm_config.sharding_config.total_dp_size
|
|
807
|
+
|
|
808
|
+
if dp_size > 1:
|
|
809
|
+
if vllm_config.scheduler_config.async_scheduling:
|
|
810
|
+
vllm_config.scheduler_config._original_scheduler_cls = AsyncScheduler
|
|
811
|
+
else:
|
|
812
|
+
vllm_config.scheduler_config._original_scheduler_cls = Scheduler
|
|
813
|
+
|
|
814
|
+
vllm_config.scheduler_config.scheduler_cls = DPScheduler
|