sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,580 @@
|
|
1
|
+
"""
|
2
|
+
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
|
3
|
+
Mimics TokenizerManager's state management and ZMQ communication patterns.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
import dataclasses
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
import signal
|
11
|
+
import sys
|
12
|
+
import threading
|
13
|
+
import time
|
14
|
+
from typing import Any, Dict, List, Optional, Union
|
15
|
+
|
16
|
+
import grpc
|
17
|
+
import zmq
|
18
|
+
import zmq.asyncio
|
19
|
+
|
20
|
+
from sglang.srt.managers.io_struct import (
|
21
|
+
AbortReq,
|
22
|
+
BatchEmbeddingOut,
|
23
|
+
BatchTokenIDOut,
|
24
|
+
HealthCheckOutput,
|
25
|
+
TokenizedEmbeddingReqInput,
|
26
|
+
TokenizedGenerateReqInput,
|
27
|
+
)
|
28
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
29
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
30
|
+
from sglang.utils import get_exception_traceback
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class GrpcSignalHandler:
|
36
|
+
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
|
37
|
+
|
38
|
+
def __init__(self, grpc_manager):
|
39
|
+
self.grpc_manager = grpc_manager
|
40
|
+
|
41
|
+
def sigterm_handler(self, signum=None, frame=None):
|
42
|
+
"""Handle SIGTERM by gracefully shutting down gRPC server."""
|
43
|
+
logger.warning(
|
44
|
+
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
|
45
|
+
)
|
46
|
+
self.grpc_manager.gracefully_exit = True
|
47
|
+
|
48
|
+
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
49
|
+
"""Handle SIGQUIT from failed scheduler process."""
|
50
|
+
logger.error(
|
51
|
+
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
|
52
|
+
)
|
53
|
+
logger.info(
|
54
|
+
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
|
55
|
+
)
|
56
|
+
# Just exit cleanly - the scheduler handles crash dumps
|
57
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
58
|
+
|
59
|
+
|
60
|
+
@dataclasses.dataclass
|
61
|
+
class GrpcReqState:
|
62
|
+
"""State tracking for a gRPC request."""
|
63
|
+
|
64
|
+
# Request identification
|
65
|
+
request_id: str
|
66
|
+
grpc_context: Optional[grpc.aio.ServicerContext]
|
67
|
+
|
68
|
+
# Communication
|
69
|
+
out_queue: asyncio.Queue
|
70
|
+
finished: bool
|
71
|
+
event: asyncio.Event
|
72
|
+
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
73
|
+
|
74
|
+
# Metrics (same as TokenizerManager's ReqState)
|
75
|
+
created_time: float
|
76
|
+
finished_time: float = 0.0
|
77
|
+
first_token_time: float = 0.0
|
78
|
+
last_time: float = 0.0
|
79
|
+
last_completion_tokens: int = 1
|
80
|
+
|
81
|
+
# Streaming state
|
82
|
+
last_output_offset: int = 0
|
83
|
+
stream_finished: bool = False
|
84
|
+
|
85
|
+
# Output accumulation
|
86
|
+
text: str = ""
|
87
|
+
output_ids: List[int] = dataclasses.field(default_factory=list)
|
88
|
+
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
89
|
+
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
90
|
+
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
91
|
+
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
92
|
+
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
93
|
+
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
94
|
+
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
95
|
+
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
96
|
+
|
97
|
+
# Session state
|
98
|
+
session_id: Optional[str] = None
|
99
|
+
is_session_request: bool = False
|
100
|
+
|
101
|
+
|
102
|
+
class GrpcRequestManager:
|
103
|
+
"""
|
104
|
+
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
|
105
|
+
behaviors without tokenization.
|
106
|
+
"""
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
server_args: ServerArgs,
|
111
|
+
port_args: PortArgs,
|
112
|
+
):
|
113
|
+
"""Initialize the gRPC request manager."""
|
114
|
+
self.server_args = server_args
|
115
|
+
self.port_args = port_args
|
116
|
+
|
117
|
+
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
118
|
+
context = zmq.asyncio.Context(2)
|
119
|
+
|
120
|
+
# Socket for receiving outputs from scheduler
|
121
|
+
self.recv_from_scheduler = get_zmq_socket(
|
122
|
+
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
123
|
+
)
|
124
|
+
|
125
|
+
# Socket for sending requests to scheduler
|
126
|
+
self.send_to_scheduler = get_zmq_socket(
|
127
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
128
|
+
)
|
129
|
+
|
130
|
+
# State Management (from TokenizerManager)
|
131
|
+
self.rid_to_state: Dict[str, GrpcReqState] = {}
|
132
|
+
self.asyncio_tasks: set = set()
|
133
|
+
self.gracefully_exit = False
|
134
|
+
self.no_create_loop = False
|
135
|
+
self.event_loop = None
|
136
|
+
|
137
|
+
# Pause/Resume Control
|
138
|
+
self.is_pause = False
|
139
|
+
self.is_pause_cond = asyncio.Condition()
|
140
|
+
|
141
|
+
# Metrics
|
142
|
+
self.request_counter = 0
|
143
|
+
self.request_counter_lock = asyncio.Lock()
|
144
|
+
self.last_receive_tstamp = time.time()
|
145
|
+
|
146
|
+
# Crash dump for debugging
|
147
|
+
self.crash_dump_request_list = []
|
148
|
+
self.crash_dump_performed = False
|
149
|
+
|
150
|
+
logger.info(
|
151
|
+
f"GrpcRequestManager initialized with ZMQ IPC: "
|
152
|
+
f"recv={port_args.detokenizer_ipc_name}, "
|
153
|
+
f"send={port_args.scheduler_input_ipc_name}"
|
154
|
+
)
|
155
|
+
|
156
|
+
async def generate_request(
|
157
|
+
self,
|
158
|
+
obj: TokenizedGenerateReqInput,
|
159
|
+
request_id: Optional[str] = None,
|
160
|
+
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
161
|
+
) -> asyncio.Queue:
|
162
|
+
"""
|
163
|
+
Submit a generation request to the scheduler.
|
164
|
+
Returns a queue for streaming outputs.
|
165
|
+
"""
|
166
|
+
# Generate request ID if not provided
|
167
|
+
if request_id is None:
|
168
|
+
async with self.request_counter_lock:
|
169
|
+
request_id = f"grpc-{self.request_counter}"
|
170
|
+
self.request_counter += 1
|
171
|
+
|
172
|
+
obj.rid = request_id
|
173
|
+
|
174
|
+
# TODO: support log_request
|
175
|
+
|
176
|
+
# Create request state
|
177
|
+
state = GrpcReqState(
|
178
|
+
request_id=request_id,
|
179
|
+
grpc_context=grpc_context,
|
180
|
+
out_queue=asyncio.Queue(),
|
181
|
+
finished=False,
|
182
|
+
event=asyncio.Event(),
|
183
|
+
obj=obj,
|
184
|
+
created_time=time.time(),
|
185
|
+
)
|
186
|
+
|
187
|
+
# Track session if needed
|
188
|
+
if hasattr(obj, "session_params") and obj.session_params:
|
189
|
+
state.session_id = obj.session_params.session_id
|
190
|
+
state.is_session_request = True
|
191
|
+
|
192
|
+
# Register state
|
193
|
+
self.rid_to_state[request_id] = state
|
194
|
+
self.record_request_for_crash_dump(obj)
|
195
|
+
|
196
|
+
# Send to scheduler via ZMQ
|
197
|
+
try:
|
198
|
+
await self._send_to_scheduler(obj)
|
199
|
+
except Exception as e:
|
200
|
+
# Clean up on failure
|
201
|
+
del self.rid_to_state[request_id]
|
202
|
+
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
203
|
+
|
204
|
+
return state.out_queue
|
205
|
+
|
206
|
+
async def embedding_request(
|
207
|
+
self,
|
208
|
+
obj: TokenizedEmbeddingReqInput,
|
209
|
+
request_id: Optional[str] = None,
|
210
|
+
) -> asyncio.Future:
|
211
|
+
"""
|
212
|
+
Submit an embedding request to the scheduler.
|
213
|
+
Returns a future that will contain the embedding result.
|
214
|
+
"""
|
215
|
+
# Generate request ID if not provided
|
216
|
+
if request_id is None:
|
217
|
+
async with self.request_counter_lock:
|
218
|
+
request_id = f"grpc-embed-{self.request_counter}"
|
219
|
+
self.request_counter += 1
|
220
|
+
|
221
|
+
obj.rid = request_id
|
222
|
+
|
223
|
+
# Create request state
|
224
|
+
state = GrpcReqState(
|
225
|
+
request_id=request_id,
|
226
|
+
grpc_context=None,
|
227
|
+
out_queue=asyncio.Queue(),
|
228
|
+
finished=False,
|
229
|
+
event=asyncio.Event(),
|
230
|
+
obj=obj,
|
231
|
+
created_time=time.time(),
|
232
|
+
)
|
233
|
+
|
234
|
+
# Register state
|
235
|
+
self.rid_to_state[request_id] = state
|
236
|
+
|
237
|
+
# Create future for result
|
238
|
+
future = asyncio.Future()
|
239
|
+
|
240
|
+
# Send to scheduler
|
241
|
+
try:
|
242
|
+
await self._send_to_scheduler(obj)
|
243
|
+
except Exception as e:
|
244
|
+
del self.rid_to_state[request_id]
|
245
|
+
future.set_exception(e)
|
246
|
+
return future
|
247
|
+
|
248
|
+
# Wait for result in background
|
249
|
+
async def wait_for_result():
|
250
|
+
try:
|
251
|
+
# Wait for completion
|
252
|
+
await state.event.wait()
|
253
|
+
# Get result from queue
|
254
|
+
result = await state.out_queue.get()
|
255
|
+
future.set_result(result)
|
256
|
+
except Exception as e:
|
257
|
+
future.set_exception(e)
|
258
|
+
finally:
|
259
|
+
# Clean up
|
260
|
+
if request_id in self.rid_to_state:
|
261
|
+
del self.rid_to_state[request_id]
|
262
|
+
|
263
|
+
asyncio.create_task(wait_for_result())
|
264
|
+
return future
|
265
|
+
|
266
|
+
async def abort_request(self, request_id: str) -> bool:
|
267
|
+
"""Abort a running request."""
|
268
|
+
if request_id not in self.rid_to_state:
|
269
|
+
return False
|
270
|
+
|
271
|
+
# Send abort to scheduler
|
272
|
+
abort_req = AbortReq(rid=request_id)
|
273
|
+
try:
|
274
|
+
await self._send_to_scheduler(abort_req)
|
275
|
+
except Exception as e:
|
276
|
+
logger.error(f"Failed to send abort request: {e}")
|
277
|
+
return False
|
278
|
+
|
279
|
+
# Mark as finished
|
280
|
+
state = self.rid_to_state.get(request_id)
|
281
|
+
if state:
|
282
|
+
state.finished = True
|
283
|
+
state.stream_finished = True
|
284
|
+
state.event.set()
|
285
|
+
|
286
|
+
# Send abort notification to output queue
|
287
|
+
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
288
|
+
|
289
|
+
return True
|
290
|
+
|
291
|
+
async def pause_generation(self):
|
292
|
+
"""Pause generation processing."""
|
293
|
+
async with self.is_pause_cond:
|
294
|
+
self.is_pause = True
|
295
|
+
logger.info("Generation paused")
|
296
|
+
|
297
|
+
async def resume_generation(self):
|
298
|
+
"""Resume generation processing."""
|
299
|
+
async with self.is_pause_cond:
|
300
|
+
self.is_pause = False
|
301
|
+
self.is_pause_cond.notify_all()
|
302
|
+
logger.info("Generation resumed")
|
303
|
+
|
304
|
+
async def handle_loop(self):
|
305
|
+
"""
|
306
|
+
Main event loop - processes outputs from scheduler.
|
307
|
+
Mimics TokenizerManager's handle_loop.
|
308
|
+
"""
|
309
|
+
while not self.gracefully_exit:
|
310
|
+
try:
|
311
|
+
# Receive from scheduler
|
312
|
+
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
313
|
+
self.last_receive_tstamp = time.time()
|
314
|
+
|
315
|
+
# Check for pause
|
316
|
+
async with self.is_pause_cond:
|
317
|
+
while self.is_pause:
|
318
|
+
await self.is_pause_cond.wait()
|
319
|
+
|
320
|
+
# Handle different output types
|
321
|
+
if isinstance(recv_obj, BatchTokenIDOut):
|
322
|
+
await self._handle_batch_output(recv_obj)
|
323
|
+
elif isinstance(recv_obj, BatchEmbeddingOut):
|
324
|
+
await self._handle_embedding_output(recv_obj)
|
325
|
+
elif isinstance(recv_obj, HealthCheckOutput):
|
326
|
+
await self._handle_health_check_output(recv_obj)
|
327
|
+
else:
|
328
|
+
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
329
|
+
|
330
|
+
except zmq.error.Again:
|
331
|
+
# Timeout, check if we should exit
|
332
|
+
if self.gracefully_exit:
|
333
|
+
break
|
334
|
+
continue
|
335
|
+
except Exception as e:
|
336
|
+
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
337
|
+
if self.gracefully_exit:
|
338
|
+
break
|
339
|
+
|
340
|
+
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
341
|
+
"""Handle batch generation output from scheduler."""
|
342
|
+
# Process each request in the batch
|
343
|
+
for i, rid in enumerate(batch_out.rids):
|
344
|
+
if rid not in self.rid_to_state:
|
345
|
+
continue
|
346
|
+
|
347
|
+
state = self.rid_to_state[rid]
|
348
|
+
|
349
|
+
# Update metrics
|
350
|
+
now = time.time()
|
351
|
+
if state.first_token_time == 0.0:
|
352
|
+
state.first_token_time = now
|
353
|
+
state.last_time = now
|
354
|
+
|
355
|
+
# Extract output for this request
|
356
|
+
output_data = {
|
357
|
+
"request_id": rid,
|
358
|
+
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
359
|
+
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
360
|
+
"finished": batch_out.finished_reasons[i] is not None,
|
361
|
+
"meta_info": {
|
362
|
+
"prompt_tokens": (
|
363
|
+
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
364
|
+
),
|
365
|
+
"completion_tokens": (
|
366
|
+
batch_out.completion_tokens[i]
|
367
|
+
if batch_out.completion_tokens
|
368
|
+
else 0
|
369
|
+
),
|
370
|
+
"finish_reason": (
|
371
|
+
str(batch_out.finished_reasons[i])
|
372
|
+
if batch_out.finished_reasons[i]
|
373
|
+
else None
|
374
|
+
),
|
375
|
+
},
|
376
|
+
}
|
377
|
+
|
378
|
+
# Add logprobs if available
|
379
|
+
if batch_out.output_token_logprobs_val and i < len(
|
380
|
+
batch_out.output_token_logprobs_val
|
381
|
+
):
|
382
|
+
output_data["logprobs"] = {
|
383
|
+
"tokens": batch_out.output_token_logprobs_val[i],
|
384
|
+
"top_logprobs": (
|
385
|
+
batch_out.output_top_logprobs_val[i]
|
386
|
+
if batch_out.output_top_logprobs_val
|
387
|
+
and i < len(batch_out.output_top_logprobs_val)
|
388
|
+
else None
|
389
|
+
),
|
390
|
+
}
|
391
|
+
|
392
|
+
# Update state
|
393
|
+
if output_data["text"]:
|
394
|
+
state.text += output_data["text"][state.last_output_offset :]
|
395
|
+
state.last_output_offset = len(output_data["text"])
|
396
|
+
|
397
|
+
if output_data["token_ids"]:
|
398
|
+
state.output_ids.extend(output_data["token_ids"])
|
399
|
+
|
400
|
+
# Send to output queue
|
401
|
+
await state.out_queue.put(output_data)
|
402
|
+
|
403
|
+
# Handle completion
|
404
|
+
if output_data["finished"]:
|
405
|
+
state.finished = True
|
406
|
+
state.finished_time = now
|
407
|
+
state.stream_finished = True
|
408
|
+
state.event.set()
|
409
|
+
|
410
|
+
# Remove from tracking after a delay
|
411
|
+
async def cleanup():
|
412
|
+
await asyncio.sleep(5.0)
|
413
|
+
if rid in self.rid_to_state:
|
414
|
+
del self.rid_to_state[rid]
|
415
|
+
|
416
|
+
asyncio.create_task(cleanup())
|
417
|
+
|
418
|
+
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
419
|
+
"""Handle batch embedding output from scheduler."""
|
420
|
+
for i, rid in enumerate(batch_out.rids):
|
421
|
+
if rid not in self.rid_to_state:
|
422
|
+
continue
|
423
|
+
|
424
|
+
state = self.rid_to_state[rid]
|
425
|
+
|
426
|
+
# Create result
|
427
|
+
result = {
|
428
|
+
"request_id": rid,
|
429
|
+
"embedding": batch_out.embeddings[i],
|
430
|
+
"prompt_tokens": (
|
431
|
+
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
432
|
+
),
|
433
|
+
"finish_reason": (
|
434
|
+
batch_out.finish_reason[i] if batch_out.finish_reason else None
|
435
|
+
),
|
436
|
+
}
|
437
|
+
|
438
|
+
# Send result
|
439
|
+
await state.out_queue.put(result)
|
440
|
+
|
441
|
+
# Mark as finished
|
442
|
+
state.finished = True
|
443
|
+
state.finished_time = time.time()
|
444
|
+
state.event.set()
|
445
|
+
|
446
|
+
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
|
447
|
+
"""Handle health check output from scheduler."""
|
448
|
+
rid = health_out.rid
|
449
|
+
|
450
|
+
if rid not in self.rid_to_state:
|
451
|
+
logger.warning(f"Health check output for unknown request: {rid}")
|
452
|
+
return
|
453
|
+
|
454
|
+
state = self.rid_to_state[rid]
|
455
|
+
|
456
|
+
# Create health check result
|
457
|
+
result = {
|
458
|
+
"request_id": rid,
|
459
|
+
"healthy": True, # If we got a response, scheduler is healthy
|
460
|
+
"output_text": (
|
461
|
+
health_out.output_str if hasattr(health_out, "output_str") else ""
|
462
|
+
),
|
463
|
+
"finish_reason": (
|
464
|
+
health_out.finish_reason
|
465
|
+
if hasattr(health_out, "finish_reason")
|
466
|
+
else "stop"
|
467
|
+
),
|
468
|
+
}
|
469
|
+
|
470
|
+
# Send result
|
471
|
+
await state.out_queue.put(result)
|
472
|
+
|
473
|
+
# Mark as finished
|
474
|
+
state.finished = True
|
475
|
+
state.finished_time = time.time()
|
476
|
+
state.event.set()
|
477
|
+
|
478
|
+
async def _send_to_scheduler(self, obj):
|
479
|
+
"""Send an object to the scheduler via ZMQ."""
|
480
|
+
try:
|
481
|
+
self.send_to_scheduler.send_pyobj(obj)
|
482
|
+
except Exception as e:
|
483
|
+
logger.error(f"Failed to send to scheduler: {e}")
|
484
|
+
raise
|
485
|
+
|
486
|
+
def record_request_for_crash_dump(self, obj):
|
487
|
+
"""Record request for potential crash dump."""
|
488
|
+
if len(self.crash_dump_request_list) < 100:
|
489
|
+
self.crash_dump_request_list.append(
|
490
|
+
{
|
491
|
+
"time": time.time(),
|
492
|
+
"request_id": getattr(obj, "rid", "unknown"),
|
493
|
+
"type": type(obj).__name__,
|
494
|
+
}
|
495
|
+
)
|
496
|
+
|
497
|
+
async def shutdown(self):
|
498
|
+
"""Gracefully shutdown the request manager."""
|
499
|
+
logger.info("Shutting down GrpcRequestManager")
|
500
|
+
self.gracefully_exit = True
|
501
|
+
|
502
|
+
# Cancel all pending requests
|
503
|
+
for rid, state in self.rid_to_state.items():
|
504
|
+
if not state.finished:
|
505
|
+
await state.out_queue.put(
|
506
|
+
{"error": "Server shutting down", "shutdown": True}
|
507
|
+
)
|
508
|
+
state.finished = True
|
509
|
+
state.event.set()
|
510
|
+
|
511
|
+
# Wait for tasks to complete
|
512
|
+
if self.asyncio_tasks:
|
513
|
+
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
514
|
+
|
515
|
+
# Close ZMQ sockets
|
516
|
+
self.recv_from_scheduler.close()
|
517
|
+
self.send_to_scheduler.close()
|
518
|
+
|
519
|
+
logger.info("GrpcRequestManager shutdown complete")
|
520
|
+
|
521
|
+
def get_server_info(self) -> Dict[str, Any]:
|
522
|
+
"""Get server information for health checks."""
|
523
|
+
return {
|
524
|
+
"active_requests": len(self.rid_to_state),
|
525
|
+
"paused": self.is_pause,
|
526
|
+
"last_receive_time": self.last_receive_tstamp,
|
527
|
+
}
|
528
|
+
|
529
|
+
def auto_create_handle_loop(self):
|
530
|
+
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
|
531
|
+
if self.no_create_loop:
|
532
|
+
return
|
533
|
+
|
534
|
+
self.no_create_loop = True
|
535
|
+
loop = asyncio.get_event_loop()
|
536
|
+
self.asyncio_tasks.add(
|
537
|
+
loop.create_task(print_exception_wrapper(self.handle_loop))
|
538
|
+
)
|
539
|
+
|
540
|
+
self.event_loop = loop
|
541
|
+
|
542
|
+
# We cannot add signal handler when the grpc manager is not in
|
543
|
+
# the main thread due to the CPython limitation.
|
544
|
+
if threading.current_thread() is threading.main_thread():
|
545
|
+
signal_handler = GrpcSignalHandler(self)
|
546
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
547
|
+
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
548
|
+
loop.add_signal_handler(
|
549
|
+
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
550
|
+
)
|
551
|
+
else:
|
552
|
+
logger.warning(
|
553
|
+
"Signal handler is not added because the grpc request manager is "
|
554
|
+
"not in the main thread. This disables graceful shutdown of the "
|
555
|
+
"grpc request manager when SIGTERM is received."
|
556
|
+
)
|
557
|
+
self.asyncio_tasks.add(
|
558
|
+
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
559
|
+
)
|
560
|
+
|
561
|
+
async def sigterm_watchdog(self):
|
562
|
+
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
|
563
|
+
while not self.gracefully_exit:
|
564
|
+
await asyncio.sleep(1.0)
|
565
|
+
|
566
|
+
|
567
|
+
async def print_exception_wrapper(func):
|
568
|
+
"""
|
569
|
+
Sometimes an asyncio function does not print exception.
|
570
|
+
We do another wrapper to handle the exception.
|
571
|
+
"""
|
572
|
+
try:
|
573
|
+
await func()
|
574
|
+
except Exception:
|
575
|
+
traceback = get_exception_traceback()
|
576
|
+
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
|
577
|
+
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
|
578
|
+
func.__self__.dump_requests_before_crash()
|
579
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
580
|
+
sys.exit(1)
|