sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,7 +1,8 @@
1
1
  # SGLang public APIs
2
2
 
3
3
  # Frontend Language APIs
4
- from sglang.api import (
4
+ from sglang.global_config import global_config
5
+ from sglang.lang.api import (
5
6
  Engine,
6
7
  Runtime,
7
8
  assistant,
@@ -25,22 +26,26 @@ from sglang.api import (
25
26
  user_end,
26
27
  video,
27
28
  )
28
- from sglang.global_config import global_config
29
29
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
30
30
  from sglang.lang.choices import (
31
31
  greedy_token_selection,
32
32
  token_length_normalized,
33
33
  unconditional_likelihood_normalized,
34
34
  )
35
+
36
+ # Lazy import some libraries
35
37
  from sglang.utils import LazyImport
36
38
  from sglang.version import __version__
37
39
 
38
- ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
39
40
  Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
40
41
  LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
41
42
  OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
42
43
  VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
43
44
 
45
+ # Runtime Engine APIs
46
+ ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
47
+ Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
48
+
44
49
  __all__ = [
45
50
  "Engine",
46
51
  "Runtime",
sglang/bench_one_batch.py CHANGED
@@ -43,6 +43,7 @@ I'm going to the park
43
43
  """
44
44
 
45
45
  import argparse
46
+ import copy
46
47
  import dataclasses
47
48
  import itertools
48
49
  import json
@@ -60,6 +61,7 @@ from sglang.srt.configs.model_config import ModelConfig
60
61
  from sglang.srt.distributed.parallel_state import destroy_distributed_environment
61
62
  from sglang.srt.entrypoints.engine import _set_envs_and_config
62
63
  from sglang.srt.hf_transformers_utils import get_tokenizer
64
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
63
65
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
64
66
  from sglang.srt.managers.scheduler import Scheduler
65
67
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -84,12 +86,14 @@ class BenchArgs:
84
86
  batch_size: Tuple[int] = (1,)
85
87
  input_len: Tuple[int] = (1024,)
86
88
  output_len: Tuple[int] = (16,)
89
+ prompt_filename: str = ""
87
90
  result_filename: str = "result.jsonl"
88
91
  correctness_test: bool = False
89
92
  # This is only used for correctness test
90
93
  cut_len: int = 4
91
94
  log_decode_step: int = 0
92
95
  profile: bool = False
96
+ profile_record_shapes: bool = False
93
97
  profile_filename_prefix: str = "profile"
94
98
 
95
99
  @staticmethod
@@ -104,6 +108,9 @@ class BenchArgs:
104
108
  parser.add_argument(
105
109
  "--output-len", type=int, nargs="+", default=BenchArgs.output_len
106
110
  )
111
+ parser.add_argument(
112
+ "--prompt-filename", type=str, default=BenchArgs.prompt_filename
113
+ )
107
114
  parser.add_argument(
108
115
  "--result-filename", type=str, default=BenchArgs.result_filename
109
116
  )
@@ -118,6 +125,11 @@ class BenchArgs:
118
125
  parser.add_argument(
119
126
  "--profile", action="store_true", help="Use Torch Profiler."
120
127
  )
128
+ parser.add_argument(
129
+ "--profile-record-shapes",
130
+ action="store_true",
131
+ help="Record tensor shapes in profiling results.",
132
+ )
121
133
  parser.add_argument(
122
134
  "--profile-filename-prefix",
123
135
  type=str,
@@ -165,12 +177,16 @@ def load_model(server_args, port_args, tp_rank):
165
177
  return model_runner, tokenizer
166
178
 
167
179
 
168
- def prepare_inputs_for_correctness_test(bench_args, tokenizer):
169
- prompts = [
170
- "The capital of France is",
171
- "The capital of the United Kindom is",
172
- "Today is a sunny day and I like",
173
- ]
180
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
181
+ prompts = (
182
+ custom_prompts
183
+ if custom_prompts
184
+ else [
185
+ "The capital of France is",
186
+ "The capital of the United Kindom is",
187
+ "Today is a sunny day and I like",
188
+ ]
189
+ )
174
190
  input_ids = [tokenizer.encode(p) for p in prompts]
175
191
  sampling_params = SamplingParams(
176
192
  temperature=0,
@@ -211,8 +227,14 @@ def prepare_extend_inputs_for_correctness_test(
211
227
  return reqs
212
228
 
213
229
 
214
- def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
215
- input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
230
+ def prepare_synthetic_inputs_for_latency_test(
231
+ batch_size, input_len, custom_inputs=None
232
+ ):
233
+ input_ids = (
234
+ custom_inputs
235
+ if custom_inputs
236
+ else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
237
+ )
216
238
  sampling_params = SamplingParams(
217
239
  temperature=0,
218
240
  max_new_tokens=BenchArgs.output_len,
@@ -279,11 +301,40 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
279
301
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
280
302
  spec_algorithm=SpeculativeAlgorithm.NONE,
281
303
  speculative_num_draft_tokens=None,
304
+ enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
305
+ enable_deepep_moe=MoeA2ABackend(
306
+ model_runner.server_args.moe_a2a_backend
307
+ ).is_deepep(),
308
+ deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
282
309
  require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
283
310
  disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
284
311
  )
285
312
 
286
313
 
314
+ def _read_prompts_from_file(prompt_file, rank_print):
315
+ """Read custom prompts from the file specified by `--prompt-filename`."""
316
+ if not prompt_file:
317
+ return []
318
+ if not os.path.exists(prompt_file):
319
+ rank_print(
320
+ f"Custom prompt file {prompt_file} not found. Using default inputs..."
321
+ )
322
+ return []
323
+ with open(prompt_file, "r") as pf:
324
+ return pf.readlines()
325
+
326
+
327
+ def _save_profile_trace_results(profiler, filename):
328
+ parent_dir = os.path.dirname(os.path.abspath(filename))
329
+ os.makedirs(parent_dir, exist_ok=True)
330
+ profiler.export_chrome_trace(filename)
331
+ print(
332
+ profiler.key_averages(group_by_input_shape=True).table(
333
+ sort_by="self_cpu_time_total"
334
+ )
335
+ )
336
+
337
+
287
338
  def correctness_test(
288
339
  server_args,
289
340
  port_args,
@@ -298,7 +349,10 @@ def correctness_test(
298
349
  model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
299
350
 
300
351
  # Prepare inputs
301
- input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
352
+ custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
353
+ input_ids, reqs = prepare_inputs_for_correctness_test(
354
+ bench_args, tokenizer, custom_prompts
355
+ )
302
356
  rank_print(f"\n{input_ids=}\n")
303
357
 
304
358
  if bench_args.cut_len > 0:
@@ -344,6 +398,7 @@ def latency_test_run_once(
344
398
  device,
345
399
  log_decode_step,
346
400
  profile,
401
+ profile_record_shapes,
347
402
  profile_filename_prefix,
348
403
  ):
349
404
  max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
@@ -374,6 +429,7 @@ def latency_test_run_once(
374
429
  torch.profiler.ProfilerActivity.CUDA,
375
430
  ],
376
431
  with_stack=True,
432
+ record_shapes=profile_record_shapes,
377
433
  )
378
434
  profiler.start()
379
435
 
@@ -391,10 +447,30 @@ def latency_test_run_once(
391
447
  measurement_results["prefill_latency"] = prefill_latency
392
448
  measurement_results["prefill_throughput"] = throughput
393
449
 
450
+ if profile:
451
+ profiler.stop()
452
+ profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
453
+ _save_profile_trace_results(profiler, profile_filename)
454
+ rank_print(
455
+ f"torch profiler chrome trace for prefill saved to {profile_filename}"
456
+ )
457
+
394
458
  # Decode
395
459
  decode_latencies = []
396
460
  for i in range(output_len - 1):
397
461
  synchronize(device)
462
+ if profile and i == output_len / 2:
463
+ profiler = None
464
+ profiler = torch.profiler.profile(
465
+ activities=[
466
+ torch.profiler.ProfilerActivity.CPU,
467
+ torch.profiler.ProfilerActivity.CUDA,
468
+ ],
469
+ with_stack=True,
470
+ record_shapes=profile_record_shapes,
471
+ )
472
+ profiler.start()
473
+
398
474
  tic = time.perf_counter()
399
475
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
400
476
  synchronize(device)
@@ -407,13 +483,13 @@ def latency_test_run_once(
407
483
  f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
408
484
  )
409
485
 
410
- if profile:
411
- profiler.stop()
412
- profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
413
- parent_dir = os.path.dirname(os.path.abspath(profile_filename))
414
- os.makedirs(parent_dir, exist_ok=True)
415
- profiler.export_chrome_trace(profile_filename)
416
- rank_print(f"torch profiler chrome trace saved to {profile_filename}")
486
+ if profile and i == output_len / 2:
487
+ profiler.stop()
488
+ profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
489
+ _save_profile_trace_results(profiler, profile_filename)
490
+ rank_print(
491
+ f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
492
+ )
417
493
 
418
494
  # Record decode timing from 2nd output
419
495
  if output_len > 1:
@@ -469,17 +545,42 @@ def latency_test(
469
545
  server_args.device,
470
546
  log_decode_step=0,
471
547
  profile=False,
548
+ profile_record_shapes=False,
472
549
  profile_filename_prefix="", # not used
473
550
  )
474
551
 
475
552
  rank_print("Benchmark ...")
476
553
 
554
+ custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
555
+ custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
556
+ custom_input_len = len(custom_inputs)
557
+
477
558
  # Run the sweep
478
559
  result_list = []
479
560
  for bs, il, ol in itertools.product(
480
561
  bench_args.batch_size, bench_args.input_len, bench_args.output_len
481
562
  ):
482
- reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
563
+ bs_aligned_inputs = []
564
+ if custom_inputs:
565
+ if custom_input_len == bs:
566
+ bs_aligned_inputs = custom_inputs
567
+ elif custom_input_len > bs:
568
+ rank_print(
569
+ f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
570
+ f"Using the first {bs} prompts."
571
+ )
572
+ bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
573
+ else:
574
+ rank_print(
575
+ f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
576
+ f"Pad to the desired batch_size with the last prompt."
577
+ )
578
+ bs_aligned_inputs = copy.deepcopy(custom_inputs)
579
+ bs_aligned_inputs.extend(
580
+ [bs_aligned_inputs[-1]] * (bs - custom_input_len)
581
+ )
582
+
583
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
483
584
  ret = latency_test_run_once(
484
585
  bench_args.run_name,
485
586
  model_runner,
@@ -491,6 +592,7 @@ def latency_test(
491
592
  server_args.device,
492
593
  bench_args.log_decode_step,
493
594
  bench_args.profile if tp_rank == 0 else None,
595
+ bench_args.profile_record_shapes if tp_rank == 0 else None,
494
596
  bench_args.profile_filename_prefix,
495
597
  )
496
598
  if ret is not None:
@@ -505,6 +505,22 @@ register_chat_template(
505
505
  )
506
506
  )
507
507
 
508
+ # Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
509
+ register_chat_template(
510
+ ChatTemplate(
511
+ name="glm-4v",
512
+ default_system_prompt=None,
513
+ role_prefix_and_suffix={
514
+ "system": ("<|system|>\n", "\n"),
515
+ "user": ("<|user|>\n", "\n"),
516
+ "assistant": ("<|assistant|>\n", "\n"),
517
+ },
518
+ style=ChatTemplateStyle.PLAIN,
519
+ stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
520
+ image_token="<|image|>",
521
+ )
522
+ )
523
+
508
524
 
509
525
  @register_chat_template_matching_function
510
526
  def match_deepseek(model_path: str):
@@ -562,6 +578,8 @@ def match_chat_ml(model_path: str):
562
578
  return "chatml"
563
579
  if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
564
580
  return "qwen2-vl"
581
+ if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
582
+ return "glm-4v"
565
583
  if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
566
584
  r"llava", model_path, re.IGNORECASE
567
585
  ):
@@ -0,0 +1,137 @@
1
+ import os
2
+ import sys
3
+ from contextlib import nullcontext
4
+
5
+ import torch
6
+
7
+
8
+ # NOTE copied and modified from DeepGEMM
9
+ class suppress_stdout_stderr:
10
+ def __enter__(self):
11
+ self.outnull_file = open(os.devnull, "w")
12
+ self.errnull_file = open(os.devnull, "w")
13
+
14
+ self.old_stdout_fileno_undup = sys.stdout.fileno()
15
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
16
+
17
+ self.old_stdout_fileno = os.dup(sys.stdout.fileno())
18
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
19
+
20
+ self.old_stdout = sys.stdout
21
+ self.old_stderr = sys.stderr
22
+
23
+ os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
24
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
25
+
26
+ sys.stdout = self.outnull_file
27
+ sys.stderr = self.errnull_file
28
+ return self
29
+
30
+ def __exit__(self, *_):
31
+ sys.stdout = self.old_stdout
32
+ sys.stderr = self.old_stderr
33
+
34
+ os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
35
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
36
+
37
+ os.close(self.old_stdout_fileno)
38
+ os.close(self.old_stderr_fileno)
39
+
40
+ self.outnull_file.close()
41
+ self.errnull_file.close()
42
+
43
+
44
+ # NOTE copied and modified from DeepGEMM
45
+ def bench_kineto(
46
+ fn,
47
+ kernel_names,
48
+ num_tests: int = 30,
49
+ suppress_kineto_output: bool = False,
50
+ trace_path: str = None,
51
+ flush_l2: bool = True,
52
+ with_multiple_kernels: bool = False,
53
+ ):
54
+ # Conflict with Nsight Systems
55
+ using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
56
+
57
+ # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
58
+ flush_l2_size = int(8e9 // 4)
59
+
60
+ # For some auto-tuning kernels with prints
61
+ fn()
62
+
63
+ # Profile
64
+ suppress = (
65
+ suppress_stdout_stderr
66
+ if suppress_kineto_output and not using_nsys
67
+ else nullcontext
68
+ )
69
+ with suppress():
70
+ schedule = (
71
+ torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
72
+ if not using_nsys
73
+ else None
74
+ )
75
+ profiler = (
76
+ torch.profiler.profile(
77
+ activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
78
+ )
79
+ if not using_nsys
80
+ else nullcontext()
81
+ )
82
+ with profiler:
83
+ for i in range(2):
84
+ for _ in range(num_tests):
85
+ if flush_l2:
86
+ torch.empty(
87
+ flush_l2_size, dtype=torch.int, device="cuda"
88
+ ).zero_()
89
+ fn()
90
+
91
+ if not using_nsys:
92
+ profiler.step()
93
+
94
+ # Return 1 if using Nsight Systems
95
+ if using_nsys:
96
+ return 1
97
+
98
+ # Parse the profiling table
99
+ assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
100
+ is_tuple = isinstance(kernel_names, tuple)
101
+ prof_lines = (
102
+ profiler.key_averages()
103
+ .table(sort_by="cuda_time_total", max_name_column_width=100)
104
+ .split("\n")
105
+ )
106
+ kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
107
+ assert all([isinstance(name, str) for name in kernel_names])
108
+ if not with_multiple_kernels:
109
+ for name in kernel_names:
110
+ assert (
111
+ sum([name in line for line in prof_lines]) == 1
112
+ ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
113
+
114
+ # Save chrome traces
115
+ if trace_path is not None:
116
+ profiler.export_chrome_trace(trace_path)
117
+
118
+ # Return average kernel times
119
+ units = {"ms": 1e3, "us": 1e6}
120
+ kernel_times = []
121
+ for name in kernel_names:
122
+ total_time = 0
123
+ total_num = 0
124
+ for line in prof_lines:
125
+ if name in line:
126
+ time_str = line.split()[-2]
127
+ num_str = line.split()[-1]
128
+ for unit, scale in units.items():
129
+ if unit in time_str:
130
+ total_time += (
131
+ float(time_str.replace(unit, "")) / scale * int(num_str)
132
+ )
133
+ total_num += int(num_str)
134
+ break
135
+ kernel_times.append(total_time / total_num)
136
+
137
+ return tuple(kernel_times) if is_tuple else kernel_times[0]
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
27
27
  get_context_length,
28
28
  get_generation_config,
29
29
  get_hf_text_config,
30
+ get_sparse_attention_config,
30
31
  )
31
32
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
32
33
  from sglang.srt.server_args import ServerArgs
@@ -63,13 +64,12 @@ class ModelConfig:
63
64
  hybrid_kvcache_ratio: Optional[float] = None,
64
65
  model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
65
66
  ) -> None:
66
-
67
+ # Parse args
67
68
  self.model_path = model_path
68
69
  self.revision = revision
69
70
  self.quantization = quantization
70
71
  self.model_impl = model_impl
71
72
 
72
- # Parse args
73
73
  self.maybe_pull_model_tokenizer_from_remote()
74
74
  self.model_override_args = json.loads(model_override_args)
75
75
  kwargs = {}
@@ -133,6 +133,12 @@ class ModelConfig:
133
133
 
134
134
  if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
135
135
  self.hf_config.architectures[0] = "MiMoMTP"
136
+ if (
137
+ is_draft_model
138
+ and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
139
+ ):
140
+ self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
141
+
136
142
  # Check model type
137
143
  self.is_generation = is_generation_model(
138
144
  self.hf_config.architectures, is_embedding
@@ -270,15 +276,16 @@ class ModelConfig:
270
276
  # Verify quantization
271
277
  self._verify_quantization()
272
278
 
279
+ # Verify dual-chunk attention config
280
+ self._verify_dual_chunk_attention_config()
281
+
273
282
  # Cache attributes
274
283
  self.hf_eos_token_id = self.get_hf_eos_token_id()
275
284
 
276
- config = self.hf_config
277
-
278
285
  # multimodal
279
- self.image_token_id = getattr(config, "image_token_id", None) or getattr(
280
- config, "image_token_index", None
281
- )
286
+ self.image_token_id = getattr(
287
+ self.hf_config, "image_token_id", None
288
+ ) or getattr(self.hf_config, "image_token_index", None)
282
289
 
283
290
  @staticmethod
284
291
  def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
@@ -297,6 +304,13 @@ class ModelConfig:
297
304
  **kwargs,
298
305
  )
299
306
 
307
+ def get_total_num_attention_heads(self) -> int:
308
+ return self.num_attention_heads
309
+
310
+ def get_num_attention_heads(self, tensor_parallel_size) -> int:
311
+ total_num_attention_heads = self.num_attention_heads
312
+ return max(1, total_num_attention_heads // tensor_parallel_size)
313
+
300
314
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
301
315
  def get_total_num_kv_heads(self) -> int:
302
316
  """Returns the total number of KV heads."""
@@ -401,6 +415,8 @@ class ModelConfig:
401
415
  "fbgemm_fp8",
402
416
  "w8a8_fp8",
403
417
  "petit_nvfp4",
418
+ "quark",
419
+ "mxfp4",
404
420
  ]
405
421
  optimized_quantization_methods = [
406
422
  "fp8",
@@ -482,6 +498,23 @@ class ModelConfig:
482
498
  self.quantization,
483
499
  )
484
500
 
501
+ def _verify_dual_chunk_attention_config(self) -> None:
502
+ if hasattr(self.hf_config, "dual_chunk_attention_config"):
503
+ # Try loading the sparse attention config
504
+ sparse_attn_config = get_sparse_attention_config(self.model_path)
505
+ if not sparse_attn_config:
506
+ return
507
+ self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
508
+ sparse_attn_config
509
+ )
510
+ if (
511
+ "sparse_attention_enabled"
512
+ not in self.hf_config.dual_chunk_attention_config
513
+ ):
514
+ self.hf_config.dual_chunk_attention_config[
515
+ "sparse_attention_enabled"
516
+ ] = True
517
+
485
518
  def get_hf_eos_token_id(self) -> Optional[Set[int]]:
486
519
  eos_ids = getattr(self.hf_config, "eos_token_id", None)
487
520
  if eos_ids is not None:
@@ -626,6 +659,8 @@ multimodal_model_archs = [
626
659
  "DeepseekVL2ForCausalLM",
627
660
  "Gemma3ForConditionalGeneration",
628
661
  "Gemma3nForConditionalGeneration",
662
+ "Glm4vForConditionalGeneration",
663
+ "Glm4vMoeForConditionalGeneration",
629
664
  "Grok1VForCausalLM",
630
665
  "Grok1AForCausalLM",
631
666
  "LlavaLlamaForCausalLM",
@@ -30,8 +30,10 @@ import re
30
30
  from enum import IntEnum, auto
31
31
  from typing import Callable, Dict, List, Optional, Tuple, Union
32
32
 
33
+ from typing_extensions import Literal
34
+
33
35
  from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
34
- from sglang.srt.utils import read_system_prompt_from_file
36
+ from sglang.srt.utils import ImageData, read_system_prompt_from_file
35
37
 
36
38
 
37
39
  class SeparatorStyle(IntEnum):
@@ -91,7 +93,7 @@ class Conversation:
91
93
  video_token: str = "<video>"
92
94
  audio_token: str = "<audio>"
93
95
 
94
- image_data: Optional[List[str]] = None
96
+ image_data: Optional[List[ImageData]] = None
95
97
  video_data: Optional[List[str]] = None
96
98
  modalities: Optional[List[str]] = None
97
99
  stop_token_ids: Optional[int] = None
@@ -381,9 +383,9 @@ class Conversation:
381
383
  """Append a new message."""
382
384
  self.messages.append([role, message])
383
385
 
384
- def append_image(self, image: str):
386
+ def append_image(self, image: str, detail: Literal["auto", "low", "high"]):
385
387
  """Append a new image."""
386
- self.image_data.append(image)
388
+ self.image_data.append(ImageData(url=image, detail=detail))
387
389
 
388
390
  def append_video(self, video: str):
389
391
  """Append a new video."""
@@ -627,7 +629,9 @@ def generate_chat_conv(
627
629
  real_content = image_token + real_content
628
630
  else:
629
631
  real_content += image_token
630
- conv.append_image(content.image_url.url)
632
+ conv.append_image(
633
+ content.image_url.url, content.image_url.detail
634
+ )
631
635
  elif content.type == "video_url":
632
636
  real_content += video_token
633
637
  conv.append_video(content.video_url.url)
@@ -25,10 +25,13 @@ class KVArgs:
25
25
  gpu_id: int
26
26
  # for different tp
27
27
  decode_tp_size: int
28
- # for pp prefill
29
- prefill_pp_size: int
30
28
  kv_head_num: int
31
29
  page_size: int
30
+ # for pp prefill
31
+ prefill_pp_size: int
32
+ pp_rank: int
33
+ # for system dp
34
+ system_dp_rank: int
32
35
 
33
36
 
34
37
  class KVPoll: