sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
|
|
16
16
|
import grpc
|
17
17
|
from grpc_reflection.v1alpha import reflection
|
18
18
|
|
19
|
+
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
19
20
|
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
|
20
21
|
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
21
22
|
from sglang.srt.managers.data_parallel_controller import (
|
22
23
|
run_data_parallel_controller_process,
|
23
24
|
)
|
25
|
+
from sglang.srt.managers.disagg_service import start_disagg_service
|
24
26
|
from sglang.srt.managers.io_struct import (
|
25
27
|
TokenizedEmbeddingReqInput,
|
26
28
|
TokenizedGenerateReqInput,
|
@@ -36,6 +38,20 @@ logger = logging.getLogger(__name__)
|
|
36
38
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
37
39
|
|
38
40
|
|
41
|
+
def _run_scheduler_with_signal_handling(*args, **kwargs):
|
42
|
+
"""
|
43
|
+
Wrapper for run_scheduler_process that ignores SIGINT.
|
44
|
+
|
45
|
+
The scheduler process should not handle Ctrl+C - it should only terminate
|
46
|
+
when the parent gRPC server exits (via kill_itself_when_parent_died).
|
47
|
+
"""
|
48
|
+
# Ignore SIGINT in this subprocess - let the parent handle it
|
49
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
50
|
+
|
51
|
+
# Now run the actual scheduler process
|
52
|
+
run_scheduler_process(*args, **kwargs)
|
53
|
+
|
54
|
+
|
39
55
|
def _launch_scheduler_process_only(
|
40
56
|
server_args: ServerArgs,
|
41
57
|
port_args: Optional[PortArgs] = None,
|
@@ -88,7 +104,7 @@ def _launch_scheduler_process_only(
|
|
88
104
|
)
|
89
105
|
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
90
106
|
proc = mp.Process(
|
91
|
-
target=
|
107
|
+
target=_run_scheduler_with_signal_handling,
|
92
108
|
args=(
|
93
109
|
server_args,
|
94
110
|
port_args,
|
@@ -98,7 +114,6 @@ def _launch_scheduler_process_only(
|
|
98
114
|
pp_rank,
|
99
115
|
None,
|
100
116
|
writer,
|
101
|
-
None,
|
102
117
|
),
|
103
118
|
)
|
104
119
|
|
@@ -181,20 +196,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
181
196
|
# Convert gRPC request to internal format
|
182
197
|
tokenized_req = self._convert_generate_request(request)
|
183
198
|
|
184
|
-
# Submit to request manager
|
185
|
-
|
199
|
+
# Submit to request manager (automatically handles n>1)
|
200
|
+
response_generator = self.request_manager.generate_request(
|
186
201
|
obj=tokenized_req,
|
187
202
|
request_id=request.request_id,
|
188
203
|
grpc_context=context,
|
189
204
|
)
|
190
205
|
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
206
|
+
async for output in response_generator:
|
207
|
+
# Handle batch responses (for n>1 non-streaming)
|
208
|
+
if isinstance(output, list):
|
209
|
+
for batch_output in output:
|
210
|
+
if "error" in batch_output:
|
211
|
+
yield sglang_scheduler_pb2.GenerateResponse(
|
212
|
+
request_id=request.request_id,
|
213
|
+
error=sglang_scheduler_pb2.GenerateError(
|
214
|
+
message=batch_output["error"],
|
215
|
+
http_status_code=(
|
216
|
+
"500" if "abort" not in batch_output else "499"
|
217
|
+
),
|
218
|
+
),
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
# All non-error batch outputs are final responses
|
222
|
+
yield self._create_completion_response(
|
223
|
+
request.request_id, batch_output
|
224
|
+
)
|
225
|
+
else:
|
226
|
+
# Handle single response (for streaming or n=1 non-streaming)
|
198
227
|
if "error" in output:
|
199
228
|
yield sglang_scheduler_pb2.GenerateResponse(
|
200
229
|
request_id=request.request_id,
|
@@ -205,27 +234,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
205
234
|
),
|
206
235
|
),
|
207
236
|
)
|
208
|
-
|
209
|
-
|
210
|
-
# Check if finished
|
211
|
-
if output.get("finished", False):
|
212
|
-
# Send completion
|
237
|
+
elif output.get("finished", False):
|
213
238
|
yield self._create_completion_response(
|
214
239
|
request.request_id, output
|
215
240
|
)
|
216
|
-
break
|
217
241
|
else:
|
218
|
-
# Send chunk
|
219
242
|
yield self._create_chunk_response(request.request_id, output)
|
220
243
|
|
221
|
-
except asyncio.TimeoutError:
|
222
|
-
# Check if context is still active
|
223
|
-
if context.cancelled():
|
224
|
-
# Abort the request
|
225
|
-
await self.request_manager.abort_request(request.request_id)
|
226
|
-
break
|
227
|
-
continue
|
228
|
-
|
229
244
|
except Exception as e:
|
230
245
|
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
231
246
|
yield sglang_scheduler_pb2.GenerateResponse(
|
@@ -266,7 +281,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
266
281
|
prompt_tokens=result.get("prompt_tokens", 0),
|
267
282
|
cached_tokens=0,
|
268
283
|
embedding_dim=len(result["embedding"]),
|
269
|
-
generation_time=time.time() - self.start_time,
|
270
284
|
),
|
271
285
|
)
|
272
286
|
|
@@ -319,17 +333,21 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
319
333
|
token_ids_logprob=None,
|
320
334
|
)
|
321
335
|
|
336
|
+
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
|
337
|
+
health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
338
|
+
health_request.bootstrap_room = 0
|
339
|
+
|
322
340
|
logger.info(f"Sending health check request to request manager...")
|
323
341
|
|
324
342
|
# Submit and wait for response
|
325
|
-
|
343
|
+
output_generator = self.request_manager.generate_request(
|
326
344
|
health_request, request_id=rid
|
327
345
|
)
|
328
346
|
|
329
347
|
try:
|
330
|
-
#
|
348
|
+
# Get first response with timeout
|
331
349
|
response = await asyncio.wait_for(
|
332
|
-
|
350
|
+
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
|
333
351
|
)
|
334
352
|
|
335
353
|
# Clean up
|
@@ -394,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
394
412
|
# Convert sampling params
|
395
413
|
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
|
396
414
|
|
415
|
+
# Extract disaggregated params if present
|
416
|
+
bootstrap_host = None
|
417
|
+
bootstrap_port = None
|
418
|
+
bootstrap_room = None
|
419
|
+
if grpc_req.HasField("disaggregated_params"):
|
420
|
+
bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
|
421
|
+
bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
|
422
|
+
bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
|
423
|
+
|
397
424
|
# Create request
|
398
425
|
return TokenizedGenerateReqInput(
|
399
426
|
rid=grpc_req.request_id,
|
@@ -402,13 +429,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
402
429
|
mm_inputs=None, # TODO: implement mm support
|
403
430
|
sampling_params=sampling_params,
|
404
431
|
return_logprob=grpc_req.return_logprob,
|
405
|
-
logprob_start_len=
|
432
|
+
logprob_start_len=(
|
433
|
+
grpc_req.logprob_start_len
|
434
|
+
if grpc_req.logprob_start_len is not None
|
435
|
+
else -1
|
436
|
+
),
|
406
437
|
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
407
|
-
stream=
|
408
|
-
|
438
|
+
stream=grpc_req.stream or False,
|
439
|
+
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
|
409
440
|
token_ids_logprob=(
|
410
441
|
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
411
442
|
),
|
443
|
+
bootstrap_host=bootstrap_host,
|
444
|
+
bootstrap_port=bootstrap_port,
|
445
|
+
bootstrap_room=bootstrap_room,
|
412
446
|
)
|
413
447
|
|
414
448
|
def _convert_embed_request(
|
@@ -438,6 +472,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
438
472
|
regex = None
|
439
473
|
json_schema = None
|
440
474
|
ebnf_grammar = None
|
475
|
+
structural_tag = None
|
441
476
|
|
442
477
|
if grpc_params.HasField("regex"):
|
443
478
|
regex = grpc_params.regex
|
@@ -445,6 +480,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
445
480
|
json_schema = grpc_params.json_schema
|
446
481
|
elif grpc_params.HasField("ebnf_grammar"):
|
447
482
|
ebnf_grammar = grpc_params.ebnf_grammar
|
483
|
+
elif grpc_params.HasField("structural_tag"):
|
484
|
+
structural_tag = grpc_params.structural_tag
|
448
485
|
|
449
486
|
return SGLSamplingParams(
|
450
487
|
temperature=grpc_params.temperature or 1.0,
|
@@ -456,33 +493,114 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
456
493
|
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
457
494
|
max_new_tokens=grpc_params.max_new_tokens or 128,
|
458
495
|
min_new_tokens=grpc_params.min_new_tokens or 0,
|
459
|
-
stop=list(grpc_params.stop) if grpc_params.stop else
|
496
|
+
stop=list(grpc_params.stop) if grpc_params.stop else [],
|
460
497
|
stop_token_ids=(
|
461
|
-
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else
|
498
|
+
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
|
462
499
|
),
|
463
500
|
skip_special_tokens=grpc_params.skip_special_tokens,
|
464
501
|
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
465
502
|
regex=regex,
|
466
503
|
json_schema=json_schema,
|
467
504
|
ebnf=ebnf_grammar,
|
505
|
+
structural_tag=structural_tag,
|
468
506
|
n=grpc_params.n or 1,
|
469
507
|
ignore_eos=grpc_params.ignore_eos,
|
470
508
|
)
|
471
509
|
|
510
|
+
def _convert_output_logprobs_to_proto(
|
511
|
+
self, logprobs_data: Dict
|
512
|
+
) -> Optional[sglang_scheduler_pb2.OutputLogProbs]:
|
513
|
+
"""Convert output logprobs dict to proto (no None values, plain floats)."""
|
514
|
+
if not logprobs_data:
|
515
|
+
return None
|
516
|
+
|
517
|
+
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
|
518
|
+
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
|
519
|
+
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
|
520
|
+
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
|
521
|
+
|
522
|
+
# Build TopLogProbs entries
|
523
|
+
top_logprobs_proto = []
|
524
|
+
if top_logprobs_val and top_logprobs_idx:
|
525
|
+
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
|
526
|
+
top_logprobs_proto.append(
|
527
|
+
sglang_scheduler_pb2.TopLogProbs(
|
528
|
+
values=val_list,
|
529
|
+
token_ids=idx_list,
|
530
|
+
)
|
531
|
+
)
|
532
|
+
|
533
|
+
return sglang_scheduler_pb2.OutputLogProbs(
|
534
|
+
token_logprobs=token_logprobs_val, # Plain float array
|
535
|
+
token_ids=token_logprobs_idx,
|
536
|
+
top_logprobs=top_logprobs_proto,
|
537
|
+
)
|
538
|
+
|
539
|
+
def _convert_input_logprobs_to_proto(
|
540
|
+
self, logprobs_data: Dict
|
541
|
+
) -> Optional[sglang_scheduler_pb2.InputLogProbs]:
|
542
|
+
"""Convert input logprobs dict to proto (first token is None, wrapped in InputTokenLogProb)."""
|
543
|
+
if not logprobs_data:
|
544
|
+
return None
|
545
|
+
|
546
|
+
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
|
547
|
+
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
|
548
|
+
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
|
549
|
+
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
|
550
|
+
|
551
|
+
# Wrap values in InputTokenLogProb (None for first token, value for others)
|
552
|
+
token_logprobs_wrapped = [
|
553
|
+
(
|
554
|
+
sglang_scheduler_pb2.InputTokenLogProb()
|
555
|
+
if x is None
|
556
|
+
else sglang_scheduler_pb2.InputTokenLogProb(value=x)
|
557
|
+
)
|
558
|
+
for x in token_logprobs_val
|
559
|
+
]
|
560
|
+
|
561
|
+
# Build TopLogProbs entries
|
562
|
+
top_logprobs_proto = []
|
563
|
+
if top_logprobs_val and top_logprobs_idx:
|
564
|
+
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
|
565
|
+
top_logprobs_proto.append(
|
566
|
+
sglang_scheduler_pb2.TopLogProbs(
|
567
|
+
values=val_list,
|
568
|
+
token_ids=idx_list,
|
569
|
+
)
|
570
|
+
)
|
571
|
+
|
572
|
+
return sglang_scheduler_pb2.InputLogProbs(
|
573
|
+
token_logprobs=token_logprobs_wrapped,
|
574
|
+
token_ids=token_logprobs_idx,
|
575
|
+
top_logprobs=top_logprobs_proto,
|
576
|
+
)
|
577
|
+
|
472
578
|
def _create_chunk_response(
|
473
579
|
self, request_id: str, output: Dict
|
474
580
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
475
581
|
"""Create a streaming chunk response."""
|
582
|
+
meta_info = output.get("meta_info", {})
|
583
|
+
|
584
|
+
# Convert output logprobs if present
|
585
|
+
output_logprobs_proto = self._convert_output_logprobs_to_proto(
|
586
|
+
output.get("output_logprobs")
|
587
|
+
)
|
588
|
+
|
589
|
+
# Convert input logprobs if present (only in first chunk)
|
590
|
+
input_logprobs_proto = self._convert_input_logprobs_to_proto(
|
591
|
+
output.get("input_logprobs")
|
592
|
+
)
|
593
|
+
|
476
594
|
return sglang_scheduler_pb2.GenerateResponse(
|
477
595
|
request_id=request_id,
|
478
596
|
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
597
|
+
token_ids=output.get("token_ids", []),
|
598
|
+
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
599
|
+
completion_tokens=meta_info.get("completion_tokens", 0),
|
600
|
+
cached_tokens=meta_info.get("cached_tokens", 0),
|
601
|
+
output_logprobs=output_logprobs_proto,
|
602
|
+
input_logprobs=input_logprobs_proto,
|
603
|
+
index=output.get("index", 0),
|
486
604
|
),
|
487
605
|
)
|
488
606
|
|
@@ -491,20 +609,57 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|
491
609
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
492
610
|
"""Create a completion response."""
|
493
611
|
|
494
|
-
#
|
495
|
-
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
|
612
|
+
# Extract meta info and finish reason details
|
496
613
|
meta_info = output.get("meta_info", {})
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
614
|
+
finish_reason_data = meta_info.get("finish_reason")
|
615
|
+
|
616
|
+
# Determine finish reason, default is stop
|
617
|
+
finish_reason = "stop"
|
618
|
+
if finish_reason_data:
|
619
|
+
if isinstance(finish_reason_data, dict):
|
620
|
+
finish_reason_type = finish_reason_data.get("type")
|
621
|
+
else:
|
622
|
+
# Handle legacy string format
|
623
|
+
finish_reason_type = finish_reason_data
|
624
|
+
|
625
|
+
if finish_reason_type == "length":
|
626
|
+
finish_reason = "length"
|
627
|
+
elif finish_reason_type == "abort":
|
628
|
+
finish_reason = "abort"
|
629
|
+
|
630
|
+
# Extract matched_stop information
|
631
|
+
matched_stop_kwargs = {}
|
632
|
+
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
|
633
|
+
matched = finish_reason_data["matched"]
|
634
|
+
if isinstance(matched, int):
|
635
|
+
matched_stop_kwargs["matched_token_id"] = matched
|
636
|
+
elif isinstance(matched, str):
|
637
|
+
matched_stop_kwargs["matched_stop_str"] = matched
|
638
|
+
|
639
|
+
# Convert output logprobs if present
|
640
|
+
output_logprobs_proto = self._convert_output_logprobs_to_proto(
|
641
|
+
output.get("output_logprobs")
|
642
|
+
)
|
643
|
+
|
644
|
+
# Convert input logprobs if present
|
645
|
+
input_logprobs_proto = self._convert_input_logprobs_to_proto(
|
646
|
+
output.get("input_logprobs")
|
647
|
+
)
|
501
648
|
|
502
649
|
return sglang_scheduler_pb2.GenerateResponse(
|
503
650
|
request_id=request_id,
|
504
651
|
complete=sglang_scheduler_pb2.GenerateComplete(
|
505
652
|
output_ids=output.get("token_ids", []),
|
506
|
-
output_text=output.get("text", ""),
|
507
653
|
finish_reason=finish_reason,
|
654
|
+
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
655
|
+
completion_tokens=meta_info.get(
|
656
|
+
"completion_tokens", len(output.get("token_ids", []))
|
657
|
+
),
|
658
|
+
cached_tokens=meta_info.get("cached_tokens", 0),
|
659
|
+
output_logprobs=output_logprobs_proto,
|
660
|
+
input_logprobs=input_logprobs_proto,
|
661
|
+
index=output.get("index", 0),
|
662
|
+
**matched_stop_kwargs,
|
508
663
|
),
|
509
664
|
)
|
510
665
|
|
@@ -522,6 +677,16 @@ async def serve_grpc(
|
|
522
677
|
):
|
523
678
|
"""Start the standalone gRPC server with integrated scheduler."""
|
524
679
|
|
680
|
+
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
|
681
|
+
# This ensures the bootstrap server is ready when prefill schedulers try to register
|
682
|
+
bootstrap_server = None
|
683
|
+
if server_args.disaggregation_mode == "prefill":
|
684
|
+
bootstrap_server = start_disagg_service(server_args)
|
685
|
+
if bootstrap_server:
|
686
|
+
logger.info(
|
687
|
+
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
|
688
|
+
)
|
689
|
+
|
525
690
|
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
526
691
|
logger.info("Launching scheduler process(es)...")
|
527
692
|
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
@@ -545,9 +710,11 @@ async def serve_grpc(
|
|
545
710
|
}
|
546
711
|
|
547
712
|
# Create request manager with the correct port args
|
713
|
+
# Note: We pass None for bootstrap_server since it's already started above
|
548
714
|
request_manager = GrpcRequestManager(
|
549
715
|
server_args=server_args,
|
550
716
|
port_args=port_args,
|
717
|
+
bootstrap_server=bootstrap_server,
|
551
718
|
)
|
552
719
|
|
553
720
|
# Create gRPC server
|
@@ -597,19 +764,28 @@ async def serve_grpc(
|
|
597
764
|
await stop_event.wait()
|
598
765
|
finally:
|
599
766
|
logger.info("Shutting down gRPC server")
|
767
|
+
|
768
|
+
# Shutdown request manager first - this closes ZMQ sockets and stops background tasks
|
600
769
|
await servicer.shutdown()
|
770
|
+
|
771
|
+
# Stop the gRPC server
|
601
772
|
await server.stop(5.0)
|
602
773
|
|
603
|
-
# Terminate scheduler processes
|
774
|
+
# Terminate scheduler processes before exiting to avoid atexit hang
|
775
|
+
# The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
|
604
776
|
for i, proc in enumerate(scheduler_procs):
|
605
|
-
if proc
|
777
|
+
if proc.is_alive():
|
606
778
|
logger.info(f"Terminating scheduler process {i}...")
|
607
779
|
proc.terminate()
|
608
|
-
proc.join(timeout=
|
780
|
+
proc.join(timeout=2.0)
|
609
781
|
if proc.is_alive():
|
610
|
-
logger.warning(
|
782
|
+
logger.warning(
|
783
|
+
f"Scheduler process {i} did not terminate, killing..."
|
784
|
+
)
|
611
785
|
proc.kill()
|
612
|
-
proc.join()
|
786
|
+
proc.join(timeout=1.0)
|
787
|
+
|
788
|
+
logger.info("All scheduler processes terminated")
|
613
789
|
|
614
790
|
|
615
791
|
def main():
|
@@ -618,55 +794,9 @@ def main():
|
|
618
794
|
mp.set_start_method("spawn", force=True)
|
619
795
|
|
620
796
|
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
|
621
|
-
|
622
|
-
# Server arguments
|
623
|
-
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
624
|
-
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
|
625
|
-
|
626
|
-
# Model arguments
|
627
|
-
parser.add_argument("--model-path", type=str, required=True, help="Model path")
|
628
|
-
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
|
629
|
-
parser.add_argument("--context-length", type=int, help="Context length")
|
630
|
-
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
631
|
-
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
|
632
|
-
|
633
|
-
# Runtime arguments
|
634
|
-
parser.add_argument(
|
635
|
-
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
|
636
|
-
)
|
637
|
-
parser.add_argument(
|
638
|
-
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
|
639
|
-
)
|
640
|
-
parser.add_argument(
|
641
|
-
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
|
642
|
-
)
|
643
|
-
parser.add_argument(
|
644
|
-
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
|
645
|
-
)
|
646
|
-
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
|
647
|
-
|
648
|
-
# Logging
|
649
|
-
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
650
|
-
|
797
|
+
ServerArgs.add_cli_args(parser)
|
651
798
|
args = parser.parse_args()
|
652
|
-
|
653
|
-
# Convert to ServerArgs with gRPC host and port
|
654
|
-
server_args = ServerArgs(
|
655
|
-
model_path=args.model_path,
|
656
|
-
tokenizer_path=args.tokenizer_path or args.model_path,
|
657
|
-
context_length=args.context_length,
|
658
|
-
tp_size=args.tp_size,
|
659
|
-
dp_size=args.dp_size,
|
660
|
-
max_running_requests=args.max_running_requests,
|
661
|
-
max_total_tokens=args.max_total_tokens,
|
662
|
-
max_prefill_tokens=args.max_prefill_tokens,
|
663
|
-
attention_backend=args.attention_backend,
|
664
|
-
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
665
|
-
log_level=args.log_level,
|
666
|
-
# Override with gRPC server host and port
|
667
|
-
host=args.host,
|
668
|
-
port=args.port,
|
669
|
-
)
|
799
|
+
server_args = ServerArgs.from_cli_args(args)
|
670
800
|
|
671
801
|
# Run server
|
672
802
|
asyncio.run(
|
@@ -29,8 +29,6 @@ import time
|
|
29
29
|
from http import HTTPStatus
|
30
30
|
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
|
31
31
|
|
32
|
-
import setproctitle
|
33
|
-
|
34
32
|
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
|
35
33
|
|
36
34
|
# Fix a bug of Python threading
|
@@ -72,6 +70,7 @@ from sglang.srt.managers.io_struct import (
|
|
72
70
|
AbortReq,
|
73
71
|
CloseSessionReqInput,
|
74
72
|
ConfigureLoggingReq,
|
73
|
+
DestroyWeightsUpdateGroupReqInput,
|
75
74
|
EmbeddingReqInput,
|
76
75
|
GenerateReqInput,
|
77
76
|
GetWeightsByNameReqInput,
|
@@ -95,8 +94,8 @@ from sglang.srt.managers.io_struct import (
|
|
95
94
|
VertexGenerateReqInput,
|
96
95
|
)
|
97
96
|
from sglang.srt.managers.multi_tokenizer_mixin import (
|
98
|
-
MultiTokenizerManager,
|
99
97
|
MultiTokenizerRouter,
|
98
|
+
TokenizerWorker,
|
100
99
|
get_main_process_id,
|
101
100
|
monkey_patch_uvicorn_multiprocessing,
|
102
101
|
read_from_shared_memory,
|
@@ -128,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
|
128
127
|
# Store global states
|
129
128
|
@dataclasses.dataclass
|
130
129
|
class _GlobalState:
|
131
|
-
tokenizer_manager: Union[
|
132
|
-
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
|
133
|
-
]
|
130
|
+
tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]
|
134
131
|
template_manager: TemplateManager
|
135
132
|
scheduler_info: Dict
|
136
133
|
|
@@ -165,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
|
|
165
162
|
)
|
166
163
|
|
167
164
|
# Launch multi-tokenizer manager process
|
168
|
-
tokenizer_manager =
|
165
|
+
tokenizer_manager = TokenizerWorker(server_args, port_args)
|
169
166
|
template_manager = TemplateManager()
|
170
167
|
template_manager.initialize_templates(
|
171
168
|
tokenizer_manager=tokenizer_manager,
|
@@ -302,7 +299,23 @@ app.add_middleware(
|
|
302
299
|
|
303
300
|
@app.exception_handler(HTTPException)
|
304
301
|
async def validation_exception_handler(request: Request, exc: HTTPException):
|
305
|
-
"""Enrich HTTP exception with status code and other details
|
302
|
+
"""Enrich HTTP exception with status code and other details.
|
303
|
+
|
304
|
+
For /v1/responses, emit OpenAI-style nested error envelope:
|
305
|
+
{"error": {"message": "...", "type": "...", "param": null, "code": <status>}}
|
306
|
+
"""
|
307
|
+
# adjust fmt for responses api
|
308
|
+
if request.url.path.startswith("/v1/responses"):
|
309
|
+
nested_error = {
|
310
|
+
"message": exc.detail,
|
311
|
+
"type": HTTPStatus(exc.status_code).phrase,
|
312
|
+
"param": None,
|
313
|
+
"code": exc.status_code,
|
314
|
+
}
|
315
|
+
return ORJSONResponse(
|
316
|
+
content={"error": nested_error}, status_code=exc.status_code
|
317
|
+
)
|
318
|
+
|
306
319
|
error = ErrorResponse(
|
307
320
|
object="error",
|
308
321
|
message=exc.detail,
|
@@ -315,7 +328,10 @@ async def validation_exception_handler(request: Request, exc: HTTPException):
|
|
315
328
|
# Custom exception handlers to change validation error status codes
|
316
329
|
@app.exception_handler(RequestValidationError)
|
317
330
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
318
|
-
"""Override FastAPI's default 422 validation error with 400
|
331
|
+
"""Override FastAPI's default 422 validation error with 400.
|
332
|
+
|
333
|
+
For /v1/responses, emit OpenAI-style nested error envelope; for other endpoints keep legacy format.
|
334
|
+
"""
|
319
335
|
exc_str = str(exc)
|
320
336
|
errors_str = str(exc.errors())
|
321
337
|
|
@@ -324,6 +340,16 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|
324
340
|
else:
|
325
341
|
message = exc_str
|
326
342
|
|
343
|
+
if request.url.path.startswith("/v1/responses"):
|
344
|
+
# adapt specially, for v1/responses API only (notice the error key is different)
|
345
|
+
nested_error = {
|
346
|
+
"message": message,
|
347
|
+
"type": HTTPStatus.BAD_REQUEST.phrase,
|
348
|
+
"param": None,
|
349
|
+
"code": HTTPStatus.BAD_REQUEST.value,
|
350
|
+
}
|
351
|
+
return ORJSONResponse(status_code=400, content={"error": nested_error})
|
352
|
+
|
327
353
|
err = ErrorResponse(
|
328
354
|
message=message,
|
329
355
|
type=HTTPStatus.BAD_REQUEST.phrase,
|
@@ -731,6 +757,20 @@ async def init_weights_update_group(
|
|
731
757
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
732
758
|
|
733
759
|
|
760
|
+
@app.post("/destroy_weights_update_group")
|
761
|
+
async def destroy_weights_update_group(
|
762
|
+
obj: DestroyWeightsUpdateGroupReqInput, request: Request
|
763
|
+
):
|
764
|
+
"""Destroy the parameter update group."""
|
765
|
+
success, message = (
|
766
|
+
await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)
|
767
|
+
)
|
768
|
+
content = {"success": success, "message": message}
|
769
|
+
return ORJSONResponse(
|
770
|
+
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
771
|
+
)
|
772
|
+
|
773
|
+
|
734
774
|
@app.post("/update_weights_from_tensor")
|
735
775
|
async def update_weights_from_tensor(
|
736
776
|
obj: UpdateWeightsFromTensorReqInput, request: Request
|