tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,118 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from concurrent.futures import Future
3
+ from multiprocessing import Lock
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ from vllm.logger import init_logger
7
+ from vllm.multimodal import MULTIMODAL_REGISTRY
8
+ from vllm.multimodal.cache import worker_receiver_cache_from_config
9
+ from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
10
+ get_open_port)
11
+ from vllm.v1.executor.abstract import Executor
12
+ from vllm.v1.outputs import AsyncModelRunnerOutput
13
+ from vllm.v1.serial_utils import run_method
14
+ from vllm.v1.worker.worker_base import WorkerWrapperBase
15
+
16
+ logger = init_logger(__name__)
17
+
18
+
19
+ class DisaggExecutor(Executor):
20
+
21
+ def _init_executor(self) -> None:
22
+ """Initialize the worker and load the model.
23
+ """
24
+ self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
25
+ rpc_rank=0)
26
+ slice_config = getattr(self.vllm_config.device_config, "slice")
27
+ idx = slice_config[0]
28
+ jax_devices = slice_config[-1]
29
+ devices = []
30
+ if isinstance(idx, int):
31
+ sizes = slice_config[1]
32
+ start = sum(sizes[0:idx])
33
+ end = start + sizes[idx]
34
+
35
+ devices = jax_devices[start:end]
36
+ setattr(self.vllm_config.device_config, "slice",
37
+ (idx + 1, sizes, jax_devices))
38
+ logger.debug(
39
+ f"Creating DisaggExecutor with {devices}, index: {start} -> {end}"
40
+ )
41
+ elif isinstance(idx, tuple):
42
+ slice_idx = slice_config[1]
43
+ sizes = slice_config[2][slice_idx]
44
+ start_row, start_col = idx
45
+ selected_devices = []
46
+ max_row, max_col = 0, 0
47
+ for device in jax_devices:
48
+ coords = device.coords
49
+ max_row = max(max_row, coords[0])
50
+ max_col = max(max_col, coords[1])
51
+ if coords[0] >= start_row and coords[0] < start_row + sizes[0]:
52
+ if coords[1] >= start_col and coords[
53
+ 1] < start_col + sizes[1]:
54
+ selected_devices.append(device)
55
+ max_row, max_col = max_row + 1, max_col + 1
56
+
57
+ devices = selected_devices
58
+ if start_col + sizes[1] >= max_col:
59
+ start_row += sizes[0]
60
+ start_col = 0
61
+ else:
62
+ start_col += sizes[1]
63
+
64
+ setattr(self.vllm_config.device_config, "slice",
65
+ ((start_row, start_col), slice_idx + 1, slice_config[2],
66
+ jax_devices))
67
+ logger.debug(
68
+ f"Creating DisaggExecutor with {devices}, next start: {((start_row, start_col), slice_idx+1, slice_config[2])}"
69
+ )
70
+
71
+ distributed_init_method = get_distributed_init_method(
72
+ get_ip(), get_open_port())
73
+ local_rank = 0
74
+ rank = 0
75
+ is_driver_worker = True
76
+ kwargs = dict(
77
+ vllm_config=self.vllm_config,
78
+ local_rank=local_rank,
79
+ rank=rank,
80
+ distributed_init_method=distributed_init_method,
81
+ is_driver_worker=is_driver_worker,
82
+ devices=devices,
83
+ )
84
+ self.mm_receiver_cache = worker_receiver_cache_from_config(
85
+ self.vllm_config, MULTIMODAL_REGISTRY, Lock())
86
+ self.collective_rpc("init_worker", args=([kwargs], ))
87
+ self.collective_rpc("init_device")
88
+ self.collective_rpc("load_model")
89
+
90
+ def collective_rpc(self,
91
+ method: Union[str, Callable],
92
+ timeout: Optional[float] = None,
93
+ args: Tuple = (),
94
+ kwargs: Optional[Dict] = None,
95
+ non_block: bool = False) -> List[Any]:
96
+ if kwargs is None:
97
+ kwargs = {}
98
+
99
+ if not non_block:
100
+ return [run_method(self.driver_worker, method, args, kwargs)]
101
+
102
+ try:
103
+ result = run_method(self.driver_worker, method, args, kwargs)
104
+ if isinstance(result, AsyncModelRunnerOutput):
105
+ if (async_thread := self.async_output_thread) is not None:
106
+ return [async_thread.submit(result.get_output)]
107
+ result = result.get_output()
108
+ future = Future[Any]()
109
+ future.set_result(result)
110
+ except Exception as e:
111
+ future = Future[Any]()
112
+ future.set_exception(e)
113
+ return [future]
114
+
115
+ def check_health(self) -> None:
116
+ # DisaggExecutor will always be healthy as long as
117
+ # it's running.
118
+ return
@@ -0,0 +1,51 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from typing import Tuple
5
+
6
+ PREFILL_SLICES = 'PREFILL_SLICES'
7
+ DECODE_SLICES = 'DECODE_SLICES'
8
+
9
+
10
+ def is_disagg_enabled() -> bool:
11
+ # We triggrer our code path as long as prefill slices are set. This
12
+ # allows us to test interleave mode effectively with the code path
13
+ # for comparison purposes.
14
+ return PREFILL_SLICES in os.environ
15
+
16
+
17
+ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
18
+ """Parse slices environment variable and return the a list of integers, each the size of a slice.
19
+
20
+ For example, if slices_str is set to `2x2,2x1,2x4`, we should return `(4, 2, 8)`.
21
+
22
+ Throws exception if the slice str is malformed.
23
+ """
24
+ if not slices_str:
25
+ return ()
26
+
27
+ try:
28
+ slice_sizes = []
29
+ for s in slices_str.split(','):
30
+ dims = s.split('x')
31
+ if len(dims) == 1:
32
+ slice_sizes.append(int(dims[0]))
33
+ elif len(dims) == 2:
34
+ slice_sizes.append((int(dims[0]), int(dims[1])))
35
+ else:
36
+ raise ValueError("Each slice must be in 'N' or 'NxM' format.")
37
+ return tuple(slice_sizes)
38
+ except ValueError as e:
39
+ raise ValueError(f"Malformed slice string: '{slices_str}'") from e
40
+
41
+
42
+ def get_prefill_slices() -> Tuple[int, ...]:
43
+ if PREFILL_SLICES not in os.environ:
44
+ return ()
45
+ return _parse_slices(os.environ[PREFILL_SLICES])
46
+
47
+
48
+ def get_decode_slices() -> Tuple[int, ...]:
49
+ if DECODE_SLICES not in os.environ:
50
+ return ()
51
+ return _parse_slices(os.environ[DECODE_SLICES])
File without changes
@@ -0,0 +1,523 @@
1
+ import copy
2
+ from collections import defaultdict, deque
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ from vllm.config import VllmConfig
8
+ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
9
+ from vllm.v1.core.sched.async_scheduler import AsyncScheduler
10
+ from vllm.v1.core.sched.interface import SchedulerInterface
11
+ from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
12
+ SchedulerOutput)
13
+ from vllm.v1.core.sched.scheduler import Scheduler
14
+ from vllm.v1.engine import EngineCoreOutputs
15
+ from vllm.v1.kv_cache_interface import KVCacheConfig
16
+ from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
17
+ from vllm.v1.outputs import ModelRunnerOutput
18
+ from vllm.v1.request import Request
19
+ from vllm.v1.structured_output import StructuredOutputManager
20
+
21
+ from tpu_inference.logger import init_logger
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class DPSchedulerOutput(SchedulerOutput):
28
+ """Extended SchedulerOutput that includes DP rank assignments."""
29
+ assigned_dp_rank: Optional[Dict[str, int]] = None
30
+
31
+ def __init__(self, *args, assigned_dp_rank=None, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.assigned_dp_rank = assigned_dp_rank or {}
34
+
35
+
36
+ class DPScheduler(SchedulerInterface):
37
+ """
38
+ DPScheduler is used when DP size is >=2. Otherwise the default vLLM scheduler is used.
39
+
40
+ The DPScheduler manages:
41
+ 1. Multiple vLLM Schedulers (one per DP rank)
42
+ 2. Request-to-scheduler assignment
43
+
44
+ Each Scheduler manages its own logical KV cache shard and scheduling logic.
45
+
46
+ **Load Balancing**
47
+
48
+ For new requests:
49
+ - If there is prefix cache hit, assigns request to the rank with the best hit
50
+ - Otherwise, assigns request to the rank with the least total tokens
51
+
52
+ Once a DP rank is assigned to a request, it remains fixed for the request's lifetime.
53
+ A request will be freed from its assigned rank when it is completed or preempted.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ vllm_config: VllmConfig,
59
+ kv_cache_config: KVCacheConfig,
60
+ structured_output_manager: StructuredOutputManager,
61
+ block_size: int,
62
+ mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
63
+ include_finished_set: bool = False,
64
+ log_stats: bool = False,
65
+ ) -> None:
66
+ self.vllm_config = vllm_config
67
+ self.block_size = block_size
68
+ self.log_stats = log_stats
69
+ self.connector = None
70
+ self.structured_output_manager = structured_output_manager
71
+
72
+ # DP state
73
+ self.dp_size = vllm_config.sharding_config.total_dp_size
74
+ self.assigned_dp_rank: Dict[str, int] = {} # req_id -> dp_rank
75
+ self.cached_schedulers_output = deque()
76
+ self._create_per_rank_configs(kv_cache_config)
77
+
78
+ # The original scheduler class could be Scheduler or AsyncScheduler
79
+ original_scheduler_cls = vllm_config.scheduler_config._original_scheduler_cls
80
+ self.schedulers: List[Scheduler] = []
81
+ for rank in range(self.dp_size):
82
+ scheduler = original_scheduler_cls(
83
+ vllm_config=self.vllm_config,
84
+ kv_cache_config=self.per_rank_kv_cache_configs[rank],
85
+ structured_output_manager=structured_output_manager,
86
+ block_size=block_size,
87
+ mm_registry=mm_registry,
88
+ include_finished_set=include_finished_set,
89
+ log_stats=log_stats,
90
+ )
91
+ self.schedulers.append(scheduler)
92
+
93
+ logger.info(
94
+ f"DPScheduler (Async = {self.vllm_config.scheduler_config.async_scheduling}) "
95
+ f"per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
96
+ f"max_tokens={self.vllm_config.scheduler_config.max_num_batched_tokens}"
97
+ )
98
+
99
+ def _create_per_rank_configs(self, kv_cache_config: KVCacheConfig) -> None:
100
+ self.per_rank_kv_cache_configs: List[KVCacheConfig] = []
101
+ for _ in range(self.dp_size):
102
+ rank_config = copy.deepcopy(kv_cache_config)
103
+ rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
104
+ self.per_rank_kv_cache_configs.append(rank_config)
105
+
106
+ def _get_rank_token_counts(self) -> Dict[int, int]:
107
+ """Calculate total tokens currently assigned to each DP rank."""
108
+ rank_tokens = {rank: 0 for rank in range(self.dp_size)}
109
+
110
+ for rank, scheduler in enumerate(self.schedulers):
111
+ for request in scheduler.running:
112
+ rank_tokens[rank] += request.num_tokens
113
+ for request in scheduler.waiting:
114
+ rank_tokens[rank] += request.num_tokens
115
+
116
+ return rank_tokens
117
+
118
+ def _find_best_rank_for_request(self, request: Request) -> int:
119
+ """Find the best DP rank for a new request based on load balancing."""
120
+ rank_tokens = self._get_rank_token_counts()
121
+
122
+ # First, try to find a rank with prefix cache hit
123
+ best_cache_rank = None
124
+ best_cache_tokens = 0
125
+ for rank, scheduler in enumerate(self.schedulers):
126
+ blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
127
+ request)
128
+ if cached_tokens > best_cache_tokens:
129
+ best_cache_tokens = cached_tokens
130
+ best_cache_rank = rank
131
+ if best_cache_tokens > 0:
132
+ return best_cache_rank
133
+
134
+ # Otherwise, find rank with least tokens
135
+ selected_rank = min(rank_tokens, key=rank_tokens.get)
136
+ return selected_rank
137
+
138
+ def add_request(self, request: Request) -> None:
139
+ """
140
+ Add a new request to the appropriate DP rank scheduler.
141
+
142
+ This is the main entry point for new requests. The scheduler will:
143
+ 1. Determine the best DP rank for the request (load balancing + cache hits)
144
+ 2. Assign the request to that rank
145
+ 3. Add the request to the rank's scheduler
146
+ """
147
+ assert request.request_id not in self.assigned_dp_rank, (
148
+ f"Request {request.request_id} already "
149
+ f"assigned to rank {self.assigned_dp_rank[request.request_id]})")
150
+ rank = self._find_best_rank_for_request(request)
151
+ self.assigned_dp_rank[request.request_id] = rank
152
+ self.schedulers[rank].add_request(request)
153
+
154
+ def schedule(self) -> DPSchedulerOutput:
155
+ """
156
+ Main scheduling method that coordinates all DP rank schedulers.
157
+
158
+ Process:
159
+ 1. Add any new requests to appropriate DP ranks
160
+ 2. Run each scheduler independently
161
+ 3. Combine outputs from all schedulers
162
+ 4. Return unified scheduling result
163
+ """
164
+ # Run each scheduler independently
165
+ rank_outputs = []
166
+ for rank, scheduler in enumerate(self.schedulers):
167
+ logger.debug(
168
+ f"Running scheduler for rank {rank}: "
169
+ f"{len(scheduler.running)} running, {len(scheduler.waiting)} waiting"
170
+ )
171
+ output = scheduler.schedule()
172
+ rank_outputs.append(output)
173
+
174
+ # Cache scheduler outputs to use in `update_from_output`
175
+ self.cached_schedulers_output.append(rank_outputs)
176
+
177
+ # Return combined scheduler outputs
178
+ combined_output = self._combine_scheduler_outputs(rank_outputs)
179
+
180
+ logger.debug(
181
+ f"DPScheduler scheduled: "
182
+ f"{combined_output.total_num_scheduled_tokens} total tokens, "
183
+ f"{len(combined_output.scheduled_new_reqs)} new requests, "
184
+ f"{len(combined_output.scheduled_cached_reqs.req_ids)} cached requests"
185
+ )
186
+
187
+ return combined_output
188
+
189
+ def _combine_scheduler_outputs(
190
+ self, rank_outputs: List[SchedulerOutput]) -> DPSchedulerOutput:
191
+ """Combine outputs from all DP rank schedulers into a unified output."""
192
+
193
+ # Combine new requests
194
+ all_new_reqs = []
195
+ for output in rank_outputs:
196
+ all_new_reqs.extend(output.scheduled_new_reqs)
197
+
198
+ # Combine cached request data
199
+ combined_cached_data = self._combine_cached_request_data(rank_outputs)
200
+
201
+ # Combine token counts and other metrics
202
+ combined_num_scheduled_tokens = {}
203
+ combined_spec_decode_tokens = {}
204
+ combined_encoder_inputs = {}
205
+ total_scheduled_tokens = 0
206
+
207
+ for output in rank_outputs:
208
+ combined_num_scheduled_tokens.update(output.num_scheduled_tokens)
209
+ combined_spec_decode_tokens.update(
210
+ output.scheduled_spec_decode_tokens)
211
+ combined_encoder_inputs.update(output.scheduled_encoder_inputs)
212
+ total_scheduled_tokens += output.total_num_scheduled_tokens
213
+
214
+ # Combine finished request IDs
215
+ combined_finished_req_ids = set()
216
+ for output in rank_outputs:
217
+ combined_finished_req_ids.update(output.finished_req_ids)
218
+
219
+ # Combine other fields (take from first non-empty or use defaults)
220
+ num_common_prefix_blocks = rank_outputs[
221
+ 0].num_common_prefix_blocks if rank_outputs else []
222
+
223
+ # Create DP rank assignment mapping for scheduled requests
224
+ assigned_dp_rank = {}
225
+ for req_id in combined_num_scheduled_tokens.keys():
226
+ assigned_dp_rank[req_id] = self.assigned_dp_rank[req_id]
227
+
228
+ return DPSchedulerOutput(
229
+ scheduled_new_reqs=all_new_reqs,
230
+ scheduled_cached_reqs=combined_cached_data,
231
+ num_scheduled_tokens=combined_num_scheduled_tokens,
232
+ total_num_scheduled_tokens=total_scheduled_tokens,
233
+ scheduled_spec_decode_tokens=combined_spec_decode_tokens,
234
+ scheduled_encoder_inputs=combined_encoder_inputs,
235
+ num_common_prefix_blocks=num_common_prefix_blocks,
236
+ finished_req_ids=combined_finished_req_ids,
237
+ free_encoder_mm_hashes=set(),
238
+ assigned_dp_rank=assigned_dp_rank,
239
+ )
240
+
241
+ def _combine_cached_request_data(
242
+ self, rank_outputs: List[SchedulerOutput]) -> CachedRequestData:
243
+ """Combine cached request data from all DP rank schedulers."""
244
+ combined_req_ids = []
245
+ combined_resumed_req_ids = []
246
+ combined_new_token_ids = []
247
+ combined_all_token_ids = []
248
+ combined_new_block_ids = []
249
+ combined_num_computed_tokens = []
250
+ combined_num_output_tokens = []
251
+
252
+ for output in rank_outputs:
253
+ cached_data = output.scheduled_cached_reqs
254
+
255
+ combined_req_ids.extend(cached_data.req_ids)
256
+ combined_resumed_req_ids.extend(cached_data.resumed_req_ids)
257
+ combined_new_token_ids.extend(cached_data.new_token_ids)
258
+ combined_all_token_ids.extend(cached_data.all_token_ids)
259
+ combined_new_block_ids.extend(cached_data.new_block_ids)
260
+ combined_num_computed_tokens.extend(
261
+ cached_data.num_computed_tokens)
262
+ combined_num_output_tokens.extend(cached_data.num_output_tokens)
263
+
264
+ return CachedRequestData(
265
+ req_ids=combined_req_ids,
266
+ resumed_req_ids=combined_resumed_req_ids,
267
+ new_token_ids=combined_new_token_ids,
268
+ all_token_ids=combined_all_token_ids,
269
+ new_block_ids=combined_new_block_ids,
270
+ num_computed_tokens=combined_num_computed_tokens,
271
+ num_output_tokens=combined_num_output_tokens,
272
+ )
273
+
274
+ def get_grammar_bitmask(
275
+ self,
276
+ scheduler_output: DPSchedulerOutput,
277
+ ) -> GrammarOutput | None:
278
+ """
279
+ Generate grammar bitmask for structured output requests across all DP ranks.
280
+
281
+ This method calls get_grammar_bitmask on each underlying scheduler and
282
+ combines their outputs, similar to how other operations are handled.
283
+ """
284
+ # Use the most recent cached outputs from the schedule() call
285
+ if not self.cached_schedulers_output:
286
+ return None
287
+
288
+ rank_scheduler_outputs = self.cached_schedulers_output[
289
+ -1] # Get the most recent
290
+
291
+ combined_structured_output_request_ids = []
292
+ combined_bitmasks = []
293
+
294
+ # Get grammar bitmask from each DP rank scheduler
295
+ for rank, scheduler in enumerate(self.schedulers):
296
+ rank_output = rank_scheduler_outputs[rank]
297
+ grammar_output = scheduler.get_grammar_bitmask(rank_output)
298
+
299
+ if grammar_output is not None:
300
+ combined_structured_output_request_ids.extend(
301
+ grammar_output.structured_output_request_ids)
302
+ combined_bitmasks.append(grammar_output.grammar_bitmask)
303
+
304
+ if not combined_structured_output_request_ids:
305
+ return None
306
+
307
+ # Combine bitmasks - concatenate along the batch dimension
308
+ if len(combined_bitmasks) == 1:
309
+ combined_bitmask = combined_bitmasks[0]
310
+ else:
311
+ combined_bitmask = torch.cat(combined_bitmasks, dim=0)
312
+
313
+ return GrammarOutput(combined_structured_output_request_ids,
314
+ combined_bitmask)
315
+
316
+ def update_from_output(
317
+ self, scheduler_output: DPSchedulerOutput,
318
+ model_runner_output: ModelRunnerOutput
319
+ ) -> dict[int, EngineCoreOutputs]:
320
+ """
321
+ Update all DP rank schedulers based on model runner output.
322
+
323
+ We need to route the model runner output to the appropriate scheduler
324
+ based on which rank each request belongs to.
325
+ """
326
+ # Group model runner outputs by DP rank
327
+ rank_model_outputs = self._split_model_output_by_rank(
328
+ model_runner_output)
329
+ rank_scheduler_outputs = self.cached_schedulers_output.popleft()
330
+ # Update each scheduler with its portion of the output
331
+ combined_engine_outputs = defaultdict(list)
332
+ for rank, scheduler in enumerate(self.schedulers):
333
+ rank_engine_outputs = scheduler.update_from_output(
334
+ rank_scheduler_outputs[rank], rank_model_outputs[rank])
335
+ for client_idx, engine_output in rank_engine_outputs.items():
336
+ combined_engine_outputs[client_idx].append(engine_output)
337
+
338
+ # Clean up finished requests from DP tracking
339
+ self._cleanup_finished_requests(scheduler_output.finished_req_ids)
340
+
341
+ # Return combined EngineCoreOutput
342
+ for client_idx, engine_outputs in combined_engine_outputs.items():
343
+ combined_output = EngineCoreOutputs()
344
+ outputs = []
345
+ finished_requests = set()
346
+ for engine_output in engine_outputs:
347
+ outputs.extend(engine_output.outputs)
348
+ if engine_output.finished_requests:
349
+ finished_requests.update(engine_output.finished_requests)
350
+ combined_output.engine_index = engine_outputs[0].engine_index
351
+ combined_output.outputs = outputs
352
+ combined_output.finished_requests = finished_requests
353
+ combined_output.scheduler_stats = self.make_stats()
354
+ combined_engine_outputs[client_idx] = combined_output
355
+
356
+ return combined_engine_outputs
357
+
358
+ def _split_model_output_by_rank(
359
+ self,
360
+ global_model_output: ModelRunnerOutput) -> List[ModelRunnerOutput]:
361
+ """Split the model runner output by DP rank for individual scheduler updates."""
362
+ outputs = [
363
+ ModelRunnerOutput(
364
+ req_ids=[],
365
+ req_id_to_index=global_model_output.req_id_to_index,
366
+ sampled_token_ids=global_model_output.sampled_token_ids,
367
+ logprobs=global_model_output.logprobs,
368
+ prompt_logprobs_dict=global_model_output.prompt_logprobs_dict,
369
+ pooler_output=None,
370
+ num_nans_in_logits=global_model_output.num_nans_in_logits,
371
+ kv_connector_output=global_model_output.kv_connector_output,
372
+ ) for _ in range(self.dp_size)
373
+ ]
374
+
375
+ for req_id in global_model_output.req_ids:
376
+ rank = self.assigned_dp_rank[req_id]
377
+ outputs[rank].req_ids.append(req_id)
378
+
379
+ return outputs
380
+
381
+ def _cleanup_finished_requests(self, finished_req_ids: set[str]) -> None:
382
+ """Remove finished requests from our DP rank assignment tracking."""
383
+ for req_id in finished_req_ids:
384
+ if req_id in self.assigned_dp_rank:
385
+ del self.assigned_dp_rank[req_id]
386
+
387
+ def finish_requests(self, request_ids, finished_status) -> None:
388
+ """Forward request finish signals to the appropriate DP rank schedulers."""
389
+ if isinstance(request_ids, str):
390
+ request_ids = [request_ids]
391
+
392
+ # Route finish signals to appropriate schedulers
393
+ rank_request_ids = defaultdict(list)
394
+ for req_id in request_ids:
395
+ rank = self.assigned_dp_rank[req_id]
396
+ rank_request_ids[rank].append(req_id)
397
+
398
+ # Forward to each scheduler
399
+ for rank, req_ids in rank_request_ids.items():
400
+ self.schedulers[rank].finish_requests(req_ids, finished_status)
401
+
402
+ def get_num_unfinished_requests(self) -> int:
403
+ """Get total number of unfinished requests across all DP ranks."""
404
+ return sum(scheduler.get_num_unfinished_requests()
405
+ for scheduler in self.schedulers)
406
+
407
+ def has_finished_requests(self) -> bool:
408
+ """Check if any DP rank has finished requests."""
409
+ return any(scheduler.has_finished_requests()
410
+ for scheduler in self.schedulers)
411
+
412
+ def get_request_counts(self) -> Tuple[int, int]:
413
+ """Get total (running, waiting) request counts across all DP ranks."""
414
+ total_running = sum(
415
+ len(scheduler.running) for scheduler in self.schedulers)
416
+ total_waiting = sum(
417
+ len(scheduler.waiting) for scheduler in self.schedulers)
418
+ return total_running, total_waiting
419
+
420
+ def reset_prefix_cache(self) -> bool:
421
+ """Reset prefix cache for all DP rank schedulers."""
422
+ return all(scheduler.reset_prefix_cache()
423
+ for scheduler in self.schedulers)
424
+
425
+ def make_stats(self,
426
+ spec_decoding_stats=None,
427
+ kv_connector_stats=None) -> Optional[SchedulerStats]:
428
+ """Combine stats from all DP rank schedulers."""
429
+ if not self.log_stats:
430
+ return None
431
+
432
+ # Aggregate stats from all schedulers
433
+ total_running_reqs = 0
434
+ total_waiting_reqs = 0
435
+ total_kv_cache_usage = 0.0
436
+
437
+ combined_prefix_cache_stats = PrefixCacheStats()
438
+ combined_connector_prefix_cache_stats: Optional[
439
+ PrefixCacheStats] = None
440
+
441
+ for scheduler in self.schedulers:
442
+ rank_stats = scheduler.make_stats(spec_decoding_stats,
443
+ kv_connector_stats)
444
+ if rank_stats is None:
445
+ continue
446
+
447
+ total_running_reqs += rank_stats.num_running_reqs
448
+ total_waiting_reqs += rank_stats.num_waiting_reqs
449
+ total_kv_cache_usage += rank_stats.kv_cache_usage
450
+
451
+ # Combine prefix cache stats
452
+ if rank_stats.prefix_cache_stats:
453
+ combined_prefix_cache_stats.reset = rank_stats.prefix_cache_stats.reset
454
+ combined_prefix_cache_stats.requests += rank_stats.prefix_cache_stats.requests
455
+ combined_prefix_cache_stats.queries += rank_stats.prefix_cache_stats.queries
456
+ combined_prefix_cache_stats.hits += rank_stats.prefix_cache_stats.hits
457
+
458
+ # Combine connector prefix cache stats
459
+ if rank_stats.connector_prefix_cache_stats:
460
+ if combined_connector_prefix_cache_stats is None:
461
+ combined_connector_prefix_cache_stats = PrefixCacheStats()
462
+ combined_connector_prefix_cache_stats.reset = rank_stats.connector_prefix_cache_stats.reset
463
+ combined_connector_prefix_cache_stats.requests += rank_stats.connector_prefix_cache_stats.requests
464
+ combined_connector_prefix_cache_stats.queries += rank_stats.connector_prefix_cache_stats.queries
465
+ combined_connector_prefix_cache_stats.hits += rank_stats.connector_prefix_cache_stats.hits
466
+
467
+ # Average KV cache usage across ranks
468
+ avg_kv_cache_usage = total_kv_cache_usage / len(
469
+ self.schedulers) if self.schedulers else 0.0
470
+
471
+ return SchedulerStats(
472
+ num_running_reqs=total_running_reqs,
473
+ num_waiting_reqs=total_waiting_reqs,
474
+ kv_cache_usage=avg_kv_cache_usage,
475
+ prefix_cache_stats=combined_prefix_cache_stats,
476
+ connector_prefix_cache_stats=combined_connector_prefix_cache_stats,
477
+ spec_decoding_stats=spec_decoding_stats,
478
+ kv_connector_stats=kv_connector_stats.data
479
+ if kv_connector_stats else None,
480
+ )
481
+
482
+ def update_draft_token_ids(self, draft_token_ids) -> None:
483
+ """Forward draft token updates to the appropriate DP rank schedulers."""
484
+ # Group draft tokens by DP rank based on request assignments
485
+ rank_draft_tokens = defaultdict(lambda: {
486
+ "req_ids": [],
487
+ "draft_token_ids": []
488
+ })
489
+
490
+ for req_id, tokens in zip(draft_token_ids.req_ids,
491
+ draft_token_ids.draft_token_ids):
492
+ if req_id in self.assigned_dp_rank:
493
+ rank = self.assigned_dp_rank[req_id]
494
+ rank_draft_tokens[rank]["req_ids"].append(req_id)
495
+ rank_draft_tokens[rank]["draft_token_ids"].append(tokens)
496
+
497
+ # Forward to each scheduler
498
+ for rank, draft_data in rank_draft_tokens.items():
499
+ # Create a draft_token_ids object for this rank (mock structure)
500
+ rank_draft_token_ids = type(draft_token_ids)(
501
+ req_ids=draft_data["req_ids"],
502
+ draft_token_ids=draft_data["draft_token_ids"])
503
+ self.schedulers[rank].update_draft_token_ids(rank_draft_token_ids)
504
+
505
+ def shutdown(self) -> None:
506
+ """Shutdown all DP rank schedulers."""
507
+ for scheduler in self.schedulers:
508
+ scheduler.shutdown()
509
+
510
+
511
+ def update_vllm_config_for_dp_scheduler(vllm_config: Any) -> None:
512
+ """
513
+ Update vLLM configuration to use DPScheduler when DP size > 1.
514
+ """
515
+ dp_size = vllm_config.sharding_config.total_dp_size
516
+
517
+ if dp_size > 1:
518
+ if vllm_config.scheduler_config.async_scheduling:
519
+ vllm_config.scheduler_config._original_scheduler_cls = AsyncScheduler
520
+ else:
521
+ vllm_config.scheduler_config._original_scheduler_cls = Scheduler
522
+
523
+ vllm_config.scheduler_config.scheduler_cls = DPScheduler
File without changes