sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -11,6 +11,11 @@ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruc
11
11
  python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
12
  ## run with profiling:
13
13
  python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
14
+ ## run with profiling to custom directory:
15
+ export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
16
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
17
+ ## run with CUDA profiler (nsys):
18
+ nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profiler_activities CUDA_PROFILER
14
19
  # Usage (correctness test):
15
20
  python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
16
21
 
@@ -93,6 +98,68 @@ profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
93
98
  ]
94
99
 
95
100
 
101
+ def start_profile(profiler_activities, profile_record_shapes=False, rank_print=print):
102
+ """
103
+ Abstracted function to start profiling based on profiler_activities.
104
+ Returns profiler object (or None).
105
+ """
106
+ if "CUDA_PROFILER" in profiler_activities:
107
+ try:
108
+ torch.cuda.cudart().cudaProfilerStart()
109
+ rank_print("CUDA Profiler started (nsys will begin capturing)")
110
+ except Exception as e:
111
+ rank_print(f"Failed to start CUDA profiler: {e}")
112
+ return None
113
+ else:
114
+ activities = []
115
+ if "CPU" in profiler_activities:
116
+ activities.append(torch.profiler.ProfilerActivity.CPU)
117
+ if "GPU" in profiler_activities:
118
+ activities.append(torch.profiler.ProfilerActivity.CUDA)
119
+ if activities:
120
+ profiler = torch.profiler.profile(
121
+ activities=activities,
122
+ with_stack=True,
123
+ record_shapes=profile_record_shapes,
124
+ )
125
+ profiler.start()
126
+ return profiler
127
+ return None
128
+
129
+
130
+ def stop_profile(
131
+ profiler,
132
+ profiler_activities,
133
+ rank_print=print,
134
+ save_trace=False,
135
+ trace_filename=None,
136
+ stage=None,
137
+ ):
138
+ """
139
+ Abstracted function to stop profiling based on profiler_activities.
140
+ Optionally saves trace results and prints completion messages.
141
+ """
142
+ if "CUDA_PROFILER" in profiler_activities:
143
+ try:
144
+ torch.cuda.cudart().cudaProfilerStop()
145
+ rank_print("CUDA Profiler stopped (nsys should dump traces)")
146
+ except Exception as e:
147
+ rank_print(f"Failed to stop CUDA profiler: {e}")
148
+ elif profiler is not None:
149
+ profiler.stop()
150
+
151
+ if save_trace:
152
+ if profiler is not None:
153
+ if trace_filename:
154
+ _save_profile_trace_results(profiler, trace_filename)
155
+ stage_desc = f"for {stage}" if stage else ""
156
+ rank_print(
157
+ f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
158
+ )
159
+ if "CUDA_PROFILER" in profiler_activities:
160
+ rank_print(f"CUDA profiler trace for {stage} completed")
161
+
162
+
96
163
  @dataclasses.dataclass
97
164
  class BenchArgs:
98
165
  run_name: str = "default"
@@ -107,6 +174,8 @@ class BenchArgs:
107
174
  log_decode_step: int = 0
108
175
  profile: bool = False
109
176
  profile_record_shapes: bool = False
177
+ profiler_activities: Tuple[str] = ("CPU", "GPU")
178
+ profile_stage: str = "all"
110
179
  profile_filename_prefix: str = "profile"
111
180
 
112
181
  @staticmethod
@@ -135,14 +204,27 @@ class BenchArgs:
135
204
  default=BenchArgs.log_decode_step,
136
205
  help="Log decode latency by step, default is set to zero to disable.",
137
206
  )
138
- parser.add_argument(
139
- "--profile", action="store_true", help="Use Torch Profiler."
140
- )
207
+ parser.add_argument("--profile", action="store_true", help="Enable profiling.")
141
208
  parser.add_argument(
142
209
  "--profile-record-shapes",
143
210
  action="store_true",
144
211
  help="Record tensor shapes in profiling results.",
145
212
  )
213
+ parser.add_argument(
214
+ "--profiler_activities",
215
+ type=str,
216
+ nargs="+",
217
+ default=["CPU", "GPU"],
218
+ choices=["CPU", "GPU", "CUDA_PROFILER"],
219
+ help="Profiler activities: CPU, GPU, CUDA_PROFILER. If CPU/GPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
220
+ )
221
+ parser.add_argument(
222
+ "--profile-stage",
223
+ type=str,
224
+ default=BenchArgs.profile_stage,
225
+ choices=["all", "prefill", "decode"],
226
+ help="Which stage to profile: all, prefill, or decode only.",
227
+ )
146
228
  parser.add_argument(
147
229
  "--profile-filename-prefix",
148
230
  type=str,
@@ -337,6 +419,18 @@ def _read_prompts_from_file(prompt_file, rank_print):
337
419
  return pf.readlines()
338
420
 
339
421
 
422
+ def _get_torch_profiler_output_dir():
423
+ return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
424
+
425
+
426
+ def _create_torch_profiler_filename(
427
+ profile_filename_prefix, batch_size, input_len, output_len, stage
428
+ ):
429
+ output_dir = _get_torch_profiler_output_dir()
430
+ filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
431
+ return os.path.join(output_dir, filename)
432
+
433
+
340
434
  def _save_profile_trace_results(profiler, filename):
341
435
  parent_dir = os.path.dirname(os.path.abspath(filename))
342
436
  os.makedirs(parent_dir, exist_ok=True)
@@ -413,7 +507,10 @@ def latency_test_run_once(
413
507
  log_decode_step,
414
508
  profile,
415
509
  profile_record_shapes,
510
+ profiler_activities,
416
511
  profile_filename_prefix,
512
+ profile_stage,
513
+ tp_rank,
417
514
  ):
418
515
  max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
419
516
  if batch_size > max_batch_size:
@@ -422,7 +519,6 @@ def latency_test_run_once(
422
519
  )
423
520
  return
424
521
 
425
- # Clear the pools.
426
522
  model_runner.req_to_token_pool.clear()
427
523
  model_runner.token_to_kv_pool_allocator.clear()
428
524
 
@@ -436,20 +532,33 @@ def latency_test_run_once(
436
532
  tot_latency = 0
437
533
 
438
534
  profiler = None
439
- if profile:
440
- profiler = torch.profiler.profile(
441
- activities=profile_activities,
442
- with_stack=True,
443
- record_shapes=profile_record_shapes,
535
+ enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
536
+ if enable_profile_prefill:
537
+ profiler = start_profile(
538
+ profiler_activities,
539
+ profile_record_shapes=profile_record_shapes,
540
+ rank_print=rank_print,
444
541
  )
445
- profiler.start()
446
542
 
447
- # Prefill
448
543
  synchronize(device)
449
544
  tic = time.perf_counter()
450
545
  next_token_ids, _, batch = extend(reqs, model_runner)
451
546
  synchronize(device)
452
547
  prefill_latency = time.perf_counter() - tic
548
+
549
+ if enable_profile_prefill:
550
+ trace_filename = _create_torch_profiler_filename(
551
+ profile_filename_prefix, batch_size, input_len, output_len, "prefill"
552
+ )
553
+ stop_profile(
554
+ profiler,
555
+ profiler_activities,
556
+ rank_print=rank_print,
557
+ save_trace=True,
558
+ trace_filename=trace_filename,
559
+ stage="prefill",
560
+ )
561
+
453
562
  tot_latency += prefill_latency
454
563
  throughput = input_len * batch_size / prefill_latency
455
564
  rank_print(
@@ -458,29 +567,37 @@ def latency_test_run_once(
458
567
  measurement_results["prefill_latency"] = prefill_latency
459
568
  measurement_results["prefill_throughput"] = throughput
460
569
 
461
- if profile:
462
- profiler.stop()
463
- trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
464
- _save_profile_trace_results(profiler, trace_filename)
465
- rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
466
-
467
- # Decode
468
570
  decode_latencies = []
571
+ profile_step_of_interest = output_len // 2
572
+ enable_profile_decode = profile and profile_stage in ["all", "decode"]
469
573
  for i in range(output_len - 1):
470
574
  synchronize(device)
471
- if profile and i == output_len / 2:
472
- profiler = None
473
- profiler = torch.profiler.profile(
474
- activities=profile_activities,
475
- with_stack=True,
476
- record_shapes=profile_record_shapes,
575
+ profiler = None
576
+ if enable_profile_decode and i == profile_step_of_interest:
577
+ profiler = start_profile(
578
+ profiler_activities,
579
+ profile_record_shapes=profile_record_shapes,
580
+ rank_print=rank_print,
477
581
  )
478
- profiler.start()
479
582
 
480
583
  tic = time.perf_counter()
481
584
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
482
585
  synchronize(device)
483
586
  latency = time.perf_counter() - tic
587
+
588
+ if enable_profile_decode and i == profile_step_of_interest:
589
+ trace_filename = _create_torch_profiler_filename(
590
+ profile_filename_prefix, batch_size, input_len, output_len, "decode"
591
+ )
592
+ stop_profile(
593
+ profiler,
594
+ profiler_activities,
595
+ rank_print=rank_print,
596
+ save_trace=True,
597
+ trace_filename=trace_filename,
598
+ stage="decode",
599
+ )
600
+
484
601
  tot_latency += latency
485
602
  throughput = batch_size / latency
486
603
  decode_latencies.append(latency)
@@ -489,14 +606,6 @@ def latency_test_run_once(
489
606
  f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
490
607
  )
491
608
 
492
- if profile and i == output_len / 2:
493
- profiler.stop()
494
- trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
495
- _save_profile_trace_results(profiler, trace_filename)
496
- rank_print(
497
- f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
498
- )
499
-
500
609
  # Record decode timing from 2nd output
501
610
  if output_len > 1:
502
611
  med_decode_latency = np.median(decode_latencies)
@@ -557,7 +666,10 @@ def latency_test(
557
666
  log_decode_step=0,
558
667
  profile=False,
559
668
  profile_record_shapes=False,
560
- profile_filename_prefix="", # not used
669
+ profiler_activities=("CPU", "GPU"),
670
+ profile_filename_prefix="",
671
+ profile_stage="all",
672
+ tp_rank=tp_rank,
561
673
  )
562
674
 
563
675
  rank_print("Benchmark ...")
@@ -604,7 +716,10 @@ def latency_test(
604
716
  bench_args.log_decode_step,
605
717
  bench_args.profile if tp_rank == 0 else None,
606
718
  bench_args.profile_record_shapes if tp_rank == 0 else None,
719
+ bench_args.profiler_activities,
607
720
  bench_args.profile_filename_prefix,
721
+ bench_args.profile_stage,
722
+ tp_rank,
608
723
  )
609
724
  if ret is not None:
610
725
  result_list.append(ret)
sglang/bench_serving.py CHANGED
@@ -88,6 +88,7 @@ class RequestFuncOutput:
88
88
  latency: float = 0.0
89
89
  ttft: float = 0.0 # Time to first token
90
90
  itl: List[float] = field(default_factory=list) # List of inter-token latencies
91
+ text_chunks: List[str] = field(default_factory=list)
91
92
  prompt_len: int = 0
92
93
  error: str = ""
93
94
  output_len: int = 0
@@ -258,6 +259,9 @@ async def async_request_openai_completions(
258
259
 
259
260
  # Decoding phase
260
261
  else:
262
+ output.text_chunks.append(
263
+ data["choices"][0]["text"]
264
+ )
261
265
  output.itl.append(timestamp - most_recent_timestamp)
262
266
 
263
267
  most_recent_timestamp = timestamp
@@ -574,9 +578,8 @@ async def async_request_sglang_generate(
574
578
  num_new_tokens = output_len - last_output_len
575
579
  if num_new_tokens == 0:
576
580
  continue
577
- adjust_itl = (
578
- timestamp - most_recent_timestamp
579
- ) / num_new_tokens
581
+ chunk_gap = timestamp - most_recent_timestamp
582
+ adjust_itl = chunk_gap / num_new_tokens
580
583
  output.itl.extend([adjust_itl] * num_new_tokens)
581
584
 
582
585
  most_recent_timestamp = timestamp
@@ -764,6 +767,7 @@ def get_dataset(args, tokenizer, model_id=None):
764
767
  image_content=args.image_content,
765
768
  image_format=args.image_format,
766
769
  image_resolution=args.image_resolution,
770
+ backend=args.backend,
767
771
  )
768
772
  elif args.dataset_name == "generated-shared-prefix":
769
773
  assert not tokenize_prompt
@@ -781,6 +785,7 @@ def get_dataset(args, tokenizer, model_id=None):
781
785
  input_requests = sample_mmmu_requests(
782
786
  num_requests=args.num_prompts,
783
787
  processor=processor,
788
+ backend=args.backend,
784
789
  fixed_output_len=args.random_output_len,
785
790
  random_sample=True,
786
791
  )
@@ -1009,6 +1014,7 @@ async def get_mooncake_request_over_time(
1009
1014
  def sample_mmmu_requests(
1010
1015
  num_requests: int,
1011
1016
  processor: AutoProcessor | AutoTokenizer,
1017
+ backend: str = "sglang",
1012
1018
  fixed_output_len: Optional[int] = None,
1013
1019
  random_sample: bool = True,
1014
1020
  ) -> List[DatasetRow]:
@@ -1081,7 +1087,7 @@ def sample_mmmu_requests(
1081
1087
  text_prompt = f"Question: {question}\n\nAnswer: "
1082
1088
  output_len = fixed_output_len if fixed_output_len is not None else 256
1083
1089
  data_row = create_mm_data_row(
1084
- text_prompt, [image], [image_data], output_len, processor
1090
+ text_prompt, [image], [image_data], output_len, processor, backend
1085
1091
  )
1086
1092
  filtered_dataset.append(data_row)
1087
1093
 
@@ -1316,13 +1322,19 @@ def parse_image_resolution(image_resolution: str) -> Tuple[int, int]:
1316
1322
  )
1317
1323
 
1318
1324
 
1319
- def create_mm_data_row(text_prompt, images: list, images_base64, output_len, processor):
1325
+ def create_mm_data_row(
1326
+ text_prompt, images: list, images_base64, output_len, processor, backend
1327
+ ):
1320
1328
  try:
1321
- content_items = [
1322
- {"type": "image", "image": {"url": image_base64}}
1323
- for image_base64 in images_base64
1324
- ]
1325
- content_items.append({"type": "text", "text": text_prompt})
1329
+ if type(processor).__name__ == "Phi4MMProcessor":
1330
+ # <|endoftext10|> is the image token used in the phi-4-multimodal model.
1331
+ content_items = text_prompt.replace("image 1", "|endoftext10|")
1332
+ else:
1333
+ content_items = [
1334
+ {"type": "image", "image": {"url": image_base64}}
1335
+ for image_base64 in images_base64
1336
+ ]
1337
+ content_items.append({"type": "text", "text": text_prompt})
1326
1338
  prompt_str = processor.apply_chat_template(
1327
1339
  [{"role": "user", "content": content_items}],
1328
1340
  add_generation_prompt=True,
@@ -1357,13 +1369,24 @@ def create_mm_data_row(text_prompt, images: list, images_base64, output_len, pro
1357
1369
  )["input_ids"].numel()
1358
1370
  except Exception:
1359
1371
  # Fallback: just tokenize the text prompt directly
1360
- text_prompt_len = len(processor.tokenizer.encode(text_prompt))
1372
+ tokenizer_to_use = (
1373
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor
1374
+ )
1375
+ text_prompt_len = len(tokenizer_to_use.encode(text_prompt))
1361
1376
 
1362
1377
  # Vision tokens = total tokens - text tokens
1363
1378
  vision_prompt_len = prompt_len - text_prompt_len
1364
1379
 
1380
+ use_raw_prompt = backend in [
1381
+ "sglang-oai",
1382
+ "sglang-oai-chat",
1383
+ "vllm",
1384
+ "vllm-chat",
1385
+ "lmdeploy",
1386
+ "lmdeploy-chat",
1387
+ ]
1365
1388
  return DatasetRow(
1366
- prompt=text_prompt,
1389
+ prompt=text_prompt if use_raw_prompt else prompt_str,
1367
1390
  prompt_len=prompt_len,
1368
1391
  output_len=output_len,
1369
1392
  text_prompt_len=text_prompt_len,
@@ -1382,6 +1405,7 @@ def sample_image_requests(
1382
1405
  image_content: str,
1383
1406
  image_format: str,
1384
1407
  image_resolution: str,
1408
+ backend: str,
1385
1409
  ) -> List[DatasetRow]:
1386
1410
  """Generate requests with images.
1387
1411
 
@@ -1447,6 +1471,7 @@ def sample_image_requests(
1447
1471
  list(images_base64),
1448
1472
  int(output_lens[i]),
1449
1473
  processor,
1474
+ backend,
1450
1475
  )
1451
1476
 
1452
1477
  dataset.append(data_row)
@@ -1607,6 +1632,7 @@ def calculate_metrics(
1607
1632
  dur_s: float,
1608
1633
  tokenizer: PreTrainedTokenizerBase,
1609
1634
  backend: str,
1635
+ accept_length: Optional[float] = None,
1610
1636
  ) -> Tuple[BenchmarkMetrics, List[int]]:
1611
1637
  output_lens: List[int] = []
1612
1638
  retokenized_output_lens: List[int] = []
@@ -1618,6 +1644,14 @@ def calculate_metrics(
1618
1644
  tpots: List[float] = []
1619
1645
  ttfts: List[float] = []
1620
1646
  e2e_latencies: List[float] = []
1647
+ retokenized_itls: List[float] = []
1648
+
1649
+ use_retokenized_itl = (
1650
+ accept_length is not None
1651
+ and accept_length > 0
1652
+ and backend in ("sglang-oai", "sglang-oai-chat")
1653
+ )
1654
+
1621
1655
  for i in range(len(outputs)):
1622
1656
  if outputs[i].success:
1623
1657
  output_len = outputs[i].output_len
@@ -1631,7 +1665,17 @@ def calculate_metrics(
1631
1665
  total_input_vision += input_requests[i].vision_prompt_len
1632
1666
  if output_len > 1:
1633
1667
  tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
1634
- itls += outputs[i].itl
1668
+ if use_retokenized_itl:
1669
+ for k, itl in enumerate(outputs[i].itl):
1670
+ num_tokens = len(
1671
+ tokenizer.encode(
1672
+ outputs[i].text_chunks[k], add_special_tokens=False
1673
+ )
1674
+ )
1675
+ adjusted_itl = itl / num_tokens
1676
+ retokenized_itls.extend([adjusted_itl] * num_tokens)
1677
+ else:
1678
+ itls += outputs[i].itl
1635
1679
  ttfts.append(outputs[i].ttft)
1636
1680
 
1637
1681
  e2e_latencies.append(outputs[i].latency)
@@ -1647,6 +1691,8 @@ def calculate_metrics(
1647
1691
  "on the benchmark arguments.",
1648
1692
  stacklevel=2,
1649
1693
  )
1694
+
1695
+ itls = retokenized_itls if use_retokenized_itl else itls
1650
1696
  metrics = BenchmarkMetrics(
1651
1697
  completed=completed,
1652
1698
  total_input=total_input,
@@ -1910,6 +1956,7 @@ async def benchmark(
1910
1956
  dur_s=benchmark_duration,
1911
1957
  tokenizer=tokenizer,
1912
1958
  backend=backend,
1959
+ accept_length=accept_length,
1913
1960
  )
1914
1961
 
1915
1962
  print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
@@ -1989,6 +2036,7 @@ async def benchmark(
1989
2036
  ):
1990
2037
  result = {
1991
2038
  # Arguments
2039
+ "tag": getattr(args, "tag", None),
1992
2040
  "backend": args.backend,
1993
2041
  "dataset_name": args.dataset_name,
1994
2042
  "request_rate": "trace" if use_trace_timestamps else request_rate,
@@ -2114,6 +2162,9 @@ def run_benchmark(args_: argparse.Namespace):
2114
2162
  if not hasattr(args, "mooncake_num_rounds"):
2115
2163
  args.mooncake_num_rounds = 1
2116
2164
 
2165
+ if not hasattr(args, "served_model_name"):
2166
+ args.served_model_name = None
2167
+
2117
2168
  print(f"benchmark_args={args}")
2118
2169
 
2119
2170
  # Set global environments
@@ -2227,7 +2278,7 @@ def run_benchmark(args_: argparse.Namespace):
2227
2278
 
2228
2279
  # Read dataset
2229
2280
  backend = args.backend
2230
- model_id = args.model
2281
+ model_id = args.served_model_name or args.model
2231
2282
  tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
2232
2283
  tokenizer = get_tokenizer(tokenizer_id)
2233
2284
  input_requests = get_dataset(args, tokenizer, model_id)
@@ -2326,6 +2377,11 @@ if __name__ == "__main__":
2326
2377
  type=str,
2327
2378
  help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
2328
2379
  )
2380
+ parser.add_argument(
2381
+ "--served-model-name",
2382
+ type=str,
2383
+ help="The name of the model as served by the serving service. If not set, this defaults to the value of --model.",
2384
+ )
2329
2385
  parser.add_argument(
2330
2386
  "--tokenizer",
2331
2387
  type=str,
@@ -2583,5 +2639,8 @@ if __name__ == "__main__":
2583
2639
  ],
2584
2640
  help="Underlying workload for the mooncake dataset.",
2585
2641
  )
2642
+ parser.add_argument(
2643
+ "--tag", type=str, default=None, help="The tag to be dumped to output."
2644
+ )
2586
2645
  args = parser.parse_args()
2587
2646
  run_benchmark(args)
@@ -104,15 +104,21 @@ def launch_server_process_and_send_one_request(
104
104
  if response.status_code == 200:
105
105
  # Rank-0 node send a request to sync with other node and then return.
106
106
  if server_args.node_rank == 0:
107
+ payload = {
108
+ "input_ids": [0, 1, 2, 3],
109
+ "sampling_params": {
110
+ "max_new_tokens": 8,
111
+ "temperature": 0,
112
+ },
113
+ }
114
+ # In PD mode, include fake bootstrap fields so workers don't assert
115
+ if server_args.disaggregation_mode != "null":
116
+ payload["bootstrap_host"] = FAKE_BOOTSTRAP_HOST
117
+ payload["bootstrap_room"] = 0
118
+
107
119
  response = requests.post(
108
120
  f"{base_url}/generate",
109
- json={
110
- "input_ids": [0, 1, 2, 3],
111
- "sampling_params": {
112
- "max_new_tokens": 8,
113
- "temperature": 0,
114
- },
115
- },
121
+ json=payload,
116
122
  timeout=600,
117
123
  )
118
124
  if response.status_code != 200:
sglang/launch_server.py CHANGED
@@ -12,10 +12,12 @@ if __name__ == "__main__":
12
12
 
13
13
  try:
14
14
  if server_args.grpc_mode:
15
+ # Handle gRPC server
15
16
  from sglang.srt.entrypoints.grpc_server import serve_grpc
16
17
 
17
18
  asyncio.run(serve_grpc(server_args))
18
19
  else:
20
+ # Handle HTTP server
19
21
  from sglang.srt.entrypoints.http_server import launch_server
20
22
 
21
23
  launch_server(server_args)
@@ -9,6 +9,7 @@ from .batch_invariant_ops import (
9
9
  log_softmax,
10
10
  matmul_persistent,
11
11
  mean_dim,
12
+ rms_norm_batch_invariant,
12
13
  set_batch_invariant_mode,
13
14
  )
14
15
 
@@ -24,4 +25,5 @@ __all__ = [
24
25
  "mean_dim",
25
26
  "get_batch_invariant_attention_block_size",
26
27
  "AttentionBlockSize",
28
+ "rms_norm_batch_invariant",
27
29
  ]