tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
from concurrent.futures import Future
|
|
3
|
+
from multiprocessing import Lock
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
from vllm.logger import init_logger
|
|
7
|
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
8
|
+
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
|
9
|
+
from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
|
|
10
|
+
get_open_port)
|
|
11
|
+
from vllm.v1.executor.abstract import Executor
|
|
12
|
+
from vllm.v1.outputs import AsyncModelRunnerOutput
|
|
13
|
+
from vllm.v1.serial_utils import run_method
|
|
14
|
+
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
|
15
|
+
|
|
16
|
+
logger = init_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DisaggExecutor(Executor):
|
|
20
|
+
|
|
21
|
+
def _init_executor(self) -> None:
|
|
22
|
+
"""Initialize the worker and load the model.
|
|
23
|
+
"""
|
|
24
|
+
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
|
25
|
+
rpc_rank=0)
|
|
26
|
+
slice_config = getattr(self.vllm_config.device_config, "slice")
|
|
27
|
+
idx = slice_config[0]
|
|
28
|
+
jax_devices = slice_config[-1]
|
|
29
|
+
devices = []
|
|
30
|
+
if isinstance(idx, int):
|
|
31
|
+
sizes = slice_config[1]
|
|
32
|
+
start = sum(sizes[0:idx])
|
|
33
|
+
end = start + sizes[idx]
|
|
34
|
+
|
|
35
|
+
devices = jax_devices[start:end]
|
|
36
|
+
setattr(self.vllm_config.device_config, "slice",
|
|
37
|
+
(idx + 1, sizes, jax_devices))
|
|
38
|
+
logger.debug(
|
|
39
|
+
f"Creating DisaggExecutor with {devices}, index: {start} -> {end}"
|
|
40
|
+
)
|
|
41
|
+
elif isinstance(idx, tuple):
|
|
42
|
+
slice_idx = slice_config[1]
|
|
43
|
+
sizes = slice_config[2][slice_idx]
|
|
44
|
+
start_row, start_col = idx
|
|
45
|
+
selected_devices = []
|
|
46
|
+
max_row, max_col = 0, 0
|
|
47
|
+
for device in jax_devices:
|
|
48
|
+
coords = device.coords
|
|
49
|
+
max_row = max(max_row, coords[0])
|
|
50
|
+
max_col = max(max_col, coords[1])
|
|
51
|
+
if coords[0] >= start_row and coords[0] < start_row + sizes[0]:
|
|
52
|
+
if coords[1] >= start_col and coords[
|
|
53
|
+
1] < start_col + sizes[1]:
|
|
54
|
+
selected_devices.append(device)
|
|
55
|
+
max_row, max_col = max_row + 1, max_col + 1
|
|
56
|
+
|
|
57
|
+
devices = selected_devices
|
|
58
|
+
if start_col + sizes[1] >= max_col:
|
|
59
|
+
start_row += sizes[0]
|
|
60
|
+
start_col = 0
|
|
61
|
+
else:
|
|
62
|
+
start_col += sizes[1]
|
|
63
|
+
|
|
64
|
+
setattr(self.vllm_config.device_config, "slice",
|
|
65
|
+
((start_row, start_col), slice_idx + 1, slice_config[2],
|
|
66
|
+
jax_devices))
|
|
67
|
+
logger.debug(
|
|
68
|
+
f"Creating DisaggExecutor with {devices}, next start: {((start_row, start_col), slice_idx+1, slice_config[2])}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
distributed_init_method = get_distributed_init_method(
|
|
72
|
+
get_ip(), get_open_port())
|
|
73
|
+
local_rank = 0
|
|
74
|
+
rank = 0
|
|
75
|
+
is_driver_worker = True
|
|
76
|
+
kwargs = dict(
|
|
77
|
+
vllm_config=self.vllm_config,
|
|
78
|
+
local_rank=local_rank,
|
|
79
|
+
rank=rank,
|
|
80
|
+
distributed_init_method=distributed_init_method,
|
|
81
|
+
is_driver_worker=is_driver_worker,
|
|
82
|
+
devices=devices,
|
|
83
|
+
)
|
|
84
|
+
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
|
85
|
+
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
|
86
|
+
self.collective_rpc("init_worker", args=([kwargs], ))
|
|
87
|
+
self.collective_rpc("init_device")
|
|
88
|
+
self.collective_rpc("load_model")
|
|
89
|
+
|
|
90
|
+
def collective_rpc(self,
|
|
91
|
+
method: Union[str, Callable],
|
|
92
|
+
timeout: Optional[float] = None,
|
|
93
|
+
args: Tuple = (),
|
|
94
|
+
kwargs: Optional[Dict] = None,
|
|
95
|
+
non_block: bool = False) -> List[Any]:
|
|
96
|
+
if kwargs is None:
|
|
97
|
+
kwargs = {}
|
|
98
|
+
|
|
99
|
+
if not non_block:
|
|
100
|
+
return [run_method(self.driver_worker, method, args, kwargs)]
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
result = run_method(self.driver_worker, method, args, kwargs)
|
|
104
|
+
if isinstance(result, AsyncModelRunnerOutput):
|
|
105
|
+
if (async_thread := self.async_output_thread) is not None:
|
|
106
|
+
return [async_thread.submit(result.get_output)]
|
|
107
|
+
result = result.get_output()
|
|
108
|
+
future = Future[Any]()
|
|
109
|
+
future.set_result(result)
|
|
110
|
+
except Exception as e:
|
|
111
|
+
future = Future[Any]()
|
|
112
|
+
future.set_exception(e)
|
|
113
|
+
return [future]
|
|
114
|
+
|
|
115
|
+
def check_health(self) -> None:
|
|
116
|
+
# DisaggExecutor will always be healthy as long as
|
|
117
|
+
# it's running.
|
|
118
|
+
return
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
PREFILL_SLICES = 'PREFILL_SLICES'
|
|
7
|
+
DECODE_SLICES = 'DECODE_SLICES'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def is_disagg_enabled() -> bool:
|
|
11
|
+
# We triggrer our code path as long as prefill slices are set. This
|
|
12
|
+
# allows us to test interleave mode effectively with the code path
|
|
13
|
+
# for comparison purposes.
|
|
14
|
+
return PREFILL_SLICES in os.environ
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
18
|
+
"""Parse slices environment variable and return the a list of integers, each the size of a slice.
|
|
19
|
+
|
|
20
|
+
For example, if slices_str is set to `2x2,2x1,2x4`, we should return `(4, 2, 8)`.
|
|
21
|
+
|
|
22
|
+
Throws exception if the slice str is malformed.
|
|
23
|
+
"""
|
|
24
|
+
if not slices_str:
|
|
25
|
+
return ()
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
slice_sizes = []
|
|
29
|
+
for s in slices_str.split(','):
|
|
30
|
+
dims = s.split('x')
|
|
31
|
+
if len(dims) == 1:
|
|
32
|
+
slice_sizes.append(int(dims[0]))
|
|
33
|
+
elif len(dims) == 2:
|
|
34
|
+
slice_sizes.append((int(dims[0]), int(dims[1])))
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError("Each slice must be in 'N' or 'NxM' format.")
|
|
37
|
+
return tuple(slice_sizes)
|
|
38
|
+
except ValueError as e:
|
|
39
|
+
raise ValueError(f"Malformed slice string: '{slices_str}'") from e
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_prefill_slices() -> Tuple[int, ...]:
|
|
43
|
+
if PREFILL_SLICES not in os.environ:
|
|
44
|
+
return ()
|
|
45
|
+
return _parse_slices(os.environ[PREFILL_SLICES])
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_decode_slices() -> Tuple[int, ...]:
|
|
49
|
+
if DECODE_SLICES not in os.environ:
|
|
50
|
+
return ()
|
|
51
|
+
return _parse_slices(os.environ[DECODE_SLICES])
|
|
File without changes
|
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections import defaultdict, deque
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from vllm.config import VllmConfig
|
|
8
|
+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
9
|
+
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
|
10
|
+
from vllm.v1.core.sched.interface import SchedulerInterface
|
|
11
|
+
from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
|
|
12
|
+
SchedulerOutput)
|
|
13
|
+
from vllm.v1.core.sched.scheduler import Scheduler
|
|
14
|
+
from vllm.v1.engine import EngineCoreOutputs
|
|
15
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
16
|
+
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
|
17
|
+
from vllm.v1.outputs import ModelRunnerOutput
|
|
18
|
+
from vllm.v1.request import Request
|
|
19
|
+
from vllm.v1.structured_output import StructuredOutputManager
|
|
20
|
+
|
|
21
|
+
from tpu_inference.logger import init_logger
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class DPSchedulerOutput(SchedulerOutput):
|
|
28
|
+
"""Extended SchedulerOutput that includes DP rank assignments."""
|
|
29
|
+
assigned_dp_rank: Optional[Dict[str, int]] = None
|
|
30
|
+
|
|
31
|
+
def __init__(self, *args, assigned_dp_rank=None, **kwargs):
|
|
32
|
+
super().__init__(*args, **kwargs)
|
|
33
|
+
self.assigned_dp_rank = assigned_dp_rank or {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DPScheduler(SchedulerInterface):
|
|
37
|
+
"""
|
|
38
|
+
DPScheduler is used when DP size is >=2. Otherwise the default vLLM scheduler is used.
|
|
39
|
+
|
|
40
|
+
The DPScheduler manages:
|
|
41
|
+
1. Multiple vLLM Schedulers (one per DP rank)
|
|
42
|
+
2. Request-to-scheduler assignment
|
|
43
|
+
|
|
44
|
+
Each Scheduler manages its own logical KV cache shard and scheduling logic.
|
|
45
|
+
|
|
46
|
+
**Load Balancing**
|
|
47
|
+
|
|
48
|
+
For new requests:
|
|
49
|
+
- If there is prefix cache hit, assigns request to the rank with the best hit
|
|
50
|
+
- Otherwise, assigns request to the rank with the least total tokens
|
|
51
|
+
|
|
52
|
+
Once a DP rank is assigned to a request, it remains fixed for the request's lifetime.
|
|
53
|
+
A request will be freed from its assigned rank when it is completed or preempted.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
vllm_config: VllmConfig,
|
|
59
|
+
kv_cache_config: KVCacheConfig,
|
|
60
|
+
structured_output_manager: StructuredOutputManager,
|
|
61
|
+
block_size: int,
|
|
62
|
+
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
|
63
|
+
include_finished_set: bool = False,
|
|
64
|
+
log_stats: bool = False,
|
|
65
|
+
) -> None:
|
|
66
|
+
self.vllm_config = vllm_config
|
|
67
|
+
self.block_size = block_size
|
|
68
|
+
self.log_stats = log_stats
|
|
69
|
+
self.connector = None
|
|
70
|
+
self.structured_output_manager = structured_output_manager
|
|
71
|
+
|
|
72
|
+
# DP state
|
|
73
|
+
self.dp_size = vllm_config.sharding_config.total_dp_size
|
|
74
|
+
self.assigned_dp_rank: Dict[str, int] = {} # req_id -> dp_rank
|
|
75
|
+
self.cached_schedulers_output = deque()
|
|
76
|
+
self._create_per_rank_configs(kv_cache_config)
|
|
77
|
+
|
|
78
|
+
# The original scheduler class could be Scheduler or AsyncScheduler
|
|
79
|
+
original_scheduler_cls = vllm_config.scheduler_config._original_scheduler_cls
|
|
80
|
+
self.schedulers: List[Scheduler] = []
|
|
81
|
+
for rank in range(self.dp_size):
|
|
82
|
+
scheduler = original_scheduler_cls(
|
|
83
|
+
vllm_config=self.vllm_config,
|
|
84
|
+
kv_cache_config=self.per_rank_kv_cache_configs[rank],
|
|
85
|
+
structured_output_manager=structured_output_manager,
|
|
86
|
+
block_size=block_size,
|
|
87
|
+
mm_registry=mm_registry,
|
|
88
|
+
include_finished_set=include_finished_set,
|
|
89
|
+
log_stats=log_stats,
|
|
90
|
+
)
|
|
91
|
+
self.schedulers.append(scheduler)
|
|
92
|
+
|
|
93
|
+
logger.info(
|
|
94
|
+
f"DPScheduler (Async = {self.vllm_config.scheduler_config.async_scheduling}) "
|
|
95
|
+
f"per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
|
|
96
|
+
f"max_tokens={self.vllm_config.scheduler_config.max_num_batched_tokens}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _create_per_rank_configs(self, kv_cache_config: KVCacheConfig) -> None:
|
|
100
|
+
self.per_rank_kv_cache_configs: List[KVCacheConfig] = []
|
|
101
|
+
for _ in range(self.dp_size):
|
|
102
|
+
rank_config = copy.deepcopy(kv_cache_config)
|
|
103
|
+
rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
|
|
104
|
+
self.per_rank_kv_cache_configs.append(rank_config)
|
|
105
|
+
|
|
106
|
+
def _get_rank_token_counts(self) -> Dict[int, int]:
|
|
107
|
+
"""Calculate total tokens currently assigned to each DP rank."""
|
|
108
|
+
rank_tokens = {rank: 0 for rank in range(self.dp_size)}
|
|
109
|
+
|
|
110
|
+
for rank, scheduler in enumerate(self.schedulers):
|
|
111
|
+
for request in scheduler.running:
|
|
112
|
+
rank_tokens[rank] += request.num_tokens
|
|
113
|
+
for request in scheduler.waiting:
|
|
114
|
+
rank_tokens[rank] += request.num_tokens
|
|
115
|
+
|
|
116
|
+
return rank_tokens
|
|
117
|
+
|
|
118
|
+
def _find_best_rank_for_request(self, request: Request) -> int:
|
|
119
|
+
"""Find the best DP rank for a new request based on load balancing."""
|
|
120
|
+
rank_tokens = self._get_rank_token_counts()
|
|
121
|
+
|
|
122
|
+
# First, try to find a rank with prefix cache hit
|
|
123
|
+
best_cache_rank = None
|
|
124
|
+
best_cache_tokens = 0
|
|
125
|
+
for rank, scheduler in enumerate(self.schedulers):
|
|
126
|
+
blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
|
|
127
|
+
request)
|
|
128
|
+
if cached_tokens > best_cache_tokens:
|
|
129
|
+
best_cache_tokens = cached_tokens
|
|
130
|
+
best_cache_rank = rank
|
|
131
|
+
if best_cache_tokens > 0:
|
|
132
|
+
return best_cache_rank
|
|
133
|
+
|
|
134
|
+
# Otherwise, find rank with least tokens
|
|
135
|
+
selected_rank = min(rank_tokens, key=rank_tokens.get)
|
|
136
|
+
return selected_rank
|
|
137
|
+
|
|
138
|
+
def add_request(self, request: Request) -> None:
|
|
139
|
+
"""
|
|
140
|
+
Add a new request to the appropriate DP rank scheduler.
|
|
141
|
+
|
|
142
|
+
This is the main entry point for new requests. The scheduler will:
|
|
143
|
+
1. Determine the best DP rank for the request (load balancing + cache hits)
|
|
144
|
+
2. Assign the request to that rank
|
|
145
|
+
3. Add the request to the rank's scheduler
|
|
146
|
+
"""
|
|
147
|
+
assert request.request_id not in self.assigned_dp_rank, (
|
|
148
|
+
f"Request {request.request_id} already "
|
|
149
|
+
f"assigned to rank {self.assigned_dp_rank[request.request_id]})")
|
|
150
|
+
rank = self._find_best_rank_for_request(request)
|
|
151
|
+
self.assigned_dp_rank[request.request_id] = rank
|
|
152
|
+
self.schedulers[rank].add_request(request)
|
|
153
|
+
|
|
154
|
+
def schedule(self) -> DPSchedulerOutput:
|
|
155
|
+
"""
|
|
156
|
+
Main scheduling method that coordinates all DP rank schedulers.
|
|
157
|
+
|
|
158
|
+
Process:
|
|
159
|
+
1. Add any new requests to appropriate DP ranks
|
|
160
|
+
2. Run each scheduler independently
|
|
161
|
+
3. Combine outputs from all schedulers
|
|
162
|
+
4. Return unified scheduling result
|
|
163
|
+
"""
|
|
164
|
+
# Run each scheduler independently
|
|
165
|
+
rank_outputs = []
|
|
166
|
+
for rank, scheduler in enumerate(self.schedulers):
|
|
167
|
+
logger.debug(
|
|
168
|
+
f"Running scheduler for rank {rank}: "
|
|
169
|
+
f"{len(scheduler.running)} running, {len(scheduler.waiting)} waiting"
|
|
170
|
+
)
|
|
171
|
+
output = scheduler.schedule()
|
|
172
|
+
rank_outputs.append(output)
|
|
173
|
+
|
|
174
|
+
# Cache scheduler outputs to use in `update_from_output`
|
|
175
|
+
self.cached_schedulers_output.append(rank_outputs)
|
|
176
|
+
|
|
177
|
+
# Return combined scheduler outputs
|
|
178
|
+
combined_output = self._combine_scheduler_outputs(rank_outputs)
|
|
179
|
+
|
|
180
|
+
logger.debug(
|
|
181
|
+
f"DPScheduler scheduled: "
|
|
182
|
+
f"{combined_output.total_num_scheduled_tokens} total tokens, "
|
|
183
|
+
f"{len(combined_output.scheduled_new_reqs)} new requests, "
|
|
184
|
+
f"{len(combined_output.scheduled_cached_reqs.req_ids)} cached requests"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return combined_output
|
|
188
|
+
|
|
189
|
+
def _combine_scheduler_outputs(
|
|
190
|
+
self, rank_outputs: List[SchedulerOutput]) -> DPSchedulerOutput:
|
|
191
|
+
"""Combine outputs from all DP rank schedulers into a unified output."""
|
|
192
|
+
|
|
193
|
+
# Combine new requests
|
|
194
|
+
all_new_reqs = []
|
|
195
|
+
for output in rank_outputs:
|
|
196
|
+
all_new_reqs.extend(output.scheduled_new_reqs)
|
|
197
|
+
|
|
198
|
+
# Combine cached request data
|
|
199
|
+
combined_cached_data = self._combine_cached_request_data(rank_outputs)
|
|
200
|
+
|
|
201
|
+
# Combine token counts and other metrics
|
|
202
|
+
combined_num_scheduled_tokens = {}
|
|
203
|
+
combined_spec_decode_tokens = {}
|
|
204
|
+
combined_encoder_inputs = {}
|
|
205
|
+
total_scheduled_tokens = 0
|
|
206
|
+
|
|
207
|
+
for output in rank_outputs:
|
|
208
|
+
combined_num_scheduled_tokens.update(output.num_scheduled_tokens)
|
|
209
|
+
combined_spec_decode_tokens.update(
|
|
210
|
+
output.scheduled_spec_decode_tokens)
|
|
211
|
+
combined_encoder_inputs.update(output.scheduled_encoder_inputs)
|
|
212
|
+
total_scheduled_tokens += output.total_num_scheduled_tokens
|
|
213
|
+
|
|
214
|
+
# Combine finished request IDs
|
|
215
|
+
combined_finished_req_ids = set()
|
|
216
|
+
for output in rank_outputs:
|
|
217
|
+
combined_finished_req_ids.update(output.finished_req_ids)
|
|
218
|
+
|
|
219
|
+
# Combine other fields (take from first non-empty or use defaults)
|
|
220
|
+
num_common_prefix_blocks = rank_outputs[
|
|
221
|
+
0].num_common_prefix_blocks if rank_outputs else []
|
|
222
|
+
|
|
223
|
+
# Create DP rank assignment mapping for scheduled requests
|
|
224
|
+
assigned_dp_rank = {}
|
|
225
|
+
for req_id in combined_num_scheduled_tokens.keys():
|
|
226
|
+
assigned_dp_rank[req_id] = self.assigned_dp_rank[req_id]
|
|
227
|
+
|
|
228
|
+
return DPSchedulerOutput(
|
|
229
|
+
scheduled_new_reqs=all_new_reqs,
|
|
230
|
+
scheduled_cached_reqs=combined_cached_data,
|
|
231
|
+
num_scheduled_tokens=combined_num_scheduled_tokens,
|
|
232
|
+
total_num_scheduled_tokens=total_scheduled_tokens,
|
|
233
|
+
scheduled_spec_decode_tokens=combined_spec_decode_tokens,
|
|
234
|
+
scheduled_encoder_inputs=combined_encoder_inputs,
|
|
235
|
+
num_common_prefix_blocks=num_common_prefix_blocks,
|
|
236
|
+
finished_req_ids=combined_finished_req_ids,
|
|
237
|
+
free_encoder_mm_hashes=set(),
|
|
238
|
+
assigned_dp_rank=assigned_dp_rank,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def _combine_cached_request_data(
|
|
242
|
+
self, rank_outputs: List[SchedulerOutput]) -> CachedRequestData:
|
|
243
|
+
"""Combine cached request data from all DP rank schedulers."""
|
|
244
|
+
combined_req_ids = []
|
|
245
|
+
combined_resumed_req_ids = []
|
|
246
|
+
combined_new_token_ids = []
|
|
247
|
+
combined_all_token_ids = []
|
|
248
|
+
combined_new_block_ids = []
|
|
249
|
+
combined_num_computed_tokens = []
|
|
250
|
+
combined_num_output_tokens = []
|
|
251
|
+
|
|
252
|
+
for output in rank_outputs:
|
|
253
|
+
cached_data = output.scheduled_cached_reqs
|
|
254
|
+
|
|
255
|
+
combined_req_ids.extend(cached_data.req_ids)
|
|
256
|
+
combined_resumed_req_ids.extend(cached_data.resumed_req_ids)
|
|
257
|
+
combined_new_token_ids.extend(cached_data.new_token_ids)
|
|
258
|
+
combined_all_token_ids.extend(cached_data.all_token_ids)
|
|
259
|
+
combined_new_block_ids.extend(cached_data.new_block_ids)
|
|
260
|
+
combined_num_computed_tokens.extend(
|
|
261
|
+
cached_data.num_computed_tokens)
|
|
262
|
+
combined_num_output_tokens.extend(cached_data.num_output_tokens)
|
|
263
|
+
|
|
264
|
+
return CachedRequestData(
|
|
265
|
+
req_ids=combined_req_ids,
|
|
266
|
+
resumed_req_ids=combined_resumed_req_ids,
|
|
267
|
+
new_token_ids=combined_new_token_ids,
|
|
268
|
+
all_token_ids=combined_all_token_ids,
|
|
269
|
+
new_block_ids=combined_new_block_ids,
|
|
270
|
+
num_computed_tokens=combined_num_computed_tokens,
|
|
271
|
+
num_output_tokens=combined_num_output_tokens,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def get_grammar_bitmask(
|
|
275
|
+
self,
|
|
276
|
+
scheduler_output: DPSchedulerOutput,
|
|
277
|
+
) -> GrammarOutput | None:
|
|
278
|
+
"""
|
|
279
|
+
Generate grammar bitmask for structured output requests across all DP ranks.
|
|
280
|
+
|
|
281
|
+
This method calls get_grammar_bitmask on each underlying scheduler and
|
|
282
|
+
combines their outputs, similar to how other operations are handled.
|
|
283
|
+
"""
|
|
284
|
+
# Use the most recent cached outputs from the schedule() call
|
|
285
|
+
if not self.cached_schedulers_output:
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
rank_scheduler_outputs = self.cached_schedulers_output[
|
|
289
|
+
-1] # Get the most recent
|
|
290
|
+
|
|
291
|
+
combined_structured_output_request_ids = []
|
|
292
|
+
combined_bitmasks = []
|
|
293
|
+
|
|
294
|
+
# Get grammar bitmask from each DP rank scheduler
|
|
295
|
+
for rank, scheduler in enumerate(self.schedulers):
|
|
296
|
+
rank_output = rank_scheduler_outputs[rank]
|
|
297
|
+
grammar_output = scheduler.get_grammar_bitmask(rank_output)
|
|
298
|
+
|
|
299
|
+
if grammar_output is not None:
|
|
300
|
+
combined_structured_output_request_ids.extend(
|
|
301
|
+
grammar_output.structured_output_request_ids)
|
|
302
|
+
combined_bitmasks.append(grammar_output.grammar_bitmask)
|
|
303
|
+
|
|
304
|
+
if not combined_structured_output_request_ids:
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
# Combine bitmasks - concatenate along the batch dimension
|
|
308
|
+
if len(combined_bitmasks) == 1:
|
|
309
|
+
combined_bitmask = combined_bitmasks[0]
|
|
310
|
+
else:
|
|
311
|
+
combined_bitmask = torch.cat(combined_bitmasks, dim=0)
|
|
312
|
+
|
|
313
|
+
return GrammarOutput(combined_structured_output_request_ids,
|
|
314
|
+
combined_bitmask)
|
|
315
|
+
|
|
316
|
+
def update_from_output(
|
|
317
|
+
self, scheduler_output: DPSchedulerOutput,
|
|
318
|
+
model_runner_output: ModelRunnerOutput
|
|
319
|
+
) -> dict[int, EngineCoreOutputs]:
|
|
320
|
+
"""
|
|
321
|
+
Update all DP rank schedulers based on model runner output.
|
|
322
|
+
|
|
323
|
+
We need to route the model runner output to the appropriate scheduler
|
|
324
|
+
based on which rank each request belongs to.
|
|
325
|
+
"""
|
|
326
|
+
# Group model runner outputs by DP rank
|
|
327
|
+
rank_model_outputs = self._split_model_output_by_rank(
|
|
328
|
+
model_runner_output)
|
|
329
|
+
rank_scheduler_outputs = self.cached_schedulers_output.popleft()
|
|
330
|
+
# Update each scheduler with its portion of the output
|
|
331
|
+
combined_engine_outputs = defaultdict(list)
|
|
332
|
+
for rank, scheduler in enumerate(self.schedulers):
|
|
333
|
+
rank_engine_outputs = scheduler.update_from_output(
|
|
334
|
+
rank_scheduler_outputs[rank], rank_model_outputs[rank])
|
|
335
|
+
for client_idx, engine_output in rank_engine_outputs.items():
|
|
336
|
+
combined_engine_outputs[client_idx].append(engine_output)
|
|
337
|
+
|
|
338
|
+
# Clean up finished requests from DP tracking
|
|
339
|
+
self._cleanup_finished_requests(scheduler_output.finished_req_ids)
|
|
340
|
+
|
|
341
|
+
# Return combined EngineCoreOutput
|
|
342
|
+
for client_idx, engine_outputs in combined_engine_outputs.items():
|
|
343
|
+
combined_output = EngineCoreOutputs()
|
|
344
|
+
outputs = []
|
|
345
|
+
finished_requests = set()
|
|
346
|
+
for engine_output in engine_outputs:
|
|
347
|
+
outputs.extend(engine_output.outputs)
|
|
348
|
+
if engine_output.finished_requests:
|
|
349
|
+
finished_requests.update(engine_output.finished_requests)
|
|
350
|
+
combined_output.engine_index = engine_outputs[0].engine_index
|
|
351
|
+
combined_output.outputs = outputs
|
|
352
|
+
combined_output.finished_requests = finished_requests
|
|
353
|
+
combined_output.scheduler_stats = self.make_stats()
|
|
354
|
+
combined_engine_outputs[client_idx] = combined_output
|
|
355
|
+
|
|
356
|
+
return combined_engine_outputs
|
|
357
|
+
|
|
358
|
+
def _split_model_output_by_rank(
|
|
359
|
+
self,
|
|
360
|
+
global_model_output: ModelRunnerOutput) -> List[ModelRunnerOutput]:
|
|
361
|
+
"""Split the model runner output by DP rank for individual scheduler updates."""
|
|
362
|
+
outputs = [
|
|
363
|
+
ModelRunnerOutput(
|
|
364
|
+
req_ids=[],
|
|
365
|
+
req_id_to_index=global_model_output.req_id_to_index,
|
|
366
|
+
sampled_token_ids=global_model_output.sampled_token_ids,
|
|
367
|
+
logprobs=global_model_output.logprobs,
|
|
368
|
+
prompt_logprobs_dict=global_model_output.prompt_logprobs_dict,
|
|
369
|
+
pooler_output=None,
|
|
370
|
+
num_nans_in_logits=global_model_output.num_nans_in_logits,
|
|
371
|
+
kv_connector_output=global_model_output.kv_connector_output,
|
|
372
|
+
) for _ in range(self.dp_size)
|
|
373
|
+
]
|
|
374
|
+
|
|
375
|
+
for req_id in global_model_output.req_ids:
|
|
376
|
+
rank = self.assigned_dp_rank[req_id]
|
|
377
|
+
outputs[rank].req_ids.append(req_id)
|
|
378
|
+
|
|
379
|
+
return outputs
|
|
380
|
+
|
|
381
|
+
def _cleanup_finished_requests(self, finished_req_ids: set[str]) -> None:
|
|
382
|
+
"""Remove finished requests from our DP rank assignment tracking."""
|
|
383
|
+
for req_id in finished_req_ids:
|
|
384
|
+
if req_id in self.assigned_dp_rank:
|
|
385
|
+
del self.assigned_dp_rank[req_id]
|
|
386
|
+
|
|
387
|
+
def finish_requests(self, request_ids, finished_status) -> None:
|
|
388
|
+
"""Forward request finish signals to the appropriate DP rank schedulers."""
|
|
389
|
+
if isinstance(request_ids, str):
|
|
390
|
+
request_ids = [request_ids]
|
|
391
|
+
|
|
392
|
+
# Route finish signals to appropriate schedulers
|
|
393
|
+
rank_request_ids = defaultdict(list)
|
|
394
|
+
for req_id in request_ids:
|
|
395
|
+
rank = self.assigned_dp_rank[req_id]
|
|
396
|
+
rank_request_ids[rank].append(req_id)
|
|
397
|
+
|
|
398
|
+
# Forward to each scheduler
|
|
399
|
+
for rank, req_ids in rank_request_ids.items():
|
|
400
|
+
self.schedulers[rank].finish_requests(req_ids, finished_status)
|
|
401
|
+
|
|
402
|
+
def get_num_unfinished_requests(self) -> int:
|
|
403
|
+
"""Get total number of unfinished requests across all DP ranks."""
|
|
404
|
+
return sum(scheduler.get_num_unfinished_requests()
|
|
405
|
+
for scheduler in self.schedulers)
|
|
406
|
+
|
|
407
|
+
def has_finished_requests(self) -> bool:
|
|
408
|
+
"""Check if any DP rank has finished requests."""
|
|
409
|
+
return any(scheduler.has_finished_requests()
|
|
410
|
+
for scheduler in self.schedulers)
|
|
411
|
+
|
|
412
|
+
def get_request_counts(self) -> Tuple[int, int]:
|
|
413
|
+
"""Get total (running, waiting) request counts across all DP ranks."""
|
|
414
|
+
total_running = sum(
|
|
415
|
+
len(scheduler.running) for scheduler in self.schedulers)
|
|
416
|
+
total_waiting = sum(
|
|
417
|
+
len(scheduler.waiting) for scheduler in self.schedulers)
|
|
418
|
+
return total_running, total_waiting
|
|
419
|
+
|
|
420
|
+
def reset_prefix_cache(self) -> bool:
|
|
421
|
+
"""Reset prefix cache for all DP rank schedulers."""
|
|
422
|
+
return all(scheduler.reset_prefix_cache()
|
|
423
|
+
for scheduler in self.schedulers)
|
|
424
|
+
|
|
425
|
+
def make_stats(self,
|
|
426
|
+
spec_decoding_stats=None,
|
|
427
|
+
kv_connector_stats=None) -> Optional[SchedulerStats]:
|
|
428
|
+
"""Combine stats from all DP rank schedulers."""
|
|
429
|
+
if not self.log_stats:
|
|
430
|
+
return None
|
|
431
|
+
|
|
432
|
+
# Aggregate stats from all schedulers
|
|
433
|
+
total_running_reqs = 0
|
|
434
|
+
total_waiting_reqs = 0
|
|
435
|
+
total_kv_cache_usage = 0.0
|
|
436
|
+
|
|
437
|
+
combined_prefix_cache_stats = PrefixCacheStats()
|
|
438
|
+
combined_connector_prefix_cache_stats: Optional[
|
|
439
|
+
PrefixCacheStats] = None
|
|
440
|
+
|
|
441
|
+
for scheduler in self.schedulers:
|
|
442
|
+
rank_stats = scheduler.make_stats(spec_decoding_stats,
|
|
443
|
+
kv_connector_stats)
|
|
444
|
+
if rank_stats is None:
|
|
445
|
+
continue
|
|
446
|
+
|
|
447
|
+
total_running_reqs += rank_stats.num_running_reqs
|
|
448
|
+
total_waiting_reqs += rank_stats.num_waiting_reqs
|
|
449
|
+
total_kv_cache_usage += rank_stats.kv_cache_usage
|
|
450
|
+
|
|
451
|
+
# Combine prefix cache stats
|
|
452
|
+
if rank_stats.prefix_cache_stats:
|
|
453
|
+
combined_prefix_cache_stats.reset = rank_stats.prefix_cache_stats.reset
|
|
454
|
+
combined_prefix_cache_stats.requests += rank_stats.prefix_cache_stats.requests
|
|
455
|
+
combined_prefix_cache_stats.queries += rank_stats.prefix_cache_stats.queries
|
|
456
|
+
combined_prefix_cache_stats.hits += rank_stats.prefix_cache_stats.hits
|
|
457
|
+
|
|
458
|
+
# Combine connector prefix cache stats
|
|
459
|
+
if rank_stats.connector_prefix_cache_stats:
|
|
460
|
+
if combined_connector_prefix_cache_stats is None:
|
|
461
|
+
combined_connector_prefix_cache_stats = PrefixCacheStats()
|
|
462
|
+
combined_connector_prefix_cache_stats.reset = rank_stats.connector_prefix_cache_stats.reset
|
|
463
|
+
combined_connector_prefix_cache_stats.requests += rank_stats.connector_prefix_cache_stats.requests
|
|
464
|
+
combined_connector_prefix_cache_stats.queries += rank_stats.connector_prefix_cache_stats.queries
|
|
465
|
+
combined_connector_prefix_cache_stats.hits += rank_stats.connector_prefix_cache_stats.hits
|
|
466
|
+
|
|
467
|
+
# Average KV cache usage across ranks
|
|
468
|
+
avg_kv_cache_usage = total_kv_cache_usage / len(
|
|
469
|
+
self.schedulers) if self.schedulers else 0.0
|
|
470
|
+
|
|
471
|
+
return SchedulerStats(
|
|
472
|
+
num_running_reqs=total_running_reqs,
|
|
473
|
+
num_waiting_reqs=total_waiting_reqs,
|
|
474
|
+
kv_cache_usage=avg_kv_cache_usage,
|
|
475
|
+
prefix_cache_stats=combined_prefix_cache_stats,
|
|
476
|
+
connector_prefix_cache_stats=combined_connector_prefix_cache_stats,
|
|
477
|
+
spec_decoding_stats=spec_decoding_stats,
|
|
478
|
+
kv_connector_stats=kv_connector_stats.data
|
|
479
|
+
if kv_connector_stats else None,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def update_draft_token_ids(self, draft_token_ids) -> None:
|
|
483
|
+
"""Forward draft token updates to the appropriate DP rank schedulers."""
|
|
484
|
+
# Group draft tokens by DP rank based on request assignments
|
|
485
|
+
rank_draft_tokens = defaultdict(lambda: {
|
|
486
|
+
"req_ids": [],
|
|
487
|
+
"draft_token_ids": []
|
|
488
|
+
})
|
|
489
|
+
|
|
490
|
+
for req_id, tokens in zip(draft_token_ids.req_ids,
|
|
491
|
+
draft_token_ids.draft_token_ids):
|
|
492
|
+
if req_id in self.assigned_dp_rank:
|
|
493
|
+
rank = self.assigned_dp_rank[req_id]
|
|
494
|
+
rank_draft_tokens[rank]["req_ids"].append(req_id)
|
|
495
|
+
rank_draft_tokens[rank]["draft_token_ids"].append(tokens)
|
|
496
|
+
|
|
497
|
+
# Forward to each scheduler
|
|
498
|
+
for rank, draft_data in rank_draft_tokens.items():
|
|
499
|
+
# Create a draft_token_ids object for this rank (mock structure)
|
|
500
|
+
rank_draft_token_ids = type(draft_token_ids)(
|
|
501
|
+
req_ids=draft_data["req_ids"],
|
|
502
|
+
draft_token_ids=draft_data["draft_token_ids"])
|
|
503
|
+
self.schedulers[rank].update_draft_token_ids(rank_draft_token_ids)
|
|
504
|
+
|
|
505
|
+
def shutdown(self) -> None:
|
|
506
|
+
"""Shutdown all DP rank schedulers."""
|
|
507
|
+
for scheduler in self.schedulers:
|
|
508
|
+
scheduler.shutdown()
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def update_vllm_config_for_dp_scheduler(vllm_config: Any) -> None:
|
|
512
|
+
"""
|
|
513
|
+
Update vLLM configuration to use DPScheduler when DP size > 1.
|
|
514
|
+
"""
|
|
515
|
+
dp_size = vllm_config.sharding_config.total_dp_size
|
|
516
|
+
|
|
517
|
+
if dp_size > 1:
|
|
518
|
+
if vllm_config.scheduler_config.async_scheduling:
|
|
519
|
+
vllm_config.scheduler_config._original_scheduler_cls = AsyncScheduler
|
|
520
|
+
else:
|
|
521
|
+
vllm_config.scheduler_config._original_scheduler_cls = Scheduler
|
|
522
|
+
|
|
523
|
+
vllm_config.scheduler_config.scheduler_cls = DPScheduler
|
|
File without changes
|