tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,786 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import functools
|
|
3
|
+
import itertools
|
|
4
|
+
import math
|
|
5
|
+
import os
|
|
6
|
+
import queue
|
|
7
|
+
import signal
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
import traceback
|
|
11
|
+
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
# ======================================================================================
|
|
15
|
+
# Imports for DisaggEngineCoreProc (the vLLM adapter)
|
|
16
|
+
# ======================================================================================
|
|
17
|
+
from vllm.config import VllmConfig
|
|
18
|
+
from vllm.logger import init_logger
|
|
19
|
+
from vllm.tasks import POOLING_TASKS, SupportedTask
|
|
20
|
+
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
|
21
|
+
init_none_hash)
|
|
22
|
+
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
|
23
|
+
EngineCoreRequestType, UtilityOutput,
|
|
24
|
+
UtilityResult)
|
|
25
|
+
from vllm.v1.engine.core import EngineCore as vLLMEngineCore
|
|
26
|
+
from vllm.v1.engine.core import EngineCoreProc as vLLMEngineCoreProc
|
|
27
|
+
from vllm.v1.executor.abstract import Executor
|
|
28
|
+
from vllm.v1.request import Request, RequestStatus
|
|
29
|
+
|
|
30
|
+
from tpu_inference import utils as common_utils
|
|
31
|
+
from tpu_inference.core import disagg_executor, disagg_utils
|
|
32
|
+
from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
|
|
33
|
+
# ======================================================================================
|
|
34
|
+
# Imports for _DisaggOrchestrator (decoupled from vLLM)
|
|
35
|
+
# ======================================================================================
|
|
36
|
+
from tpu_inference.runner.utils import LatencyTracker
|
|
37
|
+
|
|
38
|
+
# This file contains two classes:
|
|
39
|
+
# 1. _DisaggOrchestrator: The clean, decoupled core orchestration logic.
|
|
40
|
+
# 2. DisaggEngineCoreProc: The vLLM-facing adapter that handles process management.
|
|
41
|
+
|
|
42
|
+
logger = init_logger(__name__)
|
|
43
|
+
|
|
44
|
+
POLLING_TIMEOUT_S = 2.5
|
|
45
|
+
HANDSHAKE_TIMEOUT_MINS = 5
|
|
46
|
+
|
|
47
|
+
_R = TypeVar('_R') # Return type for collective_rpc
|
|
48
|
+
|
|
49
|
+
# ======================================================================================
|
|
50
|
+
# Class 1: The Decoupled Orchestrator
|
|
51
|
+
# ======================================================================================
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class JetThread(threading.Thread):
|
|
55
|
+
"""Thread that kills the program if it fails.
|
|
56
|
+
|
|
57
|
+
If a driver thread goes down, we can't operate.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def run(self):
|
|
61
|
+
try:
|
|
62
|
+
super().run()
|
|
63
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
64
|
+
print(f"Thread {self.name} encountered an error: {e}")
|
|
65
|
+
traceback.print_exc()
|
|
66
|
+
os.kill(os.getpid(), signal.SIGKILL)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class _DisaggOrchestrator:
|
|
70
|
+
"""Contains the core orchestration logic, decoupled from vLLM."""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
config: VllmConfig,
|
|
75
|
+
output_queue: queue.Queue,
|
|
76
|
+
prefill_engines: list[vLLMEngineCore],
|
|
77
|
+
decode_engines: list[vLLMEngineCore],
|
|
78
|
+
prefill_slice_sizes: tuple[int, ...],
|
|
79
|
+
decode_slice_sizes: tuple[int, ...],
|
|
80
|
+
):
|
|
81
|
+
self._config = config
|
|
82
|
+
self._output_queue = output_queue
|
|
83
|
+
self._prefill_engines = prefill_engines
|
|
84
|
+
self._decode_engines = decode_engines
|
|
85
|
+
|
|
86
|
+
# Keep track of active requests.
|
|
87
|
+
self._requests: dict[str, Request] = {}
|
|
88
|
+
|
|
89
|
+
# Hack device config to pass in the subslice of TPUs.
|
|
90
|
+
slice_sizes = list(prefill_slice_sizes)
|
|
91
|
+
slice_sizes.extend(decode_slice_sizes)
|
|
92
|
+
|
|
93
|
+
self._transfer_backlogs = [
|
|
94
|
+
queue.Queue(4) for i in range(len(self._prefill_engines))
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
self._decode_backlogs = {}
|
|
98
|
+
for idx, engine in enumerate(self._decode_engines):
|
|
99
|
+
# Determine the decode backlog len by remaning hbm dividing max kv cache size of a single request
|
|
100
|
+
runner = engine.model_executor.driver_worker.model_runner
|
|
101
|
+
hbm_usage = common_utils.hbm_usage_bytes(
|
|
102
|
+
engine.model_executor.driver_worker.devices)
|
|
103
|
+
if not hbm_usage:
|
|
104
|
+
self._decode_backlogs[idx] = queue.Queue(
|
|
105
|
+
self._config.scheduler_config.max_num_seqs)
|
|
106
|
+
continue
|
|
107
|
+
hbm_free = [limit - used for used, limit in hbm_usage]
|
|
108
|
+
max_kv_bytes = len(runner.kv_caches) * (
|
|
109
|
+
runner.max_model_len // runner.cache_config.block_size) * (
|
|
110
|
+
runner.kv_caches[0][0].nbytes) // len(hbm_free)
|
|
111
|
+
max_queue_len = min(hbm_free[0] // max_kv_bytes,
|
|
112
|
+
self._config.scheduler_config.max_num_seqs)
|
|
113
|
+
logger.debug(
|
|
114
|
+
f"max kv bytes: {max_kv_bytes}, max_queue_len {max_queue_len}")
|
|
115
|
+
self._decode_backlogs[idx] = queue.Queue(max_queue_len)
|
|
116
|
+
|
|
117
|
+
self._prefill_threads = [
|
|
118
|
+
JetThread(
|
|
119
|
+
target=functools.partial(self._prefill, idx),
|
|
120
|
+
name=f"prefill-{idx}",
|
|
121
|
+
daemon=True,
|
|
122
|
+
) for idx in range(len(self._prefill_engines))
|
|
123
|
+
]
|
|
124
|
+
self._transfer_threads = [
|
|
125
|
+
JetThread(
|
|
126
|
+
target=functools.partial(
|
|
127
|
+
self._transfer,
|
|
128
|
+
idx,
|
|
129
|
+
),
|
|
130
|
+
name=f"transfer-{idx}",
|
|
131
|
+
daemon=True,
|
|
132
|
+
) for idx in range(len(self._prefill_engines))
|
|
133
|
+
]
|
|
134
|
+
self._decode_threads = [
|
|
135
|
+
JetThread(
|
|
136
|
+
target=functools.partial(
|
|
137
|
+
self._decode,
|
|
138
|
+
idx,
|
|
139
|
+
),
|
|
140
|
+
name=f"decode-{idx}",
|
|
141
|
+
daemon=True,
|
|
142
|
+
) for idx in range(len(self._decode_engines))
|
|
143
|
+
]
|
|
144
|
+
self._all_threads = list(
|
|
145
|
+
itertools.chain(
|
|
146
|
+
self._prefill_threads,
|
|
147
|
+
self._transfer_threads,
|
|
148
|
+
self._decode_threads,
|
|
149
|
+
))
|
|
150
|
+
self.live = True
|
|
151
|
+
# Start all threads
|
|
152
|
+
for t in self._all_threads:
|
|
153
|
+
t.start()
|
|
154
|
+
|
|
155
|
+
def add_request(self, request: Request):
|
|
156
|
+
"""
|
|
157
|
+
Adds a new request to the orchestrator.
|
|
158
|
+
|
|
159
|
+
This is the main entry point for new work. It stores the request for
|
|
160
|
+
internal state tracking and hands it off to the first stage of the
|
|
161
|
+
processing pipeline (the prefill scheduler).
|
|
162
|
+
"""
|
|
163
|
+
# Hand off the request to the prefill scheduler to be batched for execution.
|
|
164
|
+
self._prefill_engines[0].scheduler.add_request(request)
|
|
165
|
+
|
|
166
|
+
# Add to internal state for tracking by other threads.
|
|
167
|
+
# The key is the request_id, the value is the request object.
|
|
168
|
+
self._requests[request.request_id] = request
|
|
169
|
+
|
|
170
|
+
def _prefill(self, idx: int):
|
|
171
|
+
prefill_engine = self._prefill_engines[idx]
|
|
172
|
+
transfer_backlog = self._transfer_backlogs[idx]
|
|
173
|
+
|
|
174
|
+
while self.live:
|
|
175
|
+
if not prefill_engine.scheduler.has_requests():
|
|
176
|
+
time.sleep(0.05)
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
scheduler_output = prefill_engine.scheduler.schedule()
|
|
180
|
+
with LatencyTracker(f"prefill-{idx}"):
|
|
181
|
+
future = prefill_engine.model_executor.execute_model(
|
|
182
|
+
scheduler_output, non_block=True)
|
|
183
|
+
grammar_output = prefill_engine.scheduler.get_grammar_bitmask(
|
|
184
|
+
scheduler_output)
|
|
185
|
+
with prefill_engine.log_error_detail(scheduler_output):
|
|
186
|
+
model_output = future.result()
|
|
187
|
+
if model_output is None:
|
|
188
|
+
model_output = prefill_engine.model_executor.sample_tokens(
|
|
189
|
+
grammar_output)
|
|
190
|
+
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
191
|
+
model_output = model_output.get_output()
|
|
192
|
+
|
|
193
|
+
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
194
|
+
logger.debug(f"Prefill result: {model_output}")
|
|
195
|
+
|
|
196
|
+
kv_cache_map: dict[str, Tuple(list[jax.Array], list[Any])] = {}
|
|
197
|
+
for req_id, idx in model_output.req_id_to_index.items():
|
|
198
|
+
if len(model_output.sampled_token_ids[idx]) > 0:
|
|
199
|
+
request = self._requests[req_id]
|
|
200
|
+
block_ids = (prefill_engine.scheduler.kv_cache_manager.
|
|
201
|
+
get_block_ids(req_id))
|
|
202
|
+
# Assume one KV cache group for now.
|
|
203
|
+
kv_cache_map[req_id] = (
|
|
204
|
+
prefill_engine.model_executor.driver_worker.
|
|
205
|
+
model_runner.get_kv_cache_for_block_ids(
|
|
206
|
+
block_ids[0]), request.block_hashes)
|
|
207
|
+
logger.debug(f"prefill done: for {req_id}")
|
|
208
|
+
transfer_backlog.put(kv_cache_map, block=True)
|
|
209
|
+
|
|
210
|
+
# tweak model_output to let the scheduler know kv_transfer is done for requests, so they can be freed.
|
|
211
|
+
engine_core_outputs = prefill_engine.scheduler.update_from_output(
|
|
212
|
+
scheduler_output, model_output) # type: ignore
|
|
213
|
+
|
|
214
|
+
for req_id, idx in model_output.req_id_to_index.items():
|
|
215
|
+
if len(model_output.sampled_token_ids[idx]) > 0:
|
|
216
|
+
request = self._requests[req_id]
|
|
217
|
+
logger.debug(
|
|
218
|
+
f"request block_hashes at prefill: {request.block_hashes}"
|
|
219
|
+
)
|
|
220
|
+
logger.debug(
|
|
221
|
+
f"request-{req_id}: tokens={request.all_token_ids} after prefill"
|
|
222
|
+
)
|
|
223
|
+
# Remove request from the prefill engine.
|
|
224
|
+
if req_id in prefill_engine.scheduler.requests:
|
|
225
|
+
request = prefill_engine.scheduler.requests[req_id]
|
|
226
|
+
prefill_engine.scheduler.running.remove(request)
|
|
227
|
+
prefill_engine.scheduler.encoder_cache_manager.free(
|
|
228
|
+
request)
|
|
229
|
+
|
|
230
|
+
prefill_engine.scheduler.kv_cache_manager.free(
|
|
231
|
+
request)
|
|
232
|
+
|
|
233
|
+
prefill_engine.scheduler.requests.pop(req_id)
|
|
234
|
+
|
|
235
|
+
for output in (engine_core_outputs.items()
|
|
236
|
+
if engine_core_outputs else ()):
|
|
237
|
+
self._output_queue.put_nowait(output)
|
|
238
|
+
|
|
239
|
+
def _transfer(self, idx: int):
|
|
240
|
+
"""Transfers the kv cache on an active request to the least full
|
|
241
|
+
decode backlog."""
|
|
242
|
+
transfer_backlog = self._transfer_backlogs[idx]
|
|
243
|
+
while self.live:
|
|
244
|
+
# The transfer thread can just sleep until it has work to do.
|
|
245
|
+
kv_cachce_map = transfer_backlog.get(block=True)
|
|
246
|
+
if kv_cachce_map is None:
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
logger.debug(
|
|
250
|
+
f"transfer-{idx}: KV Cache items received: {kv_cachce_map.keys()}"
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
push_targets = []
|
|
254
|
+
for req_id, (kv_cache, block_hashes) in kv_cachce_map.items():
|
|
255
|
+
target_idx = -1
|
|
256
|
+
cnt = 9999999
|
|
257
|
+
for i, e in enumerate(self._decode_engines):
|
|
258
|
+
req_cnt = sum(e.scheduler.get_request_counts())
|
|
259
|
+
if req_cnt < cnt:
|
|
260
|
+
cnt = req_cnt
|
|
261
|
+
target_idx = i
|
|
262
|
+
|
|
263
|
+
# Only transfer the KVCache for the disaggregated serving.
|
|
264
|
+
with LatencyTracker("KVCacheTransfer"):
|
|
265
|
+
kv_cache = self._decode_engines[
|
|
266
|
+
target_idx].model_executor.driver_worker.model_runner.transfer_kv_cache(
|
|
267
|
+
kv_cache)
|
|
268
|
+
|
|
269
|
+
# TODO(fhzhang): Now how do we get the kv cache to the decode engine?
|
|
270
|
+
prefill_output = {
|
|
271
|
+
"cache": kv_cache,
|
|
272
|
+
"req_id": req_id,
|
|
273
|
+
"block_hashes": block_hashes,
|
|
274
|
+
}
|
|
275
|
+
push_targets.append((target_idx, prefill_output))
|
|
276
|
+
|
|
277
|
+
for target_idx, prefill_output in push_targets:
|
|
278
|
+
self._decode_backlogs[target_idx].put(prefill_output,
|
|
279
|
+
block=True)
|
|
280
|
+
logger.debug(
|
|
281
|
+
"Successfully transferred prefill request %s "
|
|
282
|
+
"from prefill engine %d to decode engine %d. decode backlog len %d",
|
|
283
|
+
prefill_output["req_id"],
|
|
284
|
+
idx,
|
|
285
|
+
target_idx,
|
|
286
|
+
self._decode_backlogs[target_idx].qsize(),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def _decode(self, idx: int):
|
|
290
|
+
decode_engine = self._decode_engines[idx]
|
|
291
|
+
decode_backlog = self._decode_backlogs[idx]
|
|
292
|
+
|
|
293
|
+
while self.live:
|
|
294
|
+
block = not decode_engine.scheduler.has_requests()
|
|
295
|
+
|
|
296
|
+
while True:
|
|
297
|
+
# We need to check input batch as well as the request completion is delayed
|
|
298
|
+
# from scheduler to the runner.
|
|
299
|
+
if (sum(decode_engine.scheduler.get_request_counts())
|
|
300
|
+
>= self._config.scheduler_config.max_num_seqs
|
|
301
|
+
or decode_engine.model_executor.driver_worker.
|
|
302
|
+
model_runner.input_batch.num_reqs
|
|
303
|
+
>= self._config.scheduler_config.max_num_seqs):
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
prefill_output = decode_backlog.get(block=block,
|
|
308
|
+
timeout=1.0)
|
|
309
|
+
except queue.Empty:
|
|
310
|
+
if block:
|
|
311
|
+
continue
|
|
312
|
+
break
|
|
313
|
+
|
|
314
|
+
if prefill_output is None:
|
|
315
|
+
logger.debug(
|
|
316
|
+
f"decode-{idx} Empty output, and we are idle, exiting..."
|
|
317
|
+
)
|
|
318
|
+
break
|
|
319
|
+
|
|
320
|
+
# We got a request, set block to False
|
|
321
|
+
block = False
|
|
322
|
+
|
|
323
|
+
# Insert the request to the decoder.
|
|
324
|
+
req_id = prefill_output["req_id"]
|
|
325
|
+
vllm_request = self._requests[req_id]
|
|
326
|
+
# Caching num_computed_tokens. The tokens in kv manager allocate blocks
|
|
327
|
+
# is computed as num_computed_tokens + num_new_tokens, so without caching
|
|
328
|
+
# the token number would double.
|
|
329
|
+
prompt_tokens = vllm_request.num_computed_tokens
|
|
330
|
+
vllm_request.num_computed_tokens = 0
|
|
331
|
+
kv_cache = prefill_output["cache"]
|
|
332
|
+
|
|
333
|
+
kv_cache_manager = decode_engine.scheduler.kv_cache_manager
|
|
334
|
+
kv_cache_manager.allocate_slots(
|
|
335
|
+
vllm_request,
|
|
336
|
+
prompt_tokens,
|
|
337
|
+
)
|
|
338
|
+
vllm_request.num_computed_tokens = prompt_tokens
|
|
339
|
+
new_block_ids = kv_cache_manager.get_block_ids(req_id)
|
|
340
|
+
logger.debug(
|
|
341
|
+
f"inserting {req_id} new_block_ids {new_block_ids}")
|
|
342
|
+
if len(new_block_ids[0]) != math.ceil(
|
|
343
|
+
prompt_tokens / self._config.cache_config.block_size):
|
|
344
|
+
logger.warning("Running out of blocks in decode engine! ")
|
|
345
|
+
break
|
|
346
|
+
|
|
347
|
+
decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
|
|
348
|
+
vllm_request, kv_cache, new_block_ids)
|
|
349
|
+
|
|
350
|
+
vllm_request.status = RequestStatus.RUNNING
|
|
351
|
+
block_hashes = prefill_output["block_hashes"]
|
|
352
|
+
vllm_request.block_hashes = block_hashes
|
|
353
|
+
decode_engine.scheduler.running.append(vllm_request)
|
|
354
|
+
decode_engine.scheduler.requests[req_id] = vllm_request
|
|
355
|
+
|
|
356
|
+
self._requests.pop(req_id)
|
|
357
|
+
|
|
358
|
+
scheduler_output = decode_engine.scheduler.schedule()
|
|
359
|
+
|
|
360
|
+
logger.debug(f'''decode-{idx}: scheduler_output -
|
|
361
|
+
{scheduler_output.scheduled_cached_reqs.num_computed_tokens},
|
|
362
|
+
new block ids - {scheduler_output.scheduled_cached_reqs.new_block_ids}'''
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
with LatencyTracker(f"decode-{idx}"):
|
|
366
|
+
future = decode_engine.model_executor.execute_model(
|
|
367
|
+
scheduler_output, non_block=True)
|
|
368
|
+
grammar_output = decode_engine.scheduler.get_grammar_bitmask(
|
|
369
|
+
scheduler_output)
|
|
370
|
+
with decode_engine.log_error_detail(scheduler_output):
|
|
371
|
+
model_output = future.result()
|
|
372
|
+
if model_output is None:
|
|
373
|
+
model_output = decode_engine.model_executor.sample_tokens(
|
|
374
|
+
grammar_output)
|
|
375
|
+
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
376
|
+
model_output = model_output.get_output()
|
|
377
|
+
|
|
378
|
+
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
379
|
+
logger.debug(f"Decode result: {model_output}")
|
|
380
|
+
|
|
381
|
+
engine_core_outputs = decode_engine.scheduler.update_from_output(
|
|
382
|
+
scheduler_output, model_output) # type: ignore
|
|
383
|
+
for output in (engine_core_outputs.items()
|
|
384
|
+
if engine_core_outputs else ()):
|
|
385
|
+
self._output_queue.put_nowait(output)
|
|
386
|
+
|
|
387
|
+
def shutdown(self):
|
|
388
|
+
for e in self._prefill_engines:
|
|
389
|
+
e.shutdown()
|
|
390
|
+
for e in self._decode_engines:
|
|
391
|
+
e.shutdown()
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
# ======================================================================================
|
|
395
|
+
# Class 2: The vLLM-Facing Adapter
|
|
396
|
+
# ======================================================================================
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def _create_engine_cores(
|
|
400
|
+
slice_sizes: tuple[int, ...],
|
|
401
|
+
vllm_config: VllmConfig,
|
|
402
|
+
log_stats: bool,
|
|
403
|
+
executor_fail_callback: Optional[Callable] = None,
|
|
404
|
+
) -> list[vLLMEngineCore]:
|
|
405
|
+
engine_cores = []
|
|
406
|
+
for _ in slice_sizes:
|
|
407
|
+
engine_core = vLLMEngineCore(
|
|
408
|
+
vllm_config,
|
|
409
|
+
disagg_executor.DisaggExecutor,
|
|
410
|
+
log_stats,
|
|
411
|
+
executor_fail_callback,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
engine_cores.append(engine_core)
|
|
415
|
+
logger.warning("Disaggregated engine core created.")
|
|
416
|
+
|
|
417
|
+
return engine_cores
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _get_slice_sizes(devices):
|
|
421
|
+
prefill_slice_sizes = disagg_utils.get_prefill_slices()
|
|
422
|
+
decode_slice_sizes = disagg_utils.get_decode_slices()
|
|
423
|
+
if isinstance(prefill_slice_sizes[0], int):
|
|
424
|
+
prefill_chip_cnt = sum(prefill_slice_sizes)
|
|
425
|
+
else:
|
|
426
|
+
prefill_chip_cnt = sum([math.prod(t) for t in prefill_slice_sizes])
|
|
427
|
+
if isinstance(decode_slice_sizes[0], int):
|
|
428
|
+
decode_chip_cnt = sum(decode_slice_sizes)
|
|
429
|
+
else:
|
|
430
|
+
decode_chip_cnt = sum([math.prod(t) for t in decode_slice_sizes])
|
|
431
|
+
assert decode_chip_cnt + prefill_chip_cnt <= len(devices)
|
|
432
|
+
assert prefill_chip_cnt > 0 and decode_chip_cnt > 0
|
|
433
|
+
|
|
434
|
+
slice_sizes = list(prefill_slice_sizes)
|
|
435
|
+
slice_sizes.extend(decode_slice_sizes)
|
|
436
|
+
return prefill_slice_sizes, decode_slice_sizes, slice_sizes
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class DisaggEngineCore(vLLMEngineCore):
|
|
440
|
+
"""The vLLM-facing adapter that handles process management and I/O. Modifes vLLMEngineCore and is only used in in-process EngineCore client."""
|
|
441
|
+
|
|
442
|
+
@staticmethod
|
|
443
|
+
def is_supported() -> bool:
|
|
444
|
+
"""
|
|
445
|
+
Returns True if this engine can run in the current environment.
|
|
446
|
+
"""
|
|
447
|
+
return disagg_utils.is_disagg_enabled()
|
|
448
|
+
|
|
449
|
+
def __init__(
|
|
450
|
+
self,
|
|
451
|
+
vllm_config: VllmConfig,
|
|
452
|
+
executor_class: type[Executor],
|
|
453
|
+
log_stats: bool,
|
|
454
|
+
executor_fail_callback: Optional[Callable] = None,
|
|
455
|
+
):
|
|
456
|
+
self.vllm_config = vllm_config
|
|
457
|
+
|
|
458
|
+
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
|
459
|
+
bytes]]()
|
|
460
|
+
|
|
461
|
+
self.devices = jax.devices()
|
|
462
|
+
device_kind = self.devices[0].device_kind
|
|
463
|
+
if device_kind != 'TPU7x':
|
|
464
|
+
self.vllm_config.cache_config.gpu_memory_utilization = (
|
|
465
|
+
self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
|
|
466
|
+
prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
|
|
467
|
+
self.devices)
|
|
468
|
+
|
|
469
|
+
if isinstance(slice_sizes[0], int):
|
|
470
|
+
setattr(vllm_config.device_config, "slice",
|
|
471
|
+
(0, slice_sizes, self.devices))
|
|
472
|
+
else:
|
|
473
|
+
setattr(vllm_config.device_config, "slice",
|
|
474
|
+
((0, 0), 0, slice_sizes, self.devices))
|
|
475
|
+
logger.info(
|
|
476
|
+
f"Creating DisaggEngineCore with slice_sizes {slice_sizes}...")
|
|
477
|
+
|
|
478
|
+
self._prefill_engines = _create_engine_cores(
|
|
479
|
+
prefill_slice_sizes,
|
|
480
|
+
vllm_config,
|
|
481
|
+
log_stats,
|
|
482
|
+
executor_fail_callback,
|
|
483
|
+
)
|
|
484
|
+
logger.info(
|
|
485
|
+
f"{len(self._prefill_engines)} Disaggregated prefill engines created."
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
self._decode_engines = _create_engine_cores(
|
|
489
|
+
decode_slice_sizes,
|
|
490
|
+
vllm_config,
|
|
491
|
+
log_stats,
|
|
492
|
+
executor_fail_callback,
|
|
493
|
+
)
|
|
494
|
+
logger.info(
|
|
495
|
+
f"{len(self._decode_engines)} Disaggregated decode engines created."
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
self.batch_queue = None
|
|
499
|
+
|
|
500
|
+
self.request_block_hasher = None
|
|
501
|
+
if (self.vllm_config.cache_config.enable_prefix_caching
|
|
502
|
+
or self._prefill_engines[0].scheduler.get_kv_connector()
|
|
503
|
+
is not None):
|
|
504
|
+
|
|
505
|
+
block_size = vllm_config.cache_config.block_size
|
|
506
|
+
caching_hash_fn = common_utils.get_hash_fn_by_name(
|
|
507
|
+
vllm_config.cache_config.prefix_caching_hash_algo)
|
|
508
|
+
init_none_hash(caching_hash_fn)
|
|
509
|
+
|
|
510
|
+
self.request_block_hasher = get_request_block_hasher(
|
|
511
|
+
block_size, caching_hash_fn)
|
|
512
|
+
|
|
513
|
+
self.step_fn = (self.step if self.batch_queue is None else
|
|
514
|
+
self.step_with_batch_queue)
|
|
515
|
+
|
|
516
|
+
self.mm_receiver_cache = None
|
|
517
|
+
self._orchestrator = _DisaggOrchestrator(
|
|
518
|
+
config=vllm_config,
|
|
519
|
+
output_queue=self.output_queue,
|
|
520
|
+
prefill_engines=self._prefill_engines,
|
|
521
|
+
decode_engines=self._decode_engines,
|
|
522
|
+
prefill_slice_sizes=prefill_slice_sizes,
|
|
523
|
+
decode_slice_sizes=decode_slice_sizes,
|
|
524
|
+
)
|
|
525
|
+
# for vllm compatibility
|
|
526
|
+
self.model_executor = self._prefill_engines[0].model_executor
|
|
527
|
+
|
|
528
|
+
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
529
|
+
return self._prefill_engines[0].model_executor.supported_tasks
|
|
530
|
+
|
|
531
|
+
def add_request(self, request: Request, request_wave: int = 0):
|
|
532
|
+
if not isinstance(request.request_id, str):
|
|
533
|
+
raise TypeError(
|
|
534
|
+
f"request_id must be a string, got {type(request.request_id)}")
|
|
535
|
+
|
|
536
|
+
if pooling_params := request.pooling_params:
|
|
537
|
+
supported_pooling_tasks = [
|
|
538
|
+
task for task in self.get_supported_tasks()
|
|
539
|
+
if task in POOLING_TASKS
|
|
540
|
+
]
|
|
541
|
+
|
|
542
|
+
if pooling_params.task not in supported_pooling_tasks:
|
|
543
|
+
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
|
|
544
|
+
f"Supported tasks: {supported_pooling_tasks}")
|
|
545
|
+
|
|
546
|
+
if request.kv_transfer_params is not None and (
|
|
547
|
+
not self.scheduler.get_kv_connector()):
|
|
548
|
+
logger.warning("Got kv_transfer_params, but no KVConnector found. "
|
|
549
|
+
"Disabling KVTransfer for this request.")
|
|
550
|
+
|
|
551
|
+
self._orchestrator.add_request(request)
|
|
552
|
+
|
|
553
|
+
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
|
554
|
+
client_idx, output = self.output_queue.get()
|
|
555
|
+
# logger.warning(f"step output: {output}")
|
|
556
|
+
time.sleep(0.03)
|
|
557
|
+
return {client_idx: output}, True
|
|
558
|
+
|
|
559
|
+
def shutdown(self):
|
|
560
|
+
self._orchestrator.shutdown()
|
|
561
|
+
|
|
562
|
+
def reset_mm_cache(self):
|
|
563
|
+
# NOTE: Since this is mainly for debugging, we don't attempt to
|
|
564
|
+
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
|
|
565
|
+
for engine in itertools.chain(self._prefill_engines,
|
|
566
|
+
self._decode_engines):
|
|
567
|
+
if engine.scheduler.has_unfinished_requests():
|
|
568
|
+
logger.warning(
|
|
569
|
+
"Resetting the multi-modal cache when requests are "
|
|
570
|
+
"in progress may lead to desynced internal caches.")
|
|
571
|
+
|
|
572
|
+
if engine.mm_receiver_cache is not None:
|
|
573
|
+
engine.mm_receiver_cache.clear_cache()
|
|
574
|
+
|
|
575
|
+
def reset_prefix_cache(self):
|
|
576
|
+
for engine in itertools.chain(self._prefill_engines,
|
|
577
|
+
self._decode_engines):
|
|
578
|
+
engine.scheduler.reset_prefix_cache()
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
class DisaggEngineCoreProc(vLLMEngineCoreProc):
|
|
582
|
+
"""The vLLM-facing adapter that handles process management and I/O."""
|
|
583
|
+
|
|
584
|
+
@staticmethod
|
|
585
|
+
def is_supported() -> bool:
|
|
586
|
+
"""
|
|
587
|
+
Returns True if this engine can run in the current environment.
|
|
588
|
+
"""
|
|
589
|
+
return disagg_utils.is_disagg_enabled()
|
|
590
|
+
|
|
591
|
+
def __init__(
|
|
592
|
+
self,
|
|
593
|
+
vllm_config: VllmConfig,
|
|
594
|
+
local_client: bool,
|
|
595
|
+
handshake_address: str,
|
|
596
|
+
executor_class: type[Executor],
|
|
597
|
+
log_stats: bool,
|
|
598
|
+
client_handshake_address: Optional[str] = None,
|
|
599
|
+
engine_index: int = 0,
|
|
600
|
+
**kwargs,
|
|
601
|
+
):
|
|
602
|
+
if 'dp_rank' in kwargs or 'local_dp_rank' in kwargs:
|
|
603
|
+
logger.debug(
|
|
604
|
+
"Ignoring data parallelism arguments for non-DP disaggregated engine."
|
|
605
|
+
)
|
|
606
|
+
# We don't invoke super class's ctor as we are not really the
|
|
607
|
+
# engine core to be executed, instead we create other instance of
|
|
608
|
+
# engine cores and let them do the work.
|
|
609
|
+
self.vllm_config = vllm_config
|
|
610
|
+
|
|
611
|
+
# We should be taking the input from the client, the code below is forked from
|
|
612
|
+
# vllm.v1.engine.core.EngineCoreProc.
|
|
613
|
+
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
|
614
|
+
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
|
615
|
+
bytes]]()
|
|
616
|
+
|
|
617
|
+
self.engine_index = engine_index
|
|
618
|
+
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
|
619
|
+
self.engines_running = False
|
|
620
|
+
|
|
621
|
+
self.devices = jax.devices()
|
|
622
|
+
device_kind = self.devices[0].device_kind
|
|
623
|
+
if device_kind != 'TPU7x':
|
|
624
|
+
self.vllm_config.cache_config.gpu_memory_utilization = (
|
|
625
|
+
self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
|
|
626
|
+
prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
|
|
627
|
+
self.devices)
|
|
628
|
+
|
|
629
|
+
if isinstance(slice_sizes[0], int):
|
|
630
|
+
setattr(vllm_config.device_config, "slice",
|
|
631
|
+
(0, slice_sizes, self.devices))
|
|
632
|
+
else:
|
|
633
|
+
setattr(vllm_config.device_config, "slice",
|
|
634
|
+
((0, 0), 0, slice_sizes, self.devices))
|
|
635
|
+
logger.info(
|
|
636
|
+
f"Creating DisaggEngineCoreProc with slice_sizes {slice_sizes}...")
|
|
637
|
+
|
|
638
|
+
def executor_fail_callback():
|
|
639
|
+
self.input_queue.put_nowait(
|
|
640
|
+
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
|
641
|
+
|
|
642
|
+
# Don't complete handshake until DP coordinator ready message is
|
|
643
|
+
# received.
|
|
644
|
+
with self._perform_handshakes(handshake_address, identity,
|
|
645
|
+
local_client, vllm_config,
|
|
646
|
+
client_handshake_address) as addresses:
|
|
647
|
+
self.client_count = len(addresses.outputs)
|
|
648
|
+
|
|
649
|
+
# Set up data parallel environment.
|
|
650
|
+
self.has_coordinator = addresses.coordinator_output is not None
|
|
651
|
+
self.frontend_stats_publish_address = (
|
|
652
|
+
addresses.frontend_stats_publish_address)
|
|
653
|
+
self.publish_dp_lb_stats = (
|
|
654
|
+
self.has_coordinator
|
|
655
|
+
and not vllm_config.parallel_config.data_parallel_external_lb)
|
|
656
|
+
# Background Threads and Queues for IO. These enable us to
|
|
657
|
+
# overlap ZMQ socket IO with GPU since they release the GIL,
|
|
658
|
+
# and to overlap some serialization/deserialization with the
|
|
659
|
+
# model forward pass.
|
|
660
|
+
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
|
661
|
+
|
|
662
|
+
self._prefill_engines = _create_engine_cores(
|
|
663
|
+
prefill_slice_sizes,
|
|
664
|
+
vllm_config,
|
|
665
|
+
log_stats,
|
|
666
|
+
executor_fail_callback,
|
|
667
|
+
)
|
|
668
|
+
logger.info(
|
|
669
|
+
f"{len(self._prefill_engines)} Disaggregated prefill engines created."
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
self._decode_engines = _create_engine_cores(
|
|
673
|
+
decode_slice_sizes,
|
|
674
|
+
vllm_config,
|
|
675
|
+
log_stats,
|
|
676
|
+
executor_fail_callback,
|
|
677
|
+
)
|
|
678
|
+
logger.info(
|
|
679
|
+
f"{len(self._decode_engines)} Disaggregated decode engines created."
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
ready_event = threading.Event()
|
|
683
|
+
input_thread = threading.Thread(target=self.process_input_sockets,
|
|
684
|
+
args=(addresses.inputs,
|
|
685
|
+
addresses.coordinator_input,
|
|
686
|
+
identity, ready_event),
|
|
687
|
+
daemon=True)
|
|
688
|
+
input_thread.start()
|
|
689
|
+
|
|
690
|
+
self.output_thread = threading.Thread(
|
|
691
|
+
target=self.process_output_sockets,
|
|
692
|
+
args=(addresses.outputs, addresses.coordinator_output,
|
|
693
|
+
self.engine_index),
|
|
694
|
+
daemon=True)
|
|
695
|
+
self.output_thread.start()
|
|
696
|
+
while not ready_event.wait(timeout=10):
|
|
697
|
+
if not input_thread.is_alive():
|
|
698
|
+
raise RuntimeError(
|
|
699
|
+
"Input socket thread died during startup")
|
|
700
|
+
if addresses.coordinator_input is not None:
|
|
701
|
+
logger.info(
|
|
702
|
+
"Waiting for READY message from DP Coordinator...")
|
|
703
|
+
self.request_block_hasher = None
|
|
704
|
+
if (self.vllm_config.cache_config.enable_prefix_caching
|
|
705
|
+
or self._prefill_engines[0].scheduler.get_kv_connector()
|
|
706
|
+
is not None):
|
|
707
|
+
|
|
708
|
+
block_size = vllm_config.cache_config.block_size
|
|
709
|
+
caching_hash_fn = common_utils.get_hash_fn_by_name(
|
|
710
|
+
vllm_config.cache_config.prefix_caching_hash_algo)
|
|
711
|
+
init_none_hash(caching_hash_fn)
|
|
712
|
+
|
|
713
|
+
self.request_block_hasher = get_request_block_hasher(
|
|
714
|
+
block_size, caching_hash_fn)
|
|
715
|
+
|
|
716
|
+
self.mm_receiver_cache = None
|
|
717
|
+
self._orchestrator = _DisaggOrchestrator(
|
|
718
|
+
config=vllm_config,
|
|
719
|
+
output_queue=self.output_queue,
|
|
720
|
+
prefill_engines=self._prefill_engines,
|
|
721
|
+
decode_engines=self._decode_engines,
|
|
722
|
+
prefill_slice_sizes=prefill_slice_sizes,
|
|
723
|
+
decode_slice_sizes=decode_slice_sizes,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
|
|
727
|
+
if not isinstance(request.request_id, str):
|
|
728
|
+
raise TypeError(
|
|
729
|
+
f"request_id must be a string, got {type(request.request_id)}")
|
|
730
|
+
|
|
731
|
+
if pooling_params := request.pooling_params:
|
|
732
|
+
supported_pooling_tasks = [
|
|
733
|
+
task for task in self.get_supported_tasks()
|
|
734
|
+
if task in POOLING_TASKS
|
|
735
|
+
]
|
|
736
|
+
|
|
737
|
+
if pooling_params.task not in supported_pooling_tasks:
|
|
738
|
+
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
|
|
739
|
+
f"Supported tasks: {supported_pooling_tasks}")
|
|
740
|
+
|
|
741
|
+
self._orchestrator.add_request(request)
|
|
742
|
+
|
|
743
|
+
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
|
744
|
+
request: Any) -> None:
|
|
745
|
+
"""Dispatch request from client."""
|
|
746
|
+
if request_type == EngineCoreRequestType.ADD:
|
|
747
|
+
req, request_wave = request
|
|
748
|
+
self.add_request(req)
|
|
749
|
+
elif request_type == EngineCoreRequestType.ABORT:
|
|
750
|
+
# TODO(fhzhang): we need to keep track of which engine is processing
|
|
751
|
+
# the request and finish it there.
|
|
752
|
+
# owner_engine.scheduler.finish_requests(request, RequestStatus.FINISHED_ABORTED)
|
|
753
|
+
pass
|
|
754
|
+
elif request_type == EngineCoreRequestType.UTILITY:
|
|
755
|
+
client_idx, call_id, method_name, args = request
|
|
756
|
+
output = UtilityOutput(call_id)
|
|
757
|
+
try:
|
|
758
|
+
method = getattr(self._prefill_engines[0], method_name)
|
|
759
|
+
result = method(*self._convert_msgspec_args(method, args))
|
|
760
|
+
output.result = UtilityResult(result)
|
|
761
|
+
except BaseException as e:
|
|
762
|
+
logger.exception("Invocation of %s method failed", method_name)
|
|
763
|
+
output.failure_message = (f"Call to {method_name} method"
|
|
764
|
+
f" failed: {str(e)}")
|
|
765
|
+
self.output_queue.put_nowait(
|
|
766
|
+
(client_idx, EngineCoreOutputs(utility_output=output)))
|
|
767
|
+
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
|
768
|
+
raise RuntimeError("Executor failed.")
|
|
769
|
+
else:
|
|
770
|
+
logger.error("Unrecognized input request type encountered: %s",
|
|
771
|
+
request_type)
|
|
772
|
+
|
|
773
|
+
def run_busy_loop(self):
|
|
774
|
+
"""Core busy loop of the EngineCore."""
|
|
775
|
+
|
|
776
|
+
# Loop until process is sent a SIGINT or SIGTERM
|
|
777
|
+
while True:
|
|
778
|
+
while not self.input_queue.empty():
|
|
779
|
+
req = self.input_queue.get_nowait()
|
|
780
|
+
self._handle_client_request(*req)
|
|
781
|
+
# Yield control to other threads, as we are not doing any real work.
|
|
782
|
+
# Without this sleep, we'd be hogging all the cpu cycles with our run_busy_loop.
|
|
783
|
+
time.sleep(0.01)
|
|
784
|
+
|
|
785
|
+
def shutdown(self):
|
|
786
|
+
self._orchestrator.shutdown()
|