sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import asyncio
|
7
|
+
import copy
|
7
8
|
import dataclasses
|
8
9
|
import logging
|
9
10
|
import os
|
@@ -11,7 +12,8 @@ import signal
|
|
11
12
|
import sys
|
12
13
|
import threading
|
13
14
|
import time
|
14
|
-
|
15
|
+
import uuid
|
16
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
15
17
|
|
16
18
|
import grpc
|
17
19
|
import zmq
|
@@ -19,8 +21,8 @@ import zmq.asyncio
|
|
19
21
|
|
20
22
|
from sglang.srt.managers.io_struct import (
|
21
23
|
AbortReq,
|
22
|
-
|
23
|
-
|
24
|
+
BatchEmbeddingOutput,
|
25
|
+
BatchTokenIDOutput,
|
24
26
|
HealthCheckOutput,
|
25
27
|
TokenizedEmbeddingReqInput,
|
26
28
|
TokenizedGenerateReqInput,
|
@@ -79,11 +81,10 @@ class GrpcReqState:
|
|
79
81
|
last_completion_tokens: int = 1
|
80
82
|
|
81
83
|
# Streaming state
|
82
|
-
last_output_offset: int = 0
|
83
84
|
stream_finished: bool = False
|
85
|
+
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
|
84
86
|
|
85
|
-
#
|
86
|
-
text: str = ""
|
87
|
+
# Token accumulation (for non-streaming)
|
87
88
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
88
89
|
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
89
90
|
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
@@ -109,22 +110,23 @@ class GrpcRequestManager:
|
|
109
110
|
self,
|
110
111
|
server_args: ServerArgs,
|
111
112
|
port_args: PortArgs,
|
113
|
+
bootstrap_server=None,
|
112
114
|
):
|
113
115
|
"""Initialize the gRPC request manager."""
|
114
116
|
self.server_args = server_args
|
115
117
|
self.port_args = port_args
|
116
118
|
|
117
119
|
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
118
|
-
context = zmq.asyncio.Context(2)
|
120
|
+
self.context = zmq.asyncio.Context(2)
|
119
121
|
|
120
122
|
# Socket for receiving outputs from scheduler
|
121
123
|
self.recv_from_scheduler = get_zmq_socket(
|
122
|
-
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
124
|
+
self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
123
125
|
)
|
124
126
|
|
125
127
|
# Socket for sending requests to scheduler
|
126
128
|
self.send_to_scheduler = get_zmq_socket(
|
127
|
-
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
129
|
+
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
128
130
|
)
|
129
131
|
|
130
132
|
# State Management (from TokenizerManager)
|
@@ -139,41 +141,158 @@ class GrpcRequestManager:
|
|
139
141
|
self.is_pause_cond = asyncio.Condition()
|
140
142
|
|
141
143
|
# Metrics
|
142
|
-
self.request_counter = 0
|
143
|
-
self.request_counter_lock = asyncio.Lock()
|
144
144
|
self.last_receive_tstamp = time.time()
|
145
145
|
|
146
146
|
# Crash dump for debugging
|
147
147
|
self.crash_dump_request_list = []
|
148
148
|
self.crash_dump_performed = False
|
149
149
|
|
150
|
+
# Bootstrap server (passed from serve_grpc, not started here)
|
151
|
+
self.bootstrap_server = bootstrap_server
|
152
|
+
|
150
153
|
logger.info(
|
151
154
|
f"GrpcRequestManager initialized with ZMQ IPC: "
|
152
155
|
f"recv={port_args.detokenizer_ipc_name}, "
|
153
156
|
f"send={port_args.scheduler_input_ipc_name}"
|
154
157
|
)
|
158
|
+
if self.bootstrap_server:
|
159
|
+
logger.info(
|
160
|
+
f"Bootstrap server initialized for disaggregation mode: "
|
161
|
+
f"{server_args.disaggregation_mode}"
|
162
|
+
)
|
155
163
|
|
156
164
|
async def generate_request(
|
157
165
|
self,
|
158
166
|
obj: TokenizedGenerateReqInput,
|
159
167
|
request_id: Optional[str] = None,
|
160
168
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
161
|
-
) ->
|
169
|
+
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
|
162
170
|
"""
|
163
|
-
Submit a generation request to the scheduler.
|
164
|
-
|
171
|
+
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
172
|
+
|
173
|
+
This method implements the same two-phase approach as tokenizer_manager.py:
|
174
|
+
1. Phase 1: Send prefix caching request (max_new_tokens=0)
|
175
|
+
2. Phase 2: Send n generation requests that reuse the cached prefix
|
176
|
+
|
177
|
+
Yields individual responses for streaming, or aggregated responses for non-streaming.
|
165
178
|
"""
|
179
|
+
n = getattr(obj.sampling_params, "n", 1)
|
180
|
+
|
181
|
+
if n <= 1:
|
182
|
+
async for response in self._handle_single_request(
|
183
|
+
obj, request_id, grpc_context
|
184
|
+
):
|
185
|
+
yield response
|
186
|
+
return
|
187
|
+
|
188
|
+
# N>1 handling - two-phase approach
|
189
|
+
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
|
190
|
+
|
191
|
+
# Generate base request ID if not provided
|
192
|
+
if request_id is None:
|
193
|
+
base_request_id = f"grpc-{uuid.uuid4().hex}"
|
194
|
+
else:
|
195
|
+
base_request_id = request_id
|
196
|
+
|
197
|
+
# Phase 1: Cache the common prefix
|
198
|
+
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
|
199
|
+
prefix_obj = copy.copy(obj)
|
200
|
+
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
|
201
|
+
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
|
202
|
+
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
|
203
|
+
|
204
|
+
# Send prefix caching request and consume response
|
205
|
+
async for _ in self._handle_single_request(
|
206
|
+
prefix_obj, f"{base_request_id}-prefix", grpc_context
|
207
|
+
):
|
208
|
+
# Consume prefix response (usually just one chunk with finish_reason)
|
209
|
+
pass
|
210
|
+
|
211
|
+
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
|
212
|
+
|
213
|
+
# Phase 2: Generate n parallel requests
|
214
|
+
logger.debug(f"Phase 2: Generating {n} parallel requests")
|
215
|
+
generators = []
|
216
|
+
request_ids = []
|
217
|
+
|
218
|
+
for i in range(n):
|
219
|
+
# Create individual generation request
|
220
|
+
gen_obj = copy.copy(obj)
|
221
|
+
gen_obj.sampling_params = copy.copy(obj.sampling_params)
|
222
|
+
gen_obj.sampling_params.n = 1 # Each request generates 1 response
|
223
|
+
|
224
|
+
gen_request_id = f"{base_request_id}-{i}"
|
225
|
+
request_ids.append(gen_request_id)
|
226
|
+
|
227
|
+
# Start generation request
|
228
|
+
generators.append(
|
229
|
+
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
|
230
|
+
)
|
231
|
+
|
232
|
+
# Handle response aggregation
|
233
|
+
is_stream = getattr(obj, "stream", False)
|
234
|
+
|
235
|
+
if not is_stream:
|
236
|
+
# Non-streaming: collect all responses and return as batch
|
237
|
+
logger.debug(f"Non-streaming mode: collecting {n} responses")
|
238
|
+
responses = []
|
239
|
+
for generator in generators:
|
240
|
+
async for response in generator:
|
241
|
+
responses.append(response)
|
242
|
+
yield responses # Return all responses as a batch
|
243
|
+
else:
|
244
|
+
# Streaming mode: multiplex responses with index for ordering
|
245
|
+
logger.debug(f"Streaming mode: multiplexing {n} streams")
|
246
|
+
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
|
247
|
+
|
248
|
+
# Create async tasks for all generators
|
249
|
+
task_map = {}
|
250
|
+
for generator in generators:
|
251
|
+
task = asyncio.create_task(generator.__anext__())
|
252
|
+
task_map[task] = generator
|
253
|
+
|
254
|
+
# Process responses as they arrive
|
255
|
+
while task_map:
|
256
|
+
done, _ = await asyncio.wait(
|
257
|
+
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
258
|
+
)
|
259
|
+
|
260
|
+
for task in done:
|
261
|
+
generator = task_map.pop(task)
|
262
|
+
try:
|
263
|
+
response = await task
|
264
|
+
|
265
|
+
# Add index for client-side ordering
|
266
|
+
if isinstance(response, dict) and "meta_info" in response:
|
267
|
+
response_rid = response["meta_info"].get("id", "")
|
268
|
+
if response_rid in rid_to_index:
|
269
|
+
response["index"] = rid_to_index[response_rid]
|
270
|
+
|
271
|
+
yield response
|
272
|
+
|
273
|
+
# Create next task for this generator
|
274
|
+
next_task = asyncio.create_task(generator.__anext__())
|
275
|
+
task_map[next_task] = generator
|
276
|
+
|
277
|
+
except StopAsyncIteration:
|
278
|
+
# This generator is finished
|
279
|
+
pass
|
280
|
+
|
281
|
+
async def _handle_single_request(
|
282
|
+
self,
|
283
|
+
obj: TokenizedGenerateReqInput,
|
284
|
+
request_id: Optional[str] = None,
|
285
|
+
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
286
|
+
):
|
287
|
+
"""Handle a single request - core implementation without n>1 logic."""
|
166
288
|
# Generate request ID if not provided
|
167
289
|
if request_id is None:
|
168
|
-
|
169
|
-
request_id = f"grpc-{self.request_counter}"
|
170
|
-
self.request_counter += 1
|
290
|
+
request_id = f"grpc-{uuid.uuid4().hex}"
|
171
291
|
|
172
292
|
obj.rid = request_id
|
173
293
|
|
294
|
+
# Create and register request state
|
174
295
|
# TODO: support log_request
|
175
|
-
|
176
|
-
# Create request state
|
177
296
|
state = GrpcReqState(
|
178
297
|
request_id=request_id,
|
179
298
|
grpc_context=grpc_context,
|
@@ -189,19 +308,51 @@ class GrpcRequestManager:
|
|
189
308
|
state.session_id = obj.session_params.session_id
|
190
309
|
state.is_session_request = True
|
191
310
|
|
192
|
-
# Register state
|
193
311
|
self.rid_to_state[request_id] = state
|
194
312
|
self.record_request_for_crash_dump(obj)
|
195
313
|
|
196
|
-
# Send to scheduler via ZMQ
|
197
314
|
try:
|
315
|
+
# Send to scheduler - let exceptions bubble up to grpc_server.py
|
198
316
|
await self._send_to_scheduler(obj)
|
199
|
-
except Exception as e:
|
200
|
-
# Clean up on failure
|
201
|
-
del self.rid_to_state[request_id]
|
202
|
-
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
203
317
|
|
204
|
-
|
318
|
+
is_stream = getattr(obj, "stream", False)
|
319
|
+
|
320
|
+
while True:
|
321
|
+
# Client cancelled - notify scheduler and exit
|
322
|
+
if grpc_context and grpc_context.cancelled():
|
323
|
+
await self.abort_request(request_id)
|
324
|
+
return
|
325
|
+
|
326
|
+
try:
|
327
|
+
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
|
328
|
+
|
329
|
+
if is_stream:
|
330
|
+
yield response
|
331
|
+
|
332
|
+
# Non-streaming: yield final response with accumulated tokens from state
|
333
|
+
if isinstance(response, dict) and response.get("finished", False):
|
334
|
+
if not is_stream:
|
335
|
+
final_response = response.copy()
|
336
|
+
final_response["token_ids"] = state.output_ids
|
337
|
+
yield final_response
|
338
|
+
break
|
339
|
+
|
340
|
+
except asyncio.TimeoutError:
|
341
|
+
# Timeout waiting for response - abort and cleanup
|
342
|
+
logger.warning(
|
343
|
+
f"Timeout waiting for response for request {request_id}"
|
344
|
+
)
|
345
|
+
await self.abort_request(request_id)
|
346
|
+
return
|
347
|
+
|
348
|
+
finally:
|
349
|
+
# Always clean up request state when exiting
|
350
|
+
self._cleanup_request_state(request_id)
|
351
|
+
|
352
|
+
def _cleanup_request_state(self, request_id: str):
|
353
|
+
"""Clean up local request state (does not notify scheduler)."""
|
354
|
+
if request_id in self.rid_to_state:
|
355
|
+
del self.rid_to_state[request_id]
|
205
356
|
|
206
357
|
async def embedding_request(
|
207
358
|
self,
|
@@ -214,9 +365,7 @@ class GrpcRequestManager:
|
|
214
365
|
"""
|
215
366
|
# Generate request ID if not provided
|
216
367
|
if request_id is None:
|
217
|
-
|
218
|
-
request_id = f"grpc-embed-{self.request_counter}"
|
219
|
-
self.request_counter += 1
|
368
|
+
request_id = f"grpc-embed-{uuid.uuid4().hex}"
|
220
369
|
|
221
370
|
obj.rid = request_id
|
222
371
|
|
@@ -318,9 +467,9 @@ class GrpcRequestManager:
|
|
318
467
|
await self.is_pause_cond.wait()
|
319
468
|
|
320
469
|
# Handle different output types
|
321
|
-
if isinstance(recv_obj,
|
470
|
+
if isinstance(recv_obj, BatchTokenIDOutput):
|
322
471
|
await self._handle_batch_output(recv_obj)
|
323
|
-
elif isinstance(recv_obj,
|
472
|
+
elif isinstance(recv_obj, BatchEmbeddingOutput):
|
324
473
|
await self._handle_embedding_output(recv_obj)
|
325
474
|
elif isinstance(recv_obj, HealthCheckOutput):
|
326
475
|
await self._handle_health_check_output(recv_obj)
|
@@ -332,12 +481,71 @@ class GrpcRequestManager:
|
|
332
481
|
if self.gracefully_exit:
|
333
482
|
break
|
334
483
|
continue
|
484
|
+
except zmq.error.ZMQError as e:
|
485
|
+
# Socket closed or other ZMQ error - exit cleanly if shutting down
|
486
|
+
if self.gracefully_exit:
|
487
|
+
logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
|
488
|
+
break
|
489
|
+
logger.error(
|
490
|
+
f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
|
491
|
+
)
|
492
|
+
break
|
335
493
|
except Exception as e:
|
336
494
|
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
337
495
|
if self.gracefully_exit:
|
338
496
|
break
|
339
497
|
|
340
|
-
|
498
|
+
def _convert_logprob_style(
|
499
|
+
self,
|
500
|
+
state: GrpcReqState,
|
501
|
+
batch_out: BatchTokenIDOutput,
|
502
|
+
batch_index: int,
|
503
|
+
):
|
504
|
+
"""
|
505
|
+
Convert and accumulate logprobs from batch output to state.
|
506
|
+
Follows the same logic as tokenizer_manager.convert_logprob_style.
|
507
|
+
"""
|
508
|
+
# Early exit if no input logprobs at all
|
509
|
+
if batch_out.input_token_logprobs_val is None:
|
510
|
+
return
|
511
|
+
|
512
|
+
# Accumulate input token logprobs (only if list is non-empty)
|
513
|
+
if len(batch_out.input_token_logprobs_val) > 0:
|
514
|
+
state.input_token_logprobs_val.extend(
|
515
|
+
batch_out.input_token_logprobs_val[batch_index]
|
516
|
+
)
|
517
|
+
state.input_token_logprobs_idx.extend(
|
518
|
+
batch_out.input_token_logprobs_idx[batch_index]
|
519
|
+
)
|
520
|
+
|
521
|
+
# Always accumulate output token logprobs
|
522
|
+
state.output_token_logprobs_val.extend(
|
523
|
+
batch_out.output_token_logprobs_val[batch_index]
|
524
|
+
)
|
525
|
+
state.output_token_logprobs_idx.extend(
|
526
|
+
batch_out.output_token_logprobs_idx[batch_index]
|
527
|
+
)
|
528
|
+
|
529
|
+
# Handle top logprobs if requested
|
530
|
+
if state.obj.top_logprobs_num > 0:
|
531
|
+
# Accumulate input top logprobs (only if list is non-empty)
|
532
|
+
if len(batch_out.input_top_logprobs_val) > 0:
|
533
|
+
state.input_top_logprobs_val.extend(
|
534
|
+
batch_out.input_top_logprobs_val[batch_index]
|
535
|
+
)
|
536
|
+
state.input_top_logprobs_idx.extend(
|
537
|
+
batch_out.input_top_logprobs_idx[batch_index]
|
538
|
+
)
|
539
|
+
|
540
|
+
# Always accumulate output top logprobs
|
541
|
+
state.output_top_logprobs_val.extend(
|
542
|
+
batch_out.output_top_logprobs_val[batch_index]
|
543
|
+
)
|
544
|
+
state.output_top_logprobs_idx.extend(
|
545
|
+
batch_out.output_top_logprobs_idx[batch_index]
|
546
|
+
)
|
547
|
+
|
548
|
+
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
|
341
549
|
"""Handle batch generation output from scheduler."""
|
342
550
|
# Process each request in the batch
|
343
551
|
for i, rid in enumerate(batch_out.rids):
|
@@ -355,7 +563,6 @@ class GrpcRequestManager:
|
|
355
563
|
# Extract output for this request
|
356
564
|
output_data = {
|
357
565
|
"request_id": rid,
|
358
|
-
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
359
566
|
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
360
567
|
"finished": batch_out.finished_reasons[i] is not None,
|
361
568
|
"meta_info": {
|
@@ -367,37 +574,81 @@ class GrpcRequestManager:
|
|
367
574
|
if batch_out.completion_tokens
|
368
575
|
else 0
|
369
576
|
),
|
577
|
+
"cached_tokens": (
|
578
|
+
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
579
|
+
),
|
370
580
|
"finish_reason": (
|
371
|
-
|
581
|
+
batch_out.finished_reasons[i]
|
372
582
|
if batch_out.finished_reasons[i]
|
373
583
|
else None
|
374
584
|
),
|
375
585
|
},
|
376
586
|
}
|
377
587
|
|
378
|
-
#
|
379
|
-
if
|
380
|
-
batch_out
|
381
|
-
):
|
382
|
-
output_data["logprobs"] = {
|
383
|
-
"tokens": batch_out.output_token_logprobs_val[i],
|
384
|
-
"top_logprobs": (
|
385
|
-
batch_out.output_top_logprobs_val[i]
|
386
|
-
if batch_out.output_top_logprobs_val
|
387
|
-
and i < len(batch_out.output_top_logprobs_val)
|
388
|
-
else None
|
389
|
-
),
|
390
|
-
}
|
391
|
-
|
392
|
-
# Update state
|
393
|
-
if output_data["text"]:
|
394
|
-
state.text += output_data["text"][state.last_output_offset :]
|
395
|
-
state.last_output_offset = len(output_data["text"])
|
588
|
+
# Accumulate logprobs (following tokenizer_manager pattern)
|
589
|
+
if state.obj.return_logprob:
|
590
|
+
self._convert_logprob_style(state, batch_out, i)
|
396
591
|
|
592
|
+
# Send input logprobs based if available
|
593
|
+
if (
|
594
|
+
state.obj.return_logprob
|
595
|
+
and state.obj.logprob_start_len >= 0
|
596
|
+
and state.input_token_logprobs_val
|
597
|
+
):
|
598
|
+
if state.obj.stream and not state.input_logprobs_sent:
|
599
|
+
# Streaming: send input logprobs once in first chunk that has them
|
600
|
+
output_data["input_logprobs"] = {
|
601
|
+
"token_logprobs_val": state.input_token_logprobs_val,
|
602
|
+
"token_logprobs_idx": state.input_token_logprobs_idx,
|
603
|
+
"top_logprobs_val": state.input_top_logprobs_val,
|
604
|
+
"top_logprobs_idx": state.input_top_logprobs_idx,
|
605
|
+
}
|
606
|
+
state.input_logprobs_sent = True
|
607
|
+
elif not state.obj.stream and output_data["finished"]:
|
608
|
+
# Non-streaming: send input logprobs in final chunk
|
609
|
+
output_data["input_logprobs"] = {
|
610
|
+
"token_logprobs_val": state.input_token_logprobs_val,
|
611
|
+
"token_logprobs_idx": state.input_token_logprobs_idx,
|
612
|
+
"top_logprobs_val": state.input_top_logprobs_val,
|
613
|
+
"top_logprobs_idx": state.input_top_logprobs_idx,
|
614
|
+
}
|
615
|
+
|
616
|
+
# Send output logprobs if available
|
617
|
+
if (
|
618
|
+
state.obj.return_logprob
|
619
|
+
and batch_out.output_token_logprobs_val
|
620
|
+
and i < len(batch_out.output_token_logprobs_val)
|
621
|
+
):
|
622
|
+
if state.obj.stream:
|
623
|
+
# For streaming: send incremental logprobs (only new tokens in this chunk)
|
624
|
+
# NOTE: this is different than TokenizerManager, which always accumulates
|
625
|
+
def get_part(attr_name):
|
626
|
+
source_list = getattr(batch_out, attr_name, None)
|
627
|
+
return (
|
628
|
+
source_list[i]
|
629
|
+
if source_list and i < len(source_list)
|
630
|
+
else []
|
631
|
+
)
|
632
|
+
|
633
|
+
output_data["output_logprobs"] = {
|
634
|
+
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
|
635
|
+
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
|
636
|
+
"top_logprobs_val": get_part("output_top_logprobs_val"),
|
637
|
+
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
|
638
|
+
}
|
639
|
+
elif output_data["finished"]:
|
640
|
+
# Non-streaming: send cumulative output logprobs in final chunk
|
641
|
+
output_data["output_logprobs"] = {
|
642
|
+
"token_logprobs_val": state.output_token_logprobs_val,
|
643
|
+
"token_logprobs_idx": state.output_token_logprobs_idx,
|
644
|
+
"top_logprobs_val": state.output_top_logprobs_val,
|
645
|
+
"top_logprobs_idx": state.output_top_logprobs_idx,
|
646
|
+
}
|
647
|
+
|
648
|
+
# Update state for accumulation
|
397
649
|
if output_data["token_ids"]:
|
398
650
|
state.output_ids.extend(output_data["token_ids"])
|
399
651
|
|
400
|
-
# Send to output queue
|
401
652
|
await state.out_queue.put(output_data)
|
402
653
|
|
403
654
|
# Handle completion
|
@@ -415,7 +666,7 @@ class GrpcRequestManager:
|
|
415
666
|
|
416
667
|
asyncio.create_task(cleanup())
|
417
668
|
|
418
|
-
async def _handle_embedding_output(self, batch_out:
|
669
|
+
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
|
419
670
|
"""Handle batch embedding output from scheduler."""
|
420
671
|
for i, rid in enumerate(batch_out.rids):
|
421
672
|
if rid not in self.rid_to_state:
|
@@ -499,8 +750,17 @@ class GrpcRequestManager:
|
|
499
750
|
logger.info("Shutting down GrpcRequestManager")
|
500
751
|
self.gracefully_exit = True
|
501
752
|
|
753
|
+
# Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
|
754
|
+
for task in list(self.asyncio_tasks):
|
755
|
+
if not task.done():
|
756
|
+
task.cancel()
|
757
|
+
|
758
|
+
# Give tasks a moment to process cancellation
|
759
|
+
if self.asyncio_tasks:
|
760
|
+
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
761
|
+
|
502
762
|
# Cancel all pending requests
|
503
|
-
for rid, state in self.rid_to_state.items():
|
763
|
+
for rid, state in list(self.rid_to_state.items()):
|
504
764
|
if not state.finished:
|
505
765
|
await state.out_queue.put(
|
506
766
|
{"error": "Server shutting down", "shutdown": True}
|
@@ -512,10 +772,25 @@ class GrpcRequestManager:
|
|
512
772
|
if self.asyncio_tasks:
|
513
773
|
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
514
774
|
|
775
|
+
# Shutdown bootstrap server if running
|
776
|
+
if self.bootstrap_server:
|
777
|
+
logger.info("Shutting down bootstrap server")
|
778
|
+
try:
|
779
|
+
if hasattr(self.bootstrap_server, "shutdown"):
|
780
|
+
if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
|
781
|
+
await self.bootstrap_server.shutdown()
|
782
|
+
else:
|
783
|
+
self.bootstrap_server.shutdown()
|
784
|
+
except Exception as e:
|
785
|
+
logger.warning(f"Error shutting down bootstrap server: {e}")
|
786
|
+
|
515
787
|
# Close ZMQ sockets
|
516
788
|
self.recv_from_scheduler.close()
|
517
789
|
self.send_to_scheduler.close()
|
518
790
|
|
791
|
+
# Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
|
792
|
+
self.context.term()
|
793
|
+
|
519
794
|
logger.info("GrpcRequestManager shutdown complete")
|
520
795
|
|
521
796
|
def get_server_info(self) -> Dict[str, Any]:
|