sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/metrics/collector.py
CHANGED
@@ -15,7 +15,119 @@
|
|
15
15
|
|
16
16
|
import time
|
17
17
|
from dataclasses import dataclass
|
18
|
-
from
|
18
|
+
from enum import Enum
|
19
|
+
from typing import Dict, List, Optional, Union
|
20
|
+
|
21
|
+
from sglang.srt.utils import get_bool_env_var
|
22
|
+
|
23
|
+
SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class TimeStats:
|
28
|
+
"""
|
29
|
+
Store the timestamps for each stage of a request.
|
30
|
+
|
31
|
+
Unified: wait_queue -> forward -> completion
|
32
|
+
Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
|
33
|
+
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
|
34
|
+
"""
|
35
|
+
|
36
|
+
lb_entry_time: float = 0.0
|
37
|
+
wait_queue_entry_time: float = 0.0
|
38
|
+
forward_entry_time: float = 0.0
|
39
|
+
completion_time: float = 0.0
|
40
|
+
prefill_bootstrap_queue_entry_time: float = 0.0
|
41
|
+
prefill_transfer_queue_entry_time: float = 0.0
|
42
|
+
decode_prealloc_queue_entry_time: float = 0.0
|
43
|
+
decode_transfer_queue_entry_time: float = 0.0
|
44
|
+
|
45
|
+
class RequestType(Enum):
|
46
|
+
UNIFIED = "unified"
|
47
|
+
PREFILL = "prefill"
|
48
|
+
DECODE = "decode"
|
49
|
+
INVALID = "invalid"
|
50
|
+
|
51
|
+
def __str__(self) -> str:
|
52
|
+
# if unified
|
53
|
+
_type = self.get_type()
|
54
|
+
|
55
|
+
if _type == self.RequestType.UNIFIED:
|
56
|
+
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
57
|
+
forward_duration = self.completion_time - self.forward_entry_time
|
58
|
+
|
59
|
+
if SGLANG_TEST_REQUEST_TIME_STATS:
|
60
|
+
assert (
|
61
|
+
queue_duration >= 0 and forward_duration >= 0
|
62
|
+
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
63
|
+
|
64
|
+
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
|
65
|
+
elif _type == self.RequestType.PREFILL:
|
66
|
+
bootstrap_duration = (
|
67
|
+
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
|
68
|
+
)
|
69
|
+
|
70
|
+
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
71
|
+
|
72
|
+
forward_duration = self.completion_time - self.forward_entry_time
|
73
|
+
|
74
|
+
if SGLANG_TEST_REQUEST_TIME_STATS:
|
75
|
+
assert (
|
76
|
+
bootstrap_duration >= 0
|
77
|
+
and queue_duration >= 0
|
78
|
+
and forward_duration >= 0
|
79
|
+
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
80
|
+
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
|
81
|
+
# if decode
|
82
|
+
elif _type == self.RequestType.DECODE:
|
83
|
+
prealloc_duration = (
|
84
|
+
self.decode_transfer_queue_entry_time
|
85
|
+
- self.decode_prealloc_queue_entry_time
|
86
|
+
)
|
87
|
+
|
88
|
+
transfer_duration = (
|
89
|
+
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
|
90
|
+
)
|
91
|
+
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
92
|
+
forward_duration = self.completion_time - self.forward_entry_time
|
93
|
+
|
94
|
+
if SGLANG_TEST_REQUEST_TIME_STATS:
|
95
|
+
assert (
|
96
|
+
prealloc_duration >= 0
|
97
|
+
and transfer_duration >= 0
|
98
|
+
and queue_duration >= 0
|
99
|
+
and forward_duration >= 0
|
100
|
+
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
101
|
+
|
102
|
+
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
|
103
|
+
else:
|
104
|
+
return "Invalid Time Stats"
|
105
|
+
|
106
|
+
def format_duration(self, duration: float) -> str:
|
107
|
+
return f"{duration * 1e3:.2f}ms"
|
108
|
+
|
109
|
+
def get_type(self) -> RequestType:
|
110
|
+
"""Determine the type of request based on timestamp values."""
|
111
|
+
if (
|
112
|
+
self.prefill_bootstrap_queue_entry_time == 0.0
|
113
|
+
and self.prefill_transfer_queue_entry_time == 0.0
|
114
|
+
and self.decode_prealloc_queue_entry_time == 0.0
|
115
|
+
and self.decode_transfer_queue_entry_time == 0.0
|
116
|
+
):
|
117
|
+
return self.RequestType.UNIFIED
|
118
|
+
elif (
|
119
|
+
self.prefill_bootstrap_queue_entry_time > 0.0
|
120
|
+
and self.prefill_transfer_queue_entry_time > 0.0
|
121
|
+
):
|
122
|
+
return self.RequestType.PREFILL
|
123
|
+
elif (
|
124
|
+
self.decode_prealloc_queue_entry_time > 0.0
|
125
|
+
and self.decode_transfer_queue_entry_time > 0.0
|
126
|
+
and self.wait_queue_entry_time > 0.0
|
127
|
+
):
|
128
|
+
return self.RequestType.DECODE
|
129
|
+
else:
|
130
|
+
return self.RequestType.INVALID
|
19
131
|
|
20
132
|
|
21
133
|
@dataclass
|
@@ -26,18 +138,23 @@ class SchedulerStats:
|
|
26
138
|
gen_throughput: float = 0.0
|
27
139
|
num_queue_reqs: int = 0
|
28
140
|
cache_hit_rate: float = 0.0
|
141
|
+
num_grammar_queue_reqs: int = 0
|
29
142
|
spec_accept_length: float = 0.0
|
30
143
|
avg_request_queue_latency: float = 0.0
|
144
|
+
num_prefill_prealloc_queue_reqs: int = 0
|
145
|
+
num_prefill_infight_queue_reqs: int = 0
|
146
|
+
num_decode_prealloc_queue_reqs: int = 0
|
147
|
+
num_decode_transfer_queue_reqs: int = 0
|
31
148
|
|
32
149
|
|
33
150
|
class SchedulerMetricsCollector:
|
34
151
|
|
35
152
|
def __init__(self, labels: Dict[str, str]) -> None:
|
36
153
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
37
|
-
from prometheus_client import
|
154
|
+
from prometheus_client import Counter, Gauge
|
38
155
|
|
39
156
|
self.labels = labels
|
40
|
-
self.last_log_time = time.
|
157
|
+
self.last_log_time = time.perf_counter()
|
41
158
|
|
42
159
|
self.num_running_reqs = Gauge(
|
43
160
|
name="sglang:num_running_reqs",
|
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
|
|
74
191
|
multiprocess_mode="mostrecent",
|
75
192
|
)
|
76
193
|
|
194
|
+
self.num_grammar_queue_reqs = Gauge(
|
195
|
+
name="sglang:num_grammar_queue_reqs",
|
196
|
+
documentation="The number of requests in the grammar waiting queue.",
|
197
|
+
labelnames=labels.keys(),
|
198
|
+
multiprocess_mode="mostrecent",
|
199
|
+
)
|
200
|
+
|
77
201
|
self.cache_hit_rate = Gauge(
|
78
202
|
name="sglang:cache_hit_rate",
|
79
203
|
documentation="The prefix cache hit rate.",
|
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
|
|
95
219
|
multiprocess_mode="mostrecent",
|
96
220
|
)
|
97
221
|
|
222
|
+
# Disaggregation queue metrics
|
223
|
+
self.num_prefill_prealloc_queue_reqs = Gauge(
|
224
|
+
name="sglang:num_prefill_prealloc_queue_reqs",
|
225
|
+
documentation="The number of requests in the prefill prealloc queue.",
|
226
|
+
labelnames=labels.keys(),
|
227
|
+
multiprocess_mode="mostrecent",
|
228
|
+
)
|
229
|
+
|
230
|
+
self.num_prefill_infight_queue_reqs = Gauge(
|
231
|
+
name="sglang:num_prefill_infight_queue_reqs",
|
232
|
+
documentation="The number of requests in the prefill infight queue.",
|
233
|
+
labelnames=labels.keys(),
|
234
|
+
multiprocess_mode="mostrecent",
|
235
|
+
)
|
236
|
+
|
237
|
+
self.num_decode_prealloc_queue_reqs = Gauge(
|
238
|
+
name="sglang:num_decode_prealloc_queue_reqs",
|
239
|
+
documentation="The number of requests in the decode prealloc queue.",
|
240
|
+
labelnames=labels.keys(),
|
241
|
+
multiprocess_mode="mostrecent",
|
242
|
+
)
|
243
|
+
|
244
|
+
self.num_decode_transfer_queue_reqs = Gauge(
|
245
|
+
name="sglang:num_decode_transfer_queue_reqs",
|
246
|
+
documentation="The number of requests in the decode transfer queue.",
|
247
|
+
labelnames=labels.keys(),
|
248
|
+
multiprocess_mode="mostrecent",
|
249
|
+
)
|
250
|
+
|
251
|
+
self.num_bootstrap_failed_reqs = Counter(
|
252
|
+
name="sglang:num_bootstrap_failed_reqs",
|
253
|
+
documentation="The number of bootstrap failed requests.",
|
254
|
+
labelnames=labels.keys(),
|
255
|
+
)
|
256
|
+
|
257
|
+
self.num_transfer_failed_reqs = Counter(
|
258
|
+
name="sglang:num_transfer_failed_reqs",
|
259
|
+
documentation="The number of transfer failed requests.",
|
260
|
+
labelnames=labels.keys(),
|
261
|
+
)
|
262
|
+
|
98
263
|
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
99
264
|
# Convenience function for logging to gauge.
|
100
265
|
gauge.labels(**self.labels).set(data)
|
101
266
|
|
267
|
+
def increment_bootstrap_failed_reqs(self) -> None:
|
268
|
+
self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1)
|
269
|
+
|
270
|
+
def increment_transfer_failed_reqs(self) -> None:
|
271
|
+
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
272
|
+
|
102
273
|
def log_stats(self, stats: SchedulerStats) -> None:
|
103
274
|
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
104
275
|
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
105
276
|
self._log_gauge(self.token_usage, stats.token_usage)
|
106
277
|
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
107
278
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
279
|
+
self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
|
108
280
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
109
281
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
110
|
-
|
111
|
-
|
282
|
+
|
283
|
+
# Disaggregation metrics
|
284
|
+
self._log_gauge(
|
285
|
+
self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
|
286
|
+
)
|
287
|
+
self._log_gauge(
|
288
|
+
self.num_prefill_infight_queue_reqs, stats.num_prefill_infight_queue_reqs
|
289
|
+
)
|
290
|
+
self._log_gauge(
|
291
|
+
self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs
|
292
|
+
)
|
293
|
+
self._log_gauge(
|
294
|
+
self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
|
295
|
+
)
|
296
|
+
|
297
|
+
self.last_log_time = time.perf_counter()
|
112
298
|
|
113
299
|
|
114
300
|
class TokenizerMetricsCollector:
|
115
|
-
def __init__(
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
labels: Dict[str, str],
|
304
|
+
bucket_time_to_first_token: Optional[List[float]] = None,
|
305
|
+
bucket_inter_token_latency: Optional[List[float]] = None,
|
306
|
+
bucket_e2e_request_latency: Optional[List[float]] = None,
|
307
|
+
collect_tokens_histogram: bool = False,
|
308
|
+
) -> None:
|
116
309
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
117
310
|
from prometheus_client import Counter, Histogram
|
118
311
|
|
119
312
|
self.labels = labels
|
313
|
+
self.collect_tokens_histogram = collect_tokens_histogram
|
120
314
|
|
121
315
|
self.prompt_tokens_total = Counter(
|
122
316
|
name="sglang:prompt_tokens_total",
|
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
|
|
130
324
|
labelnames=labels.keys(),
|
131
325
|
)
|
132
326
|
|
327
|
+
if collect_tokens_histogram:
|
328
|
+
bucket_prompt_tokens = [
|
329
|
+
100,
|
330
|
+
300,
|
331
|
+
500,
|
332
|
+
700,
|
333
|
+
1000,
|
334
|
+
1500,
|
335
|
+
2000,
|
336
|
+
3000,
|
337
|
+
4000,
|
338
|
+
5000,
|
339
|
+
6000,
|
340
|
+
7000,
|
341
|
+
8000,
|
342
|
+
9000,
|
343
|
+
10000,
|
344
|
+
12000,
|
345
|
+
15000,
|
346
|
+
20000,
|
347
|
+
22000,
|
348
|
+
25000,
|
349
|
+
30000,
|
350
|
+
35000,
|
351
|
+
40000,
|
352
|
+
]
|
353
|
+
self.prompt_tokens_histogram = Histogram(
|
354
|
+
name="sglang:prompt_tokens_histogram",
|
355
|
+
documentation="Histogram of prompt token length.",
|
356
|
+
labelnames=labels.keys(),
|
357
|
+
buckets=bucket_prompt_tokens,
|
358
|
+
)
|
359
|
+
bucket_generation_tokens = [
|
360
|
+
100,
|
361
|
+
300,
|
362
|
+
500,
|
363
|
+
1000,
|
364
|
+
1200,
|
365
|
+
1500,
|
366
|
+
1700,
|
367
|
+
2000,
|
368
|
+
2500,
|
369
|
+
3000,
|
370
|
+
3500,
|
371
|
+
4000,
|
372
|
+
4500,
|
373
|
+
5000,
|
374
|
+
6000,
|
375
|
+
7000,
|
376
|
+
8000,
|
377
|
+
9000,
|
378
|
+
10000,
|
379
|
+
]
|
380
|
+
self.generation_tokens_histogram = Histogram(
|
381
|
+
name="sglang:generation_tokens_histogram",
|
382
|
+
documentation="Histogram of generation token length.",
|
383
|
+
labelnames=labels.keys(),
|
384
|
+
buckets=bucket_generation_tokens,
|
385
|
+
)
|
386
|
+
|
133
387
|
self.cached_tokens_total = Counter(
|
134
388
|
name="sglang:cached_tokens_total",
|
135
389
|
documentation="Number of cached prompt tokens.",
|
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
|
|
142
396
|
labelnames=labels.keys(),
|
143
397
|
)
|
144
398
|
|
145
|
-
self.
|
146
|
-
name="sglang:
|
147
|
-
documentation="
|
399
|
+
self.num_so_requests_total = Counter(
|
400
|
+
name="sglang:num_so_requests_total",
|
401
|
+
documentation="Number of structured output requests processed.",
|
148
402
|
labelnames=labels.keys(),
|
149
|
-
|
403
|
+
)
|
404
|
+
|
405
|
+
if bucket_time_to_first_token is None:
|
406
|
+
bucket_time_to_first_token = [
|
150
407
|
0.1,
|
151
408
|
0.2,
|
152
409
|
0.4,
|
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
|
|
165
422
|
100,
|
166
423
|
200,
|
167
424
|
400,
|
168
|
-
]
|
169
|
-
)
|
425
|
+
]
|
170
426
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
427
|
+
if bucket_e2e_request_latency is None:
|
428
|
+
bucket_e2e_request_latency = [
|
429
|
+
0.1,
|
430
|
+
0.2,
|
431
|
+
0.4,
|
432
|
+
0.6,
|
433
|
+
0.8,
|
434
|
+
1,
|
435
|
+
2,
|
436
|
+
4,
|
437
|
+
6,
|
438
|
+
8,
|
439
|
+
10,
|
440
|
+
20,
|
441
|
+
40,
|
442
|
+
60,
|
443
|
+
80,
|
444
|
+
100,
|
445
|
+
200,
|
446
|
+
400,
|
447
|
+
800,
|
448
|
+
]
|
449
|
+
|
450
|
+
if bucket_inter_token_latency is None:
|
451
|
+
bucket_inter_token_latency = [
|
176
452
|
0.002,
|
177
453
|
0.004,
|
178
454
|
0.006,
|
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
|
|
196
472
|
4.000,
|
197
473
|
6.000,
|
198
474
|
8.000,
|
199
|
-
]
|
475
|
+
]
|
476
|
+
|
477
|
+
self.histogram_time_to_first_token = Histogram(
|
478
|
+
name="sglang:time_to_first_token_seconds",
|
479
|
+
documentation="Histogram of time to first token in seconds.",
|
480
|
+
labelnames=labels.keys(),
|
481
|
+
buckets=bucket_time_to_first_token,
|
482
|
+
)
|
483
|
+
|
484
|
+
self.histogram_inter_token_latency_seconds = Histogram(
|
485
|
+
name="sglang:inter_token_latency_seconds",
|
486
|
+
documentation="Histogram of inter-token latency in seconds.",
|
487
|
+
labelnames=labels.keys(),
|
488
|
+
buckets=bucket_inter_token_latency,
|
200
489
|
)
|
201
490
|
|
202
491
|
self.histogram_e2e_request_latency = Histogram(
|
203
492
|
name="sglang:e2e_request_latency_seconds",
|
204
493
|
documentation="Histogram of End-to-end request latency in seconds",
|
205
494
|
labelnames=labels.keys(),
|
206
|
-
buckets=
|
207
|
-
0.1,
|
208
|
-
0.2,
|
209
|
-
0.4,
|
210
|
-
0.6,
|
211
|
-
0.8,
|
212
|
-
1,
|
213
|
-
2,
|
214
|
-
4,
|
215
|
-
6,
|
216
|
-
8,
|
217
|
-
10,
|
218
|
-
20,
|
219
|
-
40,
|
220
|
-
60,
|
221
|
-
80,
|
222
|
-
100,
|
223
|
-
200,
|
224
|
-
400,
|
225
|
-
800,
|
226
|
-
],
|
495
|
+
buckets=bucket_e2e_request_latency,
|
227
496
|
)
|
228
497
|
|
229
498
|
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
|
|
235
504
|
generation_tokens: int,
|
236
505
|
cached_tokens: int,
|
237
506
|
e2e_latency: float,
|
507
|
+
has_grammar: bool,
|
238
508
|
):
|
239
509
|
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
240
510
|
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
241
511
|
if cached_tokens > 0:
|
242
512
|
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
243
513
|
self.num_requests_total.labels(**self.labels).inc(1)
|
514
|
+
if has_grammar:
|
515
|
+
self.num_so_requests_total.labels(**self.labels).inc(1)
|
244
516
|
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
517
|
+
if self.collect_tokens_histogram:
|
518
|
+
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
|
519
|
+
self._log_histogram(self.generation_tokens_histogram, generation_tokens)
|
245
520
|
|
246
521
|
def observe_time_to_first_token(self, value: float):
|
247
522
|
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
sglang/srt/mm_utils.py
CHANGED
@@ -36,6 +36,16 @@ from io import BytesIO
|
|
36
36
|
import numpy as np
|
37
37
|
from PIL import Image
|
38
38
|
|
39
|
+
from sglang.srt.utils import flatten_nested_list
|
40
|
+
|
41
|
+
|
42
|
+
def has_valid_data(data) -> bool:
|
43
|
+
if data is None:
|
44
|
+
return False
|
45
|
+
if isinstance(data, list):
|
46
|
+
return any(has_valid_data(item) for item in flatten_nested_list(data))
|
47
|
+
return True
|
48
|
+
|
39
49
|
|
40
50
|
def select_best_resolution(original_size, possible_resolutions):
|
41
51
|
"""
|
@@ -19,7 +19,7 @@ import bisect
|
|
19
19
|
import inspect
|
20
20
|
import os
|
21
21
|
from contextlib import contextmanager
|
22
|
-
from typing import TYPE_CHECKING, Callable
|
22
|
+
from typing import TYPE_CHECKING, Callable, Optional, Union
|
23
23
|
|
24
24
|
import torch
|
25
25
|
import tqdm
|
@@ -30,6 +30,7 @@ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_captur
|
|
30
30
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
31
31
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
32
32
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
33
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
34
|
from sglang.srt.model_executor.forward_batch_info import (
|
34
35
|
CaptureHiddenMode,
|
35
36
|
ForwardBatch,
|
@@ -40,14 +41,18 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile
|
|
40
41
|
from sglang.srt.utils import (
|
41
42
|
get_available_gpu_memory,
|
42
43
|
get_device_memory_capacity,
|
43
|
-
is_hip,
|
44
44
|
rank0_log,
|
45
45
|
)
|
46
46
|
|
47
47
|
if TYPE_CHECKING:
|
48
48
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
49
49
|
|
50
|
-
|
50
|
+
# Detect whether the current forward pass is in capture mode
|
51
|
+
is_capture_mode = False
|
52
|
+
|
53
|
+
|
54
|
+
def get_is_capture_mode():
|
55
|
+
return is_capture_mode
|
51
56
|
|
52
57
|
|
53
58
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
@@ -137,7 +142,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
137
142
|
)
|
138
143
|
|
139
144
|
gpu_mem = get_device_memory_capacity()
|
140
|
-
# Batch size of each rank will not become so large when DP is on
|
141
145
|
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
142
146
|
capture_bs += list(range(160, 257, 8))
|
143
147
|
|
@@ -148,12 +152,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
148
152
|
model_runner.req_to_token_pool.size
|
149
153
|
]
|
150
154
|
|
151
|
-
capture_bs = list(sorted(set(capture_bs)))
|
152
|
-
|
153
|
-
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
154
|
-
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
155
155
|
if server_args.cuda_graph_max_bs:
|
156
156
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
157
|
+
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
158
|
+
capture_bs += list(
|
159
|
+
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
160
|
+
)
|
161
|
+
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
162
|
+
capture_bs = list(sorted(set(capture_bs)))
|
163
|
+
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
157
164
|
compile_bs = (
|
158
165
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
159
166
|
if server_args.enable_torch_compile
|
@@ -211,7 +218,10 @@ class CudaGraphRunner:
|
|
211
218
|
# Attention backend
|
212
219
|
self.max_bs = max(self.capture_bs)
|
213
220
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
214
|
-
|
221
|
+
if global_server_args_dict["attention_backend"] == "flashmla":
|
222
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
223
|
+
else:
|
224
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
215
225
|
self.seq_len_fill_value = (
|
216
226
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
217
227
|
)
|
@@ -237,6 +247,7 @@ class CudaGraphRunner:
|
|
237
247
|
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
238
248
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
239
249
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
250
|
+
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
240
251
|
|
241
252
|
# pipeline parallelism
|
242
253
|
if self.pp_size > 1:
|
@@ -296,28 +307,23 @@ class CudaGraphRunner:
|
|
296
307
|
self.capture()
|
297
308
|
except RuntimeError as e:
|
298
309
|
raise Exception(
|
299
|
-
f"Capture
|
310
|
+
f"Capture CUDA graph failed: {e}\n"
|
300
311
|
"Possible solutions:\n"
|
301
312
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
302
313
|
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
303
314
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
304
|
-
"4. disable
|
315
|
+
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
305
316
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
306
317
|
)
|
307
318
|
|
308
319
|
@contextmanager
|
309
320
|
def model_capture_mode(self):
|
310
|
-
|
311
|
-
|
312
|
-
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
313
|
-
self.model_runner.token_to_kv_pool.capture_mode = True
|
321
|
+
global is_capture_mode
|
322
|
+
is_capture_mode = True
|
314
323
|
|
315
324
|
yield
|
316
325
|
|
317
|
-
|
318
|
-
self.model_runner.model.capture_mode = False
|
319
|
-
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
320
|
-
self.model_runner.token_to_kv_pool.capture_mode = False
|
326
|
+
is_capture_mode = False
|
321
327
|
|
322
328
|
def can_run(self, forward_batch: ForwardBatch):
|
323
329
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
@@ -400,6 +406,7 @@ class CudaGraphRunner:
|
|
400
406
|
else:
|
401
407
|
encoder_lens = None
|
402
408
|
mrope_positions = self.mrope_positions[:, :bs]
|
409
|
+
self.num_token_non_padded[...] = num_tokens
|
403
410
|
|
404
411
|
# pipeline parallelism
|
405
412
|
if self.pp_size > 1:
|
@@ -458,6 +465,7 @@ class CudaGraphRunner:
|
|
458
465
|
spec_info=spec_info,
|
459
466
|
capture_hidden_mode=self.capture_hidden_mode,
|
460
467
|
lora_paths=lora_paths,
|
468
|
+
num_token_non_padded=self.num_token_non_padded,
|
461
469
|
)
|
462
470
|
|
463
471
|
if lora_paths is not None:
|
@@ -553,6 +561,7 @@ class CudaGraphRunner:
|
|
553
561
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
554
562
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
555
563
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
564
|
+
self.num_token_non_padded[...] = len(forward_batch.input_ids)
|
556
565
|
if forward_batch.seq_lens_cpu is not None:
|
557
566
|
if bs != raw_bs:
|
558
567
|
self.seq_lens_cpu.fill_(1)
|
@@ -605,6 +614,7 @@ class CudaGraphRunner:
|
|
605
614
|
|
606
615
|
# Replay
|
607
616
|
self.graphs[self.bs].replay()
|
617
|
+
|
608
618
|
output = self.output_buffers[self.bs]
|
609
619
|
if isinstance(output, LogitsProcessorOutput):
|
610
620
|
return LogitsProcessorOutput(
|