sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,119 @@
15
15
 
16
16
  import time
17
17
  from dataclasses import dataclass
18
- from typing import Dict, Union
18
+ from enum import Enum
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ from sglang.srt.utils import get_bool_env_var
22
+
23
+ SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
24
+
25
+
26
+ @dataclass
27
+ class TimeStats:
28
+ """
29
+ Store the timestamps for each stage of a request.
30
+
31
+ Unified: wait_queue -> forward -> completion
32
+ Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
33
+ Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
34
+ """
35
+
36
+ lb_entry_time: float = 0.0
37
+ wait_queue_entry_time: float = 0.0
38
+ forward_entry_time: float = 0.0
39
+ completion_time: float = 0.0
40
+ prefill_bootstrap_queue_entry_time: float = 0.0
41
+ prefill_transfer_queue_entry_time: float = 0.0
42
+ decode_prealloc_queue_entry_time: float = 0.0
43
+ decode_transfer_queue_entry_time: float = 0.0
44
+
45
+ class RequestType(Enum):
46
+ UNIFIED = "unified"
47
+ PREFILL = "prefill"
48
+ DECODE = "decode"
49
+ INVALID = "invalid"
50
+
51
+ def __str__(self) -> str:
52
+ # if unified
53
+ _type = self.get_type()
54
+
55
+ if _type == self.RequestType.UNIFIED:
56
+ queue_duration = self.forward_entry_time - self.wait_queue_entry_time
57
+ forward_duration = self.completion_time - self.forward_entry_time
58
+
59
+ if SGLANG_TEST_REQUEST_TIME_STATS:
60
+ assert (
61
+ queue_duration >= 0 and forward_duration >= 0
62
+ ), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
63
+
64
+ return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
65
+ elif _type == self.RequestType.PREFILL:
66
+ bootstrap_duration = (
67
+ self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
68
+ )
69
+
70
+ queue_duration = self.forward_entry_time - self.wait_queue_entry_time
71
+
72
+ forward_duration = self.completion_time - self.forward_entry_time
73
+
74
+ if SGLANG_TEST_REQUEST_TIME_STATS:
75
+ assert (
76
+ bootstrap_duration >= 0
77
+ and queue_duration >= 0
78
+ and forward_duration >= 0
79
+ ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
80
+ return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
81
+ # if decode
82
+ elif _type == self.RequestType.DECODE:
83
+ prealloc_duration = (
84
+ self.decode_transfer_queue_entry_time
85
+ - self.decode_prealloc_queue_entry_time
86
+ )
87
+
88
+ transfer_duration = (
89
+ self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
90
+ )
91
+ queue_duration = self.forward_entry_time - self.wait_queue_entry_time
92
+ forward_duration = self.completion_time - self.forward_entry_time
93
+
94
+ if SGLANG_TEST_REQUEST_TIME_STATS:
95
+ assert (
96
+ prealloc_duration >= 0
97
+ and transfer_duration >= 0
98
+ and queue_duration >= 0
99
+ and forward_duration >= 0
100
+ ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
101
+
102
+ return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
103
+ else:
104
+ return "Invalid Time Stats"
105
+
106
+ def format_duration(self, duration: float) -> str:
107
+ return f"{duration * 1e3:.2f}ms"
108
+
109
+ def get_type(self) -> RequestType:
110
+ """Determine the type of request based on timestamp values."""
111
+ if (
112
+ self.prefill_bootstrap_queue_entry_time == 0.0
113
+ and self.prefill_transfer_queue_entry_time == 0.0
114
+ and self.decode_prealloc_queue_entry_time == 0.0
115
+ and self.decode_transfer_queue_entry_time == 0.0
116
+ ):
117
+ return self.RequestType.UNIFIED
118
+ elif (
119
+ self.prefill_bootstrap_queue_entry_time > 0.0
120
+ and self.prefill_transfer_queue_entry_time > 0.0
121
+ ):
122
+ return self.RequestType.PREFILL
123
+ elif (
124
+ self.decode_prealloc_queue_entry_time > 0.0
125
+ and self.decode_transfer_queue_entry_time > 0.0
126
+ and self.wait_queue_entry_time > 0.0
127
+ ):
128
+ return self.RequestType.DECODE
129
+ else:
130
+ return self.RequestType.INVALID
19
131
 
20
132
 
21
133
  @dataclass
@@ -26,18 +138,23 @@ class SchedulerStats:
26
138
  gen_throughput: float = 0.0
27
139
  num_queue_reqs: int = 0
28
140
  cache_hit_rate: float = 0.0
141
+ num_grammar_queue_reqs: int = 0
29
142
  spec_accept_length: float = 0.0
30
143
  avg_request_queue_latency: float = 0.0
144
+ num_prefill_prealloc_queue_reqs: int = 0
145
+ num_prefill_infight_queue_reqs: int = 0
146
+ num_decode_prealloc_queue_reqs: int = 0
147
+ num_decode_transfer_queue_reqs: int = 0
31
148
 
32
149
 
33
150
  class SchedulerMetricsCollector:
34
151
 
35
152
  def __init__(self, labels: Dict[str, str]) -> None:
36
153
  # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
37
- from prometheus_client import Gauge, Histogram
154
+ from prometheus_client import Counter, Gauge
38
155
 
39
156
  self.labels = labels
40
- self.last_log_time = time.time()
157
+ self.last_log_time = time.perf_counter()
41
158
 
42
159
  self.num_running_reqs = Gauge(
43
160
  name="sglang:num_running_reqs",
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
74
191
  multiprocess_mode="mostrecent",
75
192
  )
76
193
 
194
+ self.num_grammar_queue_reqs = Gauge(
195
+ name="sglang:num_grammar_queue_reqs",
196
+ documentation="The number of requests in the grammar waiting queue.",
197
+ labelnames=labels.keys(),
198
+ multiprocess_mode="mostrecent",
199
+ )
200
+
77
201
  self.cache_hit_rate = Gauge(
78
202
  name="sglang:cache_hit_rate",
79
203
  documentation="The prefix cache hit rate.",
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
95
219
  multiprocess_mode="mostrecent",
96
220
  )
97
221
 
222
+ # Disaggregation queue metrics
223
+ self.num_prefill_prealloc_queue_reqs = Gauge(
224
+ name="sglang:num_prefill_prealloc_queue_reqs",
225
+ documentation="The number of requests in the prefill prealloc queue.",
226
+ labelnames=labels.keys(),
227
+ multiprocess_mode="mostrecent",
228
+ )
229
+
230
+ self.num_prefill_infight_queue_reqs = Gauge(
231
+ name="sglang:num_prefill_infight_queue_reqs",
232
+ documentation="The number of requests in the prefill infight queue.",
233
+ labelnames=labels.keys(),
234
+ multiprocess_mode="mostrecent",
235
+ )
236
+
237
+ self.num_decode_prealloc_queue_reqs = Gauge(
238
+ name="sglang:num_decode_prealloc_queue_reqs",
239
+ documentation="The number of requests in the decode prealloc queue.",
240
+ labelnames=labels.keys(),
241
+ multiprocess_mode="mostrecent",
242
+ )
243
+
244
+ self.num_decode_transfer_queue_reqs = Gauge(
245
+ name="sglang:num_decode_transfer_queue_reqs",
246
+ documentation="The number of requests in the decode transfer queue.",
247
+ labelnames=labels.keys(),
248
+ multiprocess_mode="mostrecent",
249
+ )
250
+
251
+ self.num_bootstrap_failed_reqs = Counter(
252
+ name="sglang:num_bootstrap_failed_reqs",
253
+ documentation="The number of bootstrap failed requests.",
254
+ labelnames=labels.keys(),
255
+ )
256
+
257
+ self.num_transfer_failed_reqs = Counter(
258
+ name="sglang:num_transfer_failed_reqs",
259
+ documentation="The number of transfer failed requests.",
260
+ labelnames=labels.keys(),
261
+ )
262
+
98
263
  def _log_gauge(self, gauge, data: Union[int, float]) -> None:
99
264
  # Convenience function for logging to gauge.
100
265
  gauge.labels(**self.labels).set(data)
101
266
 
267
+ def increment_bootstrap_failed_reqs(self) -> None:
268
+ self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1)
269
+
270
+ def increment_transfer_failed_reqs(self) -> None:
271
+ self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
272
+
102
273
  def log_stats(self, stats: SchedulerStats) -> None:
103
274
  self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
104
275
  self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
105
276
  self._log_gauge(self.token_usage, stats.token_usage)
106
277
  self._log_gauge(self.gen_throughput, stats.gen_throughput)
107
278
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
279
+ self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
108
280
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
109
281
  self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
110
- self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
111
- self.last_log_time = time.time()
282
+
283
+ # Disaggregation metrics
284
+ self._log_gauge(
285
+ self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
286
+ )
287
+ self._log_gauge(
288
+ self.num_prefill_infight_queue_reqs, stats.num_prefill_infight_queue_reqs
289
+ )
290
+ self._log_gauge(
291
+ self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs
292
+ )
293
+ self._log_gauge(
294
+ self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
295
+ )
296
+
297
+ self.last_log_time = time.perf_counter()
112
298
 
113
299
 
114
300
  class TokenizerMetricsCollector:
115
- def __init__(self, labels: Dict[str, str]) -> None:
301
+ def __init__(
302
+ self,
303
+ labels: Dict[str, str],
304
+ bucket_time_to_first_token: Optional[List[float]] = None,
305
+ bucket_inter_token_latency: Optional[List[float]] = None,
306
+ bucket_e2e_request_latency: Optional[List[float]] = None,
307
+ collect_tokens_histogram: bool = False,
308
+ ) -> None:
116
309
  # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
117
310
  from prometheus_client import Counter, Histogram
118
311
 
119
312
  self.labels = labels
313
+ self.collect_tokens_histogram = collect_tokens_histogram
120
314
 
121
315
  self.prompt_tokens_total = Counter(
122
316
  name="sglang:prompt_tokens_total",
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
130
324
  labelnames=labels.keys(),
131
325
  )
132
326
 
327
+ if collect_tokens_histogram:
328
+ bucket_prompt_tokens = [
329
+ 100,
330
+ 300,
331
+ 500,
332
+ 700,
333
+ 1000,
334
+ 1500,
335
+ 2000,
336
+ 3000,
337
+ 4000,
338
+ 5000,
339
+ 6000,
340
+ 7000,
341
+ 8000,
342
+ 9000,
343
+ 10000,
344
+ 12000,
345
+ 15000,
346
+ 20000,
347
+ 22000,
348
+ 25000,
349
+ 30000,
350
+ 35000,
351
+ 40000,
352
+ ]
353
+ self.prompt_tokens_histogram = Histogram(
354
+ name="sglang:prompt_tokens_histogram",
355
+ documentation="Histogram of prompt token length.",
356
+ labelnames=labels.keys(),
357
+ buckets=bucket_prompt_tokens,
358
+ )
359
+ bucket_generation_tokens = [
360
+ 100,
361
+ 300,
362
+ 500,
363
+ 1000,
364
+ 1200,
365
+ 1500,
366
+ 1700,
367
+ 2000,
368
+ 2500,
369
+ 3000,
370
+ 3500,
371
+ 4000,
372
+ 4500,
373
+ 5000,
374
+ 6000,
375
+ 7000,
376
+ 8000,
377
+ 9000,
378
+ 10000,
379
+ ]
380
+ self.generation_tokens_histogram = Histogram(
381
+ name="sglang:generation_tokens_histogram",
382
+ documentation="Histogram of generation token length.",
383
+ labelnames=labels.keys(),
384
+ buckets=bucket_generation_tokens,
385
+ )
386
+
133
387
  self.cached_tokens_total = Counter(
134
388
  name="sglang:cached_tokens_total",
135
389
  documentation="Number of cached prompt tokens.",
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
142
396
  labelnames=labels.keys(),
143
397
  )
144
398
 
145
- self.histogram_time_to_first_token = Histogram(
146
- name="sglang:time_to_first_token_seconds",
147
- documentation="Histogram of time to first token in seconds.",
399
+ self.num_so_requests_total = Counter(
400
+ name="sglang:num_so_requests_total",
401
+ documentation="Number of structured output requests processed.",
148
402
  labelnames=labels.keys(),
149
- buckets=[
403
+ )
404
+
405
+ if bucket_time_to_first_token is None:
406
+ bucket_time_to_first_token = [
150
407
  0.1,
151
408
  0.2,
152
409
  0.4,
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
165
422
  100,
166
423
  200,
167
424
  400,
168
- ],
169
- )
425
+ ]
170
426
 
171
- self.histogram_inter_token_latency_seconds = Histogram(
172
- name="sglang:inter_token_latency_seconds",
173
- documentation="Histogram of inter-token latency in seconds.",
174
- labelnames=labels.keys(),
175
- buckets=[
427
+ if bucket_e2e_request_latency is None:
428
+ bucket_e2e_request_latency = [
429
+ 0.1,
430
+ 0.2,
431
+ 0.4,
432
+ 0.6,
433
+ 0.8,
434
+ 1,
435
+ 2,
436
+ 4,
437
+ 6,
438
+ 8,
439
+ 10,
440
+ 20,
441
+ 40,
442
+ 60,
443
+ 80,
444
+ 100,
445
+ 200,
446
+ 400,
447
+ 800,
448
+ ]
449
+
450
+ if bucket_inter_token_latency is None:
451
+ bucket_inter_token_latency = [
176
452
  0.002,
177
453
  0.004,
178
454
  0.006,
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
196
472
  4.000,
197
473
  6.000,
198
474
  8.000,
199
- ],
475
+ ]
476
+
477
+ self.histogram_time_to_first_token = Histogram(
478
+ name="sglang:time_to_first_token_seconds",
479
+ documentation="Histogram of time to first token in seconds.",
480
+ labelnames=labels.keys(),
481
+ buckets=bucket_time_to_first_token,
482
+ )
483
+
484
+ self.histogram_inter_token_latency_seconds = Histogram(
485
+ name="sglang:inter_token_latency_seconds",
486
+ documentation="Histogram of inter-token latency in seconds.",
487
+ labelnames=labels.keys(),
488
+ buckets=bucket_inter_token_latency,
200
489
  )
201
490
 
202
491
  self.histogram_e2e_request_latency = Histogram(
203
492
  name="sglang:e2e_request_latency_seconds",
204
493
  documentation="Histogram of End-to-end request latency in seconds",
205
494
  labelnames=labels.keys(),
206
- buckets=[
207
- 0.1,
208
- 0.2,
209
- 0.4,
210
- 0.6,
211
- 0.8,
212
- 1,
213
- 2,
214
- 4,
215
- 6,
216
- 8,
217
- 10,
218
- 20,
219
- 40,
220
- 60,
221
- 80,
222
- 100,
223
- 200,
224
- 400,
225
- 800,
226
- ],
495
+ buckets=bucket_e2e_request_latency,
227
496
  )
228
497
 
229
498
  def _log_histogram(self, histogram, data: Union[int, float]) -> None:
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
235
504
  generation_tokens: int,
236
505
  cached_tokens: int,
237
506
  e2e_latency: float,
507
+ has_grammar: bool,
238
508
  ):
239
509
  self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
240
510
  self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
241
511
  if cached_tokens > 0:
242
512
  self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
243
513
  self.num_requests_total.labels(**self.labels).inc(1)
514
+ if has_grammar:
515
+ self.num_so_requests_total.labels(**self.labels).inc(1)
244
516
  self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
517
+ if self.collect_tokens_histogram:
518
+ self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
519
+ self._log_histogram(self.generation_tokens_histogram, generation_tokens)
245
520
 
246
521
  def observe_time_to_first_token(self, value: float):
247
522
  self.histogram_time_to_first_token.labels(**self.labels).observe(value)
sglang/srt/mm_utils.py CHANGED
@@ -36,6 +36,16 @@ from io import BytesIO
36
36
  import numpy as np
37
37
  from PIL import Image
38
38
 
39
+ from sglang.srt.utils import flatten_nested_list
40
+
41
+
42
+ def has_valid_data(data) -> bool:
43
+ if data is None:
44
+ return False
45
+ if isinstance(data, list):
46
+ return any(has_valid_data(item) for item in flatten_nested_list(data))
47
+ return True
48
+
39
49
 
40
50
  def select_best_resolution(original_size, possible_resolutions):
41
51
  """
@@ -19,7 +19,7 @@ import bisect
19
19
  import inspect
20
20
  import os
21
21
  from contextlib import contextmanager
22
- from typing import TYPE_CHECKING, Callable
22
+ from typing import TYPE_CHECKING, Callable, Optional, Union
23
23
 
24
24
  import torch
25
25
  import tqdm
@@ -30,6 +30,7 @@ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_captur
30
30
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
31
31
  from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
32
32
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
33
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
33
34
  from sglang.srt.model_executor.forward_batch_info import (
34
35
  CaptureHiddenMode,
35
36
  ForwardBatch,
@@ -40,14 +41,18 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile
40
41
  from sglang.srt.utils import (
41
42
  get_available_gpu_memory,
42
43
  get_device_memory_capacity,
43
- is_hip,
44
44
  rank0_log,
45
45
  )
46
46
 
47
47
  if TYPE_CHECKING:
48
48
  from sglang.srt.model_executor.model_runner import ModelRunner
49
49
 
50
- _is_hip = is_hip()
50
+ # Detect whether the current forward pass is in capture mode
51
+ is_capture_mode = False
52
+
53
+
54
+ def get_is_capture_mode():
55
+ return is_capture_mode
51
56
 
52
57
 
53
58
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
@@ -137,7 +142,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
137
142
  )
138
143
 
139
144
  gpu_mem = get_device_memory_capacity()
140
- # Batch size of each rank will not become so large when DP is on
141
145
  if gpu_mem is not None and gpu_mem > 96 * 1024:
142
146
  capture_bs += list(range(160, 257, 8))
143
147
 
@@ -148,12 +152,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
148
152
  model_runner.req_to_token_pool.size
149
153
  ]
150
154
 
151
- capture_bs = list(sorted(set(capture_bs)))
152
-
153
- assert len(capture_bs) > 0 and capture_bs[0] > 0
154
- capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
155
155
  if server_args.cuda_graph_max_bs:
156
156
  capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
157
+ if max(capture_bs) < server_args.cuda_graph_max_bs:
158
+ capture_bs += list(
159
+ range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
160
+ )
161
+ capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
162
+ capture_bs = list(sorted(set(capture_bs)))
163
+ assert len(capture_bs) > 0 and capture_bs[0] > 0
157
164
  compile_bs = (
158
165
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
159
166
  if server_args.enable_torch_compile
@@ -211,7 +218,10 @@ class CudaGraphRunner:
211
218
  # Attention backend
212
219
  self.max_bs = max(self.capture_bs)
213
220
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
214
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
221
+ if global_server_args_dict["attention_backend"] == "flashmla":
222
+ self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
223
+ else:
224
+ self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
215
225
  self.seq_len_fill_value = (
216
226
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
217
227
  )
@@ -237,6 +247,7 @@ class CudaGraphRunner:
237
247
  self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
238
248
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
239
249
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
250
+ self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
240
251
 
241
252
  # pipeline parallelism
242
253
  if self.pp_size > 1:
@@ -296,28 +307,23 @@ class CudaGraphRunner:
296
307
  self.capture()
297
308
  except RuntimeError as e:
298
309
  raise Exception(
299
- f"Capture cuda graph failed: {e}\n"
310
+ f"Capture CUDA graph failed: {e}\n"
300
311
  "Possible solutions:\n"
301
312
  "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
302
313
  "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
303
314
  "3. disable torch compile by not using --enable-torch-compile\n"
304
- "4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
315
+ "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
305
316
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
306
317
  )
307
318
 
308
319
  @contextmanager
309
320
  def model_capture_mode(self):
310
- if hasattr(self.model_runner.model, "capture_mode"):
311
- self.model_runner.model.capture_mode = True
312
- if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
313
- self.model_runner.token_to_kv_pool.capture_mode = True
321
+ global is_capture_mode
322
+ is_capture_mode = True
314
323
 
315
324
  yield
316
325
 
317
- if hasattr(self.model_runner.model, "capture_mode"):
318
- self.model_runner.model.capture_mode = False
319
- if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
320
- self.model_runner.token_to_kv_pool.capture_mode = False
326
+ is_capture_mode = False
321
327
 
322
328
  def can_run(self, forward_batch: ForwardBatch):
323
329
  if self.enable_dp_attention or self.enable_sp_layernorm:
@@ -400,6 +406,7 @@ class CudaGraphRunner:
400
406
  else:
401
407
  encoder_lens = None
402
408
  mrope_positions = self.mrope_positions[:, :bs]
409
+ self.num_token_non_padded[...] = num_tokens
403
410
 
404
411
  # pipeline parallelism
405
412
  if self.pp_size > 1:
@@ -458,6 +465,7 @@ class CudaGraphRunner:
458
465
  spec_info=spec_info,
459
466
  capture_hidden_mode=self.capture_hidden_mode,
460
467
  lora_paths=lora_paths,
468
+ num_token_non_padded=self.num_token_non_padded,
461
469
  )
462
470
 
463
471
  if lora_paths is not None:
@@ -553,6 +561,7 @@ class CudaGraphRunner:
553
561
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
554
562
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
555
563
  self.positions[:raw_num_token].copy_(forward_batch.positions)
564
+ self.num_token_non_padded[...] = len(forward_batch.input_ids)
556
565
  if forward_batch.seq_lens_cpu is not None:
557
566
  if bs != raw_bs:
558
567
  self.seq_lens_cpu.fill_(1)
@@ -605,6 +614,7 @@ class CudaGraphRunner:
605
614
 
606
615
  # Replay
607
616
  self.graphs[self.bs].replay()
617
+
608
618
  output = self.output_buffers[self.bs]
609
619
  if isinstance(output, LogitsProcessorOutput):
610
620
  return LogitsProcessorOutput(