tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,27 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
16
|
+
import multiprocessing.reduction
|
|
2
17
|
from collections import defaultdict, deque
|
|
3
18
|
from dataclasses import dataclass
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from multiprocessing import Process, Queue
|
|
21
|
+
from time import time
|
|
4
22
|
from typing import Any, Dict, List, Optional, Tuple
|
|
5
23
|
|
|
24
|
+
import cloudpickle
|
|
6
25
|
import torch
|
|
7
26
|
from vllm.config import VllmConfig
|
|
8
27
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
@@ -19,10 +38,186 @@ from vllm.v1.request import Request
|
|
|
19
38
|
from vllm.v1.structured_output import StructuredOutputManager
|
|
20
39
|
|
|
21
40
|
from tpu_inference.logger import init_logger
|
|
41
|
+
from tpu_inference.utils import time_function
|
|
22
42
|
|
|
23
43
|
logger = init_logger(__name__)
|
|
24
44
|
|
|
25
45
|
|
|
46
|
+
class SchedulerCommand(Enum):
|
|
47
|
+
"""Enum for scheduler worker process commands."""
|
|
48
|
+
ADD_REQUEST = "add_request"
|
|
49
|
+
SCHEDULE = "schedule"
|
|
50
|
+
FINISH_REQUESTS = "finish_requests"
|
|
51
|
+
UPDATE_DRAFT_TOKEN_IDS = "update_draft_token_ids"
|
|
52
|
+
UPDATE_FROM_OUTPUT = "update_from_output"
|
|
53
|
+
GET_GRAMMAR_BITMASK = "get_grammar_bitmask"
|
|
54
|
+
MAKE_STATS = "make_stats"
|
|
55
|
+
RESET_PREFIX_CACHE = "reset_prefix_cache"
|
|
56
|
+
GET_NUM_UNFINISHED_REQUESTS = "get_num_unfinished_requests"
|
|
57
|
+
HAS_FINISHED_REQUESTS = "has_finished_requests"
|
|
58
|
+
GET_REQUEST_COUNTS = "get_request_counts"
|
|
59
|
+
GET_TOKEN_COUNT = "get_token_count"
|
|
60
|
+
GET_COMPUTED_BLOCKS = "get_computed_blocks"
|
|
61
|
+
SHUTDOWN = "shutdown"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SchedulerWorkerError(Exception):
|
|
65
|
+
"""Exception raised when a scheduler worker process encounters an error."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, rank: int, message: str):
|
|
68
|
+
self.rank = rank
|
|
69
|
+
self.message = message
|
|
70
|
+
super().__init__(f"Scheduler worker {rank} error: {message}")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Monkey-patch multiprocessing to use cloudpickle
|
|
74
|
+
# Standard pickle fails to serialize the vLLM Request object.
|
|
75
|
+
_original_dumps = multiprocessing.reduction.ForkingPickler.dumps
|
|
76
|
+
_original_loads = multiprocessing.reduction.ForkingPickler.loads
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _cloudpickle_dumps(obj, protocol=None):
|
|
80
|
+
"""Use cloudpickle for serialization."""
|
|
81
|
+
return cloudpickle.dumps(obj, protocol=protocol)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _cloudpickle_loads(data):
|
|
85
|
+
"""Use cloudpickle for deserialization."""
|
|
86
|
+
return cloudpickle.loads(data)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _enable_cloudpickle():
|
|
90
|
+
"""Enable cloudpickle for multiprocessing queues."""
|
|
91
|
+
multiprocessing.reduction.ForkingPickler.dumps = staticmethod(
|
|
92
|
+
_cloudpickle_dumps)
|
|
93
|
+
multiprocessing.reduction.ForkingPickler.loads = staticmethod(
|
|
94
|
+
_cloudpickle_loads)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _disable_cloudpickle():
|
|
98
|
+
"""Restore original pickle for multiprocessing."""
|
|
99
|
+
multiprocessing.reduction.ForkingPickler.dumps = _original_dumps
|
|
100
|
+
multiprocessing.reduction.ForkingPickler.loads = _original_loads
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _scheduler_worker_process(
|
|
104
|
+
rank: int,
|
|
105
|
+
input_queue: Queue,
|
|
106
|
+
output_queues: Dict[str, Queue],
|
|
107
|
+
vllm_config: Any,
|
|
108
|
+
kv_cache_config: Any,
|
|
109
|
+
structured_output_manager: Any,
|
|
110
|
+
block_size: int,
|
|
111
|
+
mm_registry: Any,
|
|
112
|
+
include_finished_set: bool,
|
|
113
|
+
log_stats: bool,
|
|
114
|
+
original_scheduler_cls: type,
|
|
115
|
+
):
|
|
116
|
+
"""Worker process that manages a single scheduler instance."""
|
|
117
|
+
# Initialize the scheduler in this process
|
|
118
|
+
scheduler = original_scheduler_cls(
|
|
119
|
+
vllm_config=vllm_config,
|
|
120
|
+
kv_cache_config=kv_cache_config,
|
|
121
|
+
structured_output_manager=structured_output_manager,
|
|
122
|
+
block_size=block_size,
|
|
123
|
+
mm_registry=mm_registry,
|
|
124
|
+
include_finished_set=include_finished_set,
|
|
125
|
+
log_stats=log_stats,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
logger.debug(f"Scheduler worker process {rank} started")
|
|
129
|
+
|
|
130
|
+
# Process commands from the input queue
|
|
131
|
+
while True:
|
|
132
|
+
try:
|
|
133
|
+
command, data = input_queue.get()
|
|
134
|
+
|
|
135
|
+
match command:
|
|
136
|
+
case SchedulerCommand.ADD_REQUEST:
|
|
137
|
+
request = data
|
|
138
|
+
scheduler.add_request(request)
|
|
139
|
+
output_queues[command.value].put(None) # Signal completion
|
|
140
|
+
|
|
141
|
+
case SchedulerCommand.SCHEDULE:
|
|
142
|
+
output = scheduler.schedule()
|
|
143
|
+
output_queues[command.value].put(output)
|
|
144
|
+
|
|
145
|
+
case SchedulerCommand.FINISH_REQUESTS:
|
|
146
|
+
request_ids, finished_status = data
|
|
147
|
+
scheduler.finish_requests(request_ids, finished_status)
|
|
148
|
+
output_queues[command.value].put(None) # Signal completion
|
|
149
|
+
|
|
150
|
+
case SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS:
|
|
151
|
+
draft_token_ids = data
|
|
152
|
+
scheduler.update_draft_token_ids(draft_token_ids)
|
|
153
|
+
output_queues[command.value].put(None) # Signal completion
|
|
154
|
+
|
|
155
|
+
case SchedulerCommand.UPDATE_FROM_OUTPUT:
|
|
156
|
+
scheduler_output, model_runner_output = data
|
|
157
|
+
result = scheduler.update_from_output(
|
|
158
|
+
scheduler_output, model_runner_output)
|
|
159
|
+
output_queues[command.value].put(result)
|
|
160
|
+
|
|
161
|
+
case SchedulerCommand.GET_GRAMMAR_BITMASK:
|
|
162
|
+
scheduler_output = data
|
|
163
|
+
result = scheduler.get_grammar_bitmask(scheduler_output)
|
|
164
|
+
output_queues[command.value].put(result)
|
|
165
|
+
|
|
166
|
+
case SchedulerCommand.MAKE_STATS:
|
|
167
|
+
spec_decoding_stats, kv_connector_stats = data
|
|
168
|
+
result = scheduler.make_stats(spec_decoding_stats,
|
|
169
|
+
kv_connector_stats)
|
|
170
|
+
output_queues[command.value].put(result)
|
|
171
|
+
|
|
172
|
+
case SchedulerCommand.RESET_PREFIX_CACHE:
|
|
173
|
+
result = scheduler.reset_prefix_cache()
|
|
174
|
+
output_queues[command.value].put(result)
|
|
175
|
+
|
|
176
|
+
case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS:
|
|
177
|
+
result = scheduler.get_num_unfinished_requests()
|
|
178
|
+
output_queues[command.value].put(result)
|
|
179
|
+
|
|
180
|
+
case SchedulerCommand.HAS_FINISHED_REQUESTS:
|
|
181
|
+
result = scheduler.has_finished_requests()
|
|
182
|
+
output_queues[command.value].put(result)
|
|
183
|
+
|
|
184
|
+
case SchedulerCommand.GET_REQUEST_COUNTS:
|
|
185
|
+
running = len(scheduler.running)
|
|
186
|
+
waiting = len(scheduler.waiting)
|
|
187
|
+
output_queues[command.value].put((running, waiting))
|
|
188
|
+
|
|
189
|
+
case SchedulerCommand.GET_TOKEN_COUNT:
|
|
190
|
+
# Calculate total tokens across running and waiting requests
|
|
191
|
+
total_tokens = 0
|
|
192
|
+
for req in scheduler.running:
|
|
193
|
+
total_tokens += len(req.all_token_ids)
|
|
194
|
+
for req in scheduler.waiting:
|
|
195
|
+
total_tokens += len(req.all_token_ids)
|
|
196
|
+
output_queues[command.value].put(total_tokens)
|
|
197
|
+
|
|
198
|
+
case SchedulerCommand.GET_COMPUTED_BLOCKS:
|
|
199
|
+
request = data
|
|
200
|
+
blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
|
|
201
|
+
request)
|
|
202
|
+
output_queues[command.value].put((blocks, cached_tokens))
|
|
203
|
+
|
|
204
|
+
case SchedulerCommand.SHUTDOWN:
|
|
205
|
+
scheduler.shutdown()
|
|
206
|
+
output_queues[command.value].put(None) # Signal completion
|
|
207
|
+
break
|
|
208
|
+
case _:
|
|
209
|
+
error = SchedulerWorkerError(
|
|
210
|
+
rank, f"Unknown command: {command}")
|
|
211
|
+
output_queues[command.value].put(error)
|
|
212
|
+
raise error
|
|
213
|
+
|
|
214
|
+
except Exception as e:
|
|
215
|
+
logger.error(f"Error in scheduler worker {rank}: {e}",
|
|
216
|
+
exc_info=True)
|
|
217
|
+
error = SchedulerWorkerError(rank, str(e))
|
|
218
|
+
output_queues[command.value].put(error)
|
|
219
|
+
|
|
220
|
+
|
|
26
221
|
@dataclass
|
|
27
222
|
class DPSchedulerOutput(SchedulerOutput):
|
|
28
223
|
"""Extended SchedulerOutput that includes DP rank assignments."""
|
|
@@ -77,22 +272,50 @@ class DPScheduler(SchedulerInterface):
|
|
|
77
272
|
|
|
78
273
|
# The original scheduler class could be Scheduler or AsyncScheduler
|
|
79
274
|
original_scheduler_cls = vllm_config.scheduler_config._original_scheduler_cls
|
|
80
|
-
|
|
275
|
+
|
|
276
|
+
# Enable cloudpickle for multiprocessing to handle local functions
|
|
277
|
+
_enable_cloudpickle()
|
|
278
|
+
|
|
279
|
+
# Create worker processes with separate output queues for each command type
|
|
280
|
+
import multiprocessing
|
|
281
|
+
ctx = multiprocessing.get_context('fork')
|
|
282
|
+
self.input_queues: List[Queue] = []
|
|
283
|
+
self.output_queues: Dict[Tuple[int, str], Queue] = {}
|
|
284
|
+
self.processes: List[Process] = []
|
|
285
|
+
|
|
81
286
|
for rank in range(self.dp_size):
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
287
|
+
input_queue = ctx.Queue()
|
|
288
|
+
self.input_queues.append(input_queue)
|
|
289
|
+
|
|
290
|
+
output_queues_for_rank: Dict[str, Queue] = {}
|
|
291
|
+
for cmd in SchedulerCommand:
|
|
292
|
+
output_queues_for_rank[cmd.value] = ctx.Queue()
|
|
293
|
+
self.output_queues[(
|
|
294
|
+
rank, cmd.value)] = output_queues_for_rank[cmd.value]
|
|
295
|
+
|
|
296
|
+
process = ctx.Process(
|
|
297
|
+
target=_scheduler_worker_process,
|
|
298
|
+
args=(
|
|
299
|
+
rank,
|
|
300
|
+
input_queue,
|
|
301
|
+
output_queues_for_rank,
|
|
302
|
+
self.vllm_config,
|
|
303
|
+
self.per_rank_kv_cache_configs[rank],
|
|
304
|
+
structured_output_manager,
|
|
305
|
+
block_size,
|
|
306
|
+
mm_registry,
|
|
307
|
+
include_finished_set,
|
|
308
|
+
log_stats,
|
|
309
|
+
original_scheduler_cls,
|
|
310
|
+
),
|
|
90
311
|
)
|
|
91
|
-
|
|
312
|
+
process.start()
|
|
313
|
+
self.processes.append(process)
|
|
92
314
|
|
|
93
315
|
logger.info(
|
|
94
316
|
f"DPScheduler (Async = {self.vllm_config.scheduler_config.async_scheduling}) "
|
|
95
|
-
f"
|
|
317
|
+
f"started {self.dp_size} worker processes with cloudpickle. "
|
|
318
|
+
f"Per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
|
|
96
319
|
f"max_tokens={self.vllm_config.scheduler_config.max_num_batched_tokens}"
|
|
97
320
|
)
|
|
98
321
|
|
|
@@ -103,15 +326,39 @@ class DPScheduler(SchedulerInterface):
|
|
|
103
326
|
rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
|
|
104
327
|
self.per_rank_kv_cache_configs.append(rank_config)
|
|
105
328
|
|
|
329
|
+
def _get_result_from_queue(self, rank: int,
|
|
330
|
+
command: SchedulerCommand) -> Any:
|
|
331
|
+
"""Get result from the output queue for a specific rank and command type."""
|
|
332
|
+
queue_obj = self.output_queues[(rank, command.value)]
|
|
333
|
+
try:
|
|
334
|
+
start_time = time()
|
|
335
|
+
result = queue_obj.get()
|
|
336
|
+
end_time = time()
|
|
337
|
+
if end_time - start_time > 1.0:
|
|
338
|
+
logger.warning(
|
|
339
|
+
f"Long wait time ({end_time - start_time:.2f}s) for rank {rank} "
|
|
340
|
+
f"command {command.value} response.")
|
|
341
|
+
except EOFError as e:
|
|
342
|
+
raise RuntimeError(
|
|
343
|
+
f"Queue error for rank {rank} command {command.value}: "
|
|
344
|
+
"Worker process terminated unexpectedly. "
|
|
345
|
+
"This may indicate a crash in the scheduler worker process."
|
|
346
|
+
) from e
|
|
347
|
+
if isinstance(result, SchedulerWorkerError):
|
|
348
|
+
raise result
|
|
349
|
+
return result
|
|
350
|
+
|
|
106
351
|
def _get_rank_token_counts(self) -> Dict[int, int]:
|
|
107
352
|
"""Calculate total tokens currently assigned to each DP rank."""
|
|
108
|
-
|
|
353
|
+
for rank in range(self.dp_size):
|
|
354
|
+
self.input_queues[rank].put(
|
|
355
|
+
(SchedulerCommand.GET_TOKEN_COUNT, None))
|
|
109
356
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
357
|
+
rank_tokens = {}
|
|
358
|
+
for rank in range(self.dp_size):
|
|
359
|
+
token_count = self._get_result_from_queue(
|
|
360
|
+
rank, SchedulerCommand.GET_TOKEN_COUNT)
|
|
361
|
+
rank_tokens[rank] = token_count
|
|
115
362
|
|
|
116
363
|
return rank_tokens
|
|
117
364
|
|
|
@@ -120,11 +367,15 @@ class DPScheduler(SchedulerInterface):
|
|
|
120
367
|
rank_tokens = self._get_rank_token_counts()
|
|
121
368
|
|
|
122
369
|
# First, try to find a rank with prefix cache hit
|
|
370
|
+
for rank in range(self.dp_size):
|
|
371
|
+
self.input_queues[rank].put(
|
|
372
|
+
(SchedulerCommand.GET_COMPUTED_BLOCKS, request))
|
|
373
|
+
|
|
123
374
|
best_cache_rank = None
|
|
124
375
|
best_cache_tokens = 0
|
|
125
|
-
for rank
|
|
126
|
-
blocks, cached_tokens =
|
|
127
|
-
|
|
376
|
+
for rank in range(self.dp_size):
|
|
377
|
+
blocks, cached_tokens = self._get_result_from_queue(
|
|
378
|
+
rank, SchedulerCommand.GET_COMPUTED_BLOCKS)
|
|
128
379
|
if cached_tokens > best_cache_tokens:
|
|
129
380
|
best_cache_tokens = cached_tokens
|
|
130
381
|
best_cache_rank = rank
|
|
@@ -149,26 +400,30 @@ class DPScheduler(SchedulerInterface):
|
|
|
149
400
|
f"assigned to rank {self.assigned_dp_rank[request.request_id]})")
|
|
150
401
|
rank = self._find_best_rank_for_request(request)
|
|
151
402
|
self.assigned_dp_rank[request.request_id] = rank
|
|
152
|
-
self.schedulers[rank].add_request(request)
|
|
153
403
|
|
|
404
|
+
self.input_queues[rank].put((SchedulerCommand.ADD_REQUEST, request))
|
|
405
|
+
self._get_result_from_queue(rank, SchedulerCommand.ADD_REQUEST)
|
|
406
|
+
|
|
407
|
+
@time_function
|
|
154
408
|
def schedule(self) -> DPSchedulerOutput:
|
|
155
409
|
"""
|
|
156
410
|
Main scheduling method that coordinates all DP rank schedulers.
|
|
157
411
|
|
|
158
412
|
Process:
|
|
159
413
|
1. Add any new requests to appropriate DP ranks
|
|
160
|
-
2. Run each scheduler independently
|
|
414
|
+
2. Run each scheduler independently in parallel
|
|
161
415
|
3. Combine outputs from all schedulers
|
|
162
416
|
4. Return unified scheduling result
|
|
163
417
|
"""
|
|
164
418
|
# Run each scheduler independently
|
|
419
|
+
for rank in range(self.dp_size):
|
|
420
|
+
self.input_queues[rank].put((SchedulerCommand.SCHEDULE, None))
|
|
421
|
+
|
|
422
|
+
# Collect outputs from all workers (blocking)
|
|
165
423
|
rank_outputs = []
|
|
166
|
-
for rank
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
f"{len(scheduler.running)} running, {len(scheduler.waiting)} waiting"
|
|
170
|
-
)
|
|
171
|
-
output = scheduler.schedule()
|
|
424
|
+
for rank in range(self.dp_size):
|
|
425
|
+
output = self._get_result_from_queue(rank,
|
|
426
|
+
SchedulerCommand.SCHEDULE)
|
|
172
427
|
rank_outputs.append(output)
|
|
173
428
|
|
|
174
429
|
# Cache scheduler outputs to use in `update_from_output`
|
|
@@ -292,10 +547,12 @@ class DPScheduler(SchedulerInterface):
|
|
|
292
547
|
combined_bitmasks = []
|
|
293
548
|
|
|
294
549
|
# Get grammar bitmask from each DP rank scheduler
|
|
295
|
-
for rank
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
550
|
+
for rank in range(self.dp_size):
|
|
551
|
+
self.input_queues[rank].put((SchedulerCommand.GET_GRAMMAR_BITMASK,
|
|
552
|
+
rank_scheduler_outputs[rank]))
|
|
553
|
+
for rank in range(self.dp_size):
|
|
554
|
+
grammar_output = self._get_result_from_queue(
|
|
555
|
+
rank, SchedulerCommand.GET_GRAMMAR_BITMASK)
|
|
299
556
|
if grammar_output is not None:
|
|
300
557
|
combined_structured_output_request_ids.extend(
|
|
301
558
|
grammar_output.structured_output_request_ids)
|
|
@@ -328,10 +585,15 @@ class DPScheduler(SchedulerInterface):
|
|
|
328
585
|
model_runner_output)
|
|
329
586
|
rank_scheduler_outputs = self.cached_schedulers_output.popleft()
|
|
330
587
|
# Update each scheduler with its portion of the output
|
|
588
|
+
for rank in range(self.dp_size):
|
|
589
|
+
self.input_queues[rank].put(
|
|
590
|
+
(SchedulerCommand.UPDATE_FROM_OUTPUT,
|
|
591
|
+
(rank_scheduler_outputs[rank], rank_model_outputs[rank])))
|
|
592
|
+
|
|
331
593
|
combined_engine_outputs = defaultdict(list)
|
|
332
|
-
for rank
|
|
333
|
-
rank_engine_outputs =
|
|
334
|
-
|
|
594
|
+
for rank in range(self.dp_size):
|
|
595
|
+
rank_engine_outputs = self._get_result_from_queue(
|
|
596
|
+
rank, SchedulerCommand.UPDATE_FROM_OUTPUT)
|
|
335
597
|
for client_idx, engine_output in rank_engine_outputs.items():
|
|
336
598
|
combined_engine_outputs[client_idx].append(engine_output)
|
|
337
599
|
|
|
@@ -397,30 +659,62 @@ class DPScheduler(SchedulerInterface):
|
|
|
397
659
|
|
|
398
660
|
# Forward to each scheduler
|
|
399
661
|
for rank, req_ids in rank_request_ids.items():
|
|
400
|
-
self.
|
|
662
|
+
self.input_queues[rank].put(
|
|
663
|
+
(SchedulerCommand.FINISH_REQUESTS, (req_ids, finished_status)))
|
|
664
|
+
self._get_result_from_queue(rank, SchedulerCommand.FINISH_REQUESTS)
|
|
401
665
|
|
|
402
666
|
def get_num_unfinished_requests(self) -> int:
|
|
403
667
|
"""Get total number of unfinished requests across all DP ranks."""
|
|
404
|
-
|
|
405
|
-
|
|
668
|
+
for rank in range(self.dp_size):
|
|
669
|
+
self.input_queues[rank].put(
|
|
670
|
+
(SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
|
|
671
|
+
|
|
672
|
+
total = 0
|
|
673
|
+
for rank in range(self.dp_size):
|
|
674
|
+
count = self._get_result_from_queue(
|
|
675
|
+
rank, SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS)
|
|
676
|
+
total += count
|
|
677
|
+
return total
|
|
406
678
|
|
|
407
679
|
def has_finished_requests(self) -> bool:
|
|
408
680
|
"""Check if any DP rank has finished requests."""
|
|
409
|
-
|
|
410
|
-
|
|
681
|
+
for rank in range(self.dp_size):
|
|
682
|
+
self.input_queues[rank].put(
|
|
683
|
+
(SchedulerCommand.HAS_FINISHED_REQUESTS, None))
|
|
684
|
+
|
|
685
|
+
has_finished_any = False
|
|
686
|
+
for rank in range(self.dp_size):
|
|
687
|
+
has_finished_any |= self._get_result_from_queue(
|
|
688
|
+
rank, SchedulerCommand.HAS_FINISHED_REQUESTS)
|
|
689
|
+
return has_finished_any
|
|
411
690
|
|
|
412
691
|
def get_request_counts(self) -> Tuple[int, int]:
|
|
413
692
|
"""Get total (running, waiting) request counts across all DP ranks."""
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
693
|
+
for rank in range(self.dp_size):
|
|
694
|
+
self.input_queues[rank].put(
|
|
695
|
+
(SchedulerCommand.GET_REQUEST_COUNTS, None))
|
|
696
|
+
|
|
697
|
+
total_running = 0
|
|
698
|
+
total_waiting = 0
|
|
699
|
+
for rank in range(self.dp_size):
|
|
700
|
+
running, waiting = self._get_result_from_queue(
|
|
701
|
+
rank, SchedulerCommand.GET_REQUEST_COUNTS)
|
|
702
|
+
total_running += running
|
|
703
|
+
total_waiting += waiting
|
|
418
704
|
return total_running, total_waiting
|
|
419
705
|
|
|
420
706
|
def reset_prefix_cache(self) -> bool:
|
|
421
707
|
"""Reset prefix cache for all DP rank schedulers."""
|
|
422
|
-
|
|
423
|
-
|
|
708
|
+
for rank in range(self.dp_size):
|
|
709
|
+
self.input_queues[rank].put(
|
|
710
|
+
(SchedulerCommand.RESET_PREFIX_CACHE, None))
|
|
711
|
+
|
|
712
|
+
all_success = True
|
|
713
|
+
for rank in range(self.dp_size):
|
|
714
|
+
success = self._get_result_from_queue(
|
|
715
|
+
rank, SchedulerCommand.RESET_PREFIX_CACHE)
|
|
716
|
+
all_success &= success
|
|
717
|
+
return all_success
|
|
424
718
|
|
|
425
719
|
def make_stats(self,
|
|
426
720
|
spec_decoding_stats=None,
|
|
@@ -438,9 +732,14 @@ class DPScheduler(SchedulerInterface):
|
|
|
438
732
|
combined_connector_prefix_cache_stats: Optional[
|
|
439
733
|
PrefixCacheStats] = None
|
|
440
734
|
|
|
441
|
-
for
|
|
442
|
-
|
|
443
|
-
|
|
735
|
+
for rank in range(self.dp_size):
|
|
736
|
+
self.input_queues[rank].put(
|
|
737
|
+
(SchedulerCommand.MAKE_STATS, (spec_decoding_stats,
|
|
738
|
+
kv_connector_stats)))
|
|
739
|
+
|
|
740
|
+
for rank in range(self.dp_size):
|
|
741
|
+
rank_stats = self._get_result_from_queue(
|
|
742
|
+
rank, SchedulerCommand.MAKE_STATS)
|
|
444
743
|
if rank_stats is None:
|
|
445
744
|
continue
|
|
446
745
|
|
|
@@ -465,8 +764,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
465
764
|
combined_connector_prefix_cache_stats.hits += rank_stats.connector_prefix_cache_stats.hits
|
|
466
765
|
|
|
467
766
|
# Average KV cache usage across ranks
|
|
468
|
-
avg_kv_cache_usage = total_kv_cache_usage /
|
|
469
|
-
self.schedulers) if self.schedulers else 0.0
|
|
767
|
+
avg_kv_cache_usage = total_kv_cache_usage / self.dp_size if self.dp_size else 0.0
|
|
470
768
|
|
|
471
769
|
return SchedulerStats(
|
|
472
770
|
num_running_reqs=total_running_reqs,
|
|
@@ -494,18 +792,36 @@ class DPScheduler(SchedulerInterface):
|
|
|
494
792
|
rank_draft_tokens[rank]["req_ids"].append(req_id)
|
|
495
793
|
rank_draft_tokens[rank]["draft_token_ids"].append(tokens)
|
|
496
794
|
|
|
497
|
-
# Forward to each scheduler
|
|
498
795
|
for rank, draft_data in rank_draft_tokens.items():
|
|
499
796
|
# Create a draft_token_ids object for this rank (mock structure)
|
|
500
797
|
rank_draft_token_ids = type(draft_token_ids)(
|
|
501
798
|
req_ids=draft_data["req_ids"],
|
|
502
799
|
draft_token_ids=draft_data["draft_token_ids"])
|
|
503
|
-
self.
|
|
800
|
+
self.input_queues[rank].put(
|
|
801
|
+
(SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS,
|
|
802
|
+
rank_draft_token_ids))
|
|
803
|
+
self._get_result_from_queue(
|
|
804
|
+
rank, SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS)
|
|
504
805
|
|
|
505
806
|
def shutdown(self) -> None:
|
|
506
|
-
"""Shutdown all DP rank
|
|
507
|
-
|
|
508
|
-
|
|
807
|
+
"""Shutdown all DP rank scheduler worker processes."""
|
|
808
|
+
# Send shutdown command to all workers
|
|
809
|
+
for rank in range(self.dp_size):
|
|
810
|
+
self.input_queues[rank].put((SchedulerCommand.SHUTDOWN, None))
|
|
811
|
+
|
|
812
|
+
# Wait for acknowledgment (blocking)
|
|
813
|
+
for rank in range(self.dp_size):
|
|
814
|
+
self._get_result_from_queue(rank, SchedulerCommand.SHUTDOWN)
|
|
815
|
+
|
|
816
|
+
# Terminate and join all processes
|
|
817
|
+
for process in self.processes:
|
|
818
|
+
process.join(timeout=5.0)
|
|
819
|
+
if process.is_alive():
|
|
820
|
+
process.terminate()
|
|
821
|
+
process.join()
|
|
822
|
+
|
|
823
|
+
# Restore original pickle
|
|
824
|
+
_disable_cloudpickle()
|
|
509
825
|
|
|
510
826
|
|
|
511
827
|
def update_vllm_config_for_dp_scheduler(vllm_config: Any) -> None:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Any, Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -88,7 +88,7 @@ if TYPE_CHECKING:
|
|
|
88
88
|
from tpu_inference import envs
|
|
89
89
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
90
90
|
get_kv_ports,
|
|
91
|
-
get_kv_transfer_port,
|
|
91
|
+
get_kv_transfer_port,
|
|
92
92
|
get_side_channel_port)
|
|
93
93
|
from tpu_inference.logger import init_logger
|
|
94
94
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
@@ -442,10 +442,10 @@ class TPUConnectorWorker:
|
|
|
442
442
|
self.runner: TPUModelRunner = None
|
|
443
443
|
self.mesh: Mesh = None
|
|
444
444
|
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
445
|
-
#
|
|
446
|
-
#
|
|
447
|
-
#
|
|
448
|
-
self.node_id =
|
|
445
|
+
# default value for none distributed scenario
|
|
446
|
+
# when the topology is initialized, runner will update it
|
|
447
|
+
# based on topology_order_id
|
|
448
|
+
self.node_id = 0
|
|
449
449
|
|
|
450
450
|
# req_id: (kv, expiration_time)
|
|
451
451
|
self.reqs_wait_pull: dict[ReqId, list[list[jax.Array], float]] = {}
|
|
@@ -457,7 +457,6 @@ class TPUConnectorWorker:
|
|
|
457
457
|
self.side_channel_port = get_side_channel_port()
|
|
458
458
|
|
|
459
459
|
self.kv_transfer_server = None
|
|
460
|
-
self._maybe_start_p2p_server()
|
|
461
460
|
self.zmq_cxt = zmq.Context()
|
|
462
461
|
if self.is_producer:
|
|
463
462
|
ready_event = threading.Event()
|
|
@@ -473,7 +472,7 @@ class TPUConnectorWorker:
|
|
|
473
472
|
self.pull_conns: dict[str, Any] = {}
|
|
474
473
|
self.notif_sockets: dict[str, zmq.Socket] = {}
|
|
475
474
|
|
|
476
|
-
logger.info(f"TPUConnector Worker
|
|
475
|
+
logger.info(f"TPUConnector Worker --> init | "
|
|
477
476
|
f"ip={self.host_ip} | "
|
|
478
477
|
f"kv_transfer_port={self.kv_transfer_port} | "
|
|
479
478
|
f"side_channel_port={self.side_channel_port}")
|
|
@@ -489,6 +488,7 @@ class TPUConnectorWorker:
|
|
|
489
488
|
self.zmq_cxt.destroy(linger=0)
|
|
490
489
|
|
|
491
490
|
def register_runner(self, runner: TPUModelRunner):
|
|
491
|
+
self.node_id = runner.topology_order_id
|
|
492
492
|
self.runner = runner
|
|
493
493
|
self.mesh = runner.mesh
|
|
494
494
|
|
|
@@ -499,6 +499,11 @@ class TPUConnectorWorker:
|
|
|
499
499
|
self.shape = list(kv_layer.shape)
|
|
500
500
|
self.dtype = kv_layer.dtype
|
|
501
501
|
self.sharding = kv_layer.sharding
|
|
502
|
+
logger.info(f"TPUConnector Worker --> register_runner | "
|
|
503
|
+
f"node_id={self.node_id} | "
|
|
504
|
+
f"ip={self.host_ip} | "
|
|
505
|
+
f"kv_transfer_port={self.kv_transfer_port}")
|
|
506
|
+
self._maybe_start_p2p_server()
|
|
502
507
|
|
|
503
508
|
def _maybe_start_p2p_server(self):
|
|
504
509
|
if self.kv_transfer_server is not None:
|
|
@@ -694,9 +699,9 @@ class TPUConnectorWorker:
|
|
|
694
699
|
|
|
695
700
|
def get_uuid() -> int:
|
|
696
701
|
int128 = uuid4().int
|
|
697
|
-
# Must be 64-bit int, otherwise vllm output encoder would raise error.
|
|
698
|
-
|
|
699
|
-
return
|
|
702
|
+
# Must be less than 64-bit int, otherwise vllm output encoder would raise error.
|
|
703
|
+
# use 50 bit to avoid GO trunk the int when doing JSon serialization
|
|
704
|
+
return int128 >> 78
|
|
700
705
|
|
|
701
706
|
|
|
702
707
|
@jax.jit
|