sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -47,6 +47,7 @@ class BenchArgs:
47
47
  profile: bool = False
48
48
  profile_steps: int = 3
49
49
  profile_by_stage: bool = False
50
+ dataset_path: str = ""
50
51
 
51
52
  @staticmethod
52
53
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -83,6 +84,12 @@ class BenchArgs:
83
84
  "--profile-steps", type=int, default=BenchArgs.profile_steps
84
85
  )
85
86
  parser.add_argument("--profile-by-stage", action="store_true")
87
+ parser.add_argument(
88
+ "--dataset-path",
89
+ type=str,
90
+ default=BenchArgs.dataset_path,
91
+ help="Path to the dataset.",
92
+ )
86
93
 
87
94
  @classmethod
88
95
  def from_cli_args(cls, args: argparse.Namespace):
@@ -138,6 +145,7 @@ def run_one_case(
138
145
  profile: bool = False,
139
146
  profile_steps: int = 3,
140
147
  profile_by_stage: bool = False,
148
+ dataset_path: str = "",
141
149
  ):
142
150
  requests.post(url + "/flush_cache")
143
151
  input_requests = sample_random_requests(
@@ -146,7 +154,7 @@ def run_one_case(
146
154
  num_prompts=batch_size,
147
155
  range_ratio=1.0,
148
156
  tokenizer=tokenizer,
149
- dataset_path="",
157
+ dataset_path=dataset_path,
150
158
  random_sample=True,
151
159
  return_text=False,
152
160
  )
@@ -345,6 +353,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
345
353
  run_name="",
346
354
  result_filename="",
347
355
  tokenizer=tokenizer,
356
+ dataset_path=bench_args.dataset_path,
348
357
  )
349
358
  print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
350
359
 
sglang/bench_serving.py CHANGED
@@ -75,6 +75,7 @@ class RequestFuncInput:
75
75
  lora_name: str
76
76
  image_data: Optional[List[str]]
77
77
  extra_request_body: Dict[str, Any]
78
+ timestamp: Optional[float] = None
78
79
 
79
80
 
80
81
  @dataclass
@@ -104,10 +105,13 @@ def remove_suffix(text: str, suffix: str) -> str:
104
105
 
105
106
 
106
107
  def get_auth_headers() -> Dict[str, str]:
107
- api_key = os.environ.get("OPENAI_API_KEY")
108
- if api_key:
109
- return {"Authorization": f"Bearer {api_key}"}
108
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
109
+ if openai_api_key:
110
+ return {"Authorization": f"Bearer {openai_api_key}"}
110
111
  else:
112
+ api_key = os.environ.get("API_KEY")
113
+ if api_key:
114
+ return {"Authorization": f"{api_key}"}
111
115
  return {}
112
116
 
113
117
 
@@ -696,6 +700,24 @@ def get_dataset(args, tokenizer):
696
700
  apply_chat_template=args.apply_chat_template,
697
701
  random_sample=True,
698
702
  )
703
+ elif args.dataset_name == "mooncake":
704
+ # For mooncake, we don't generate the prompts here.
705
+ # We just load the raw trace data. The async generator will handle the rest.
706
+ if not args.dataset_path:
707
+ local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl")
708
+ else:
709
+ local_path = args.dataset_path
710
+
711
+ if not os.path.exists(local_path):
712
+ download_and_cache_file(
713
+ MOONCAKE_DATASET_URL[args.mooncake_workload], local_path
714
+ )
715
+
716
+ with open(local_path, "r") as f:
717
+ all_requests_data = [json.loads(line) for line in f if line.strip()]
718
+
719
+ # Limit the number of requests based on --num-prompts
720
+ input_requests = all_requests_data[: args.num_prompts]
699
721
  else:
700
722
  raise ValueError(f"Unknown dataset: {args.dataset_name}")
701
723
  return input_requests
@@ -750,6 +772,12 @@ class BenchmarkMetrics:
750
772
 
751
773
 
752
774
  SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
775
+ MOONCAKE_DATASET_URL = {
776
+ "mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
777
+ "conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
778
+ "synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
779
+ "toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
780
+ }
753
781
 
754
782
 
755
783
  def download_and_cache_file(url: str, filename: Optional[str] = None):
@@ -808,6 +836,80 @@ class DatasetRow:
808
836
  prompt_len: int
809
837
  output_len: int
810
838
  image_data: Optional[List[str]] = None
839
+ timestamp: Optional[float] = None
840
+
841
+
842
+ async def get_mooncake_request_over_time(
843
+ input_requests: List[Dict],
844
+ tokenizer: PreTrainedTokenizerBase,
845
+ slowdown_factor: float,
846
+ num_rounds: int,
847
+ ) -> AsyncGenerator[DatasetRow, None]:
848
+ """
849
+ An async generator that yields requests based on the timestamps in the Mooncake trace file,
850
+ with support for multi-round sessions.
851
+ """
852
+ if not input_requests:
853
+ return
854
+
855
+ input_requests.sort(key=lambda r: r["timestamp"])
856
+
857
+ start_time = time.perf_counter()
858
+ trace_start_time_ms = input_requests[0]["timestamp"]
859
+
860
+ for record in input_requests:
861
+ # Calculate when this entire session should start
862
+ relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
863
+ target_arrival_time_s = relative_arrival_time_s * slowdown_factor
864
+
865
+ current_elapsed_time_s = time.perf_counter() - start_time
866
+ sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
867
+ if sleep_duration_s > 0:
868
+ await asyncio.sleep(sleep_duration_s)
869
+
870
+ # Once the session starts, generate all rounds for it as a burst
871
+ # This simulates a user engaging in a multi-turn conversation
872
+
873
+ # Base user query constructed from hash_ids
874
+ user_query_base = ""
875
+ hash_ids = record.get("hash_ids", [])
876
+ for hash_id in hash_ids:
877
+ user_query_base += f"{hash_id}" + " ".join(
878
+ ["hi"] * 128
879
+ ) # Shorter for multi-round
880
+ user_query_base += "Tell me a story based on this context."
881
+
882
+ output_len_per_round = record.get("output_length", 256)
883
+ chat_history = []
884
+
885
+ for i in range(num_rounds):
886
+ # Add user query for the current round
887
+ chat_history.append(
888
+ {"role": "user", "content": f"Round {i+1}: {user_query_base}"}
889
+ )
890
+
891
+ # Form the full prompt from history
892
+ try:
893
+ full_prompt_text = tokenizer.apply_chat_template(
894
+ chat_history, tokenize=False, add_generation_prompt=True
895
+ )
896
+ except Exception:
897
+ full_prompt_text = "\n".join(
898
+ [f"{msg['role']}: {msg['content']}" for msg in chat_history]
899
+ )
900
+
901
+ prompt_len = len(tokenizer.encode(full_prompt_text))
902
+
903
+ yield DatasetRow(
904
+ prompt=full_prompt_text,
905
+ prompt_len=prompt_len,
906
+ output_len=output_len_per_round,
907
+ )
908
+
909
+ # Add a placeholder assistant response for the next round's context
910
+ # We use a placeholder because we don't know the real response
911
+ placeholder_response = " ".join(["story"] * output_len_per_round)
912
+ chat_history.append({"role": "assistant", "content": placeholder_response})
811
913
 
812
914
 
813
915
  def sample_mmmu_requests(
@@ -896,17 +998,25 @@ def sample_mmmu_requests(
896
998
  prompt = f"Question: {question}\n\nAnswer: "
897
999
  if apply_chat_template:
898
1000
  try:
1001
+ is_phi4_multimodal = (
1002
+ "phi-4-multimodal" in tokenizer.name_or_path.lower()
1003
+ )
1004
+ if is_phi4_multimodal:
1005
+ # <|endoftext10|> is the image token used in the phi-4-multimodal model.
1006
+ content = prompt.replace("image 1", "<|endoftext10|>")
1007
+ else:
1008
+ content = [
1009
+ {
1010
+ "type": "image_url",
1011
+ "image_url": {"url": image_data},
1012
+ },
1013
+ {"type": "text", "text": prompt},
1014
+ ]
899
1015
  prompt = tokenizer.apply_chat_template(
900
1016
  [
901
1017
  {
902
1018
  "role": "user",
903
- "content": [
904
- {
905
- "type": "image_url",
906
- "image_url": {"url": image_data},
907
- },
908
- {"type": "text", "text": prompt},
909
- ],
1019
+ "content": content,
910
1020
  }
911
1021
  ],
912
1022
  add_generation_prompt=True,
@@ -1359,19 +1469,41 @@ def sample_generated_shared_prefix_requests(
1359
1469
  async def get_request(
1360
1470
  input_requests: List[DatasetRow],
1361
1471
  request_rate: float,
1472
+ use_trace_timestamps: bool = False,
1473
+ slowdown_factor: float = 1.0,
1362
1474
  ) -> AsyncGenerator[DatasetRow, None]:
1363
- input_requests = iter(input_requests)
1364
- for request in input_requests:
1365
- yield request
1475
+ if use_trace_timestamps:
1476
+ print(
1477
+ f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
1478
+ )
1479
+ # Sort requests by timestamp for correct replay
1480
+ input_requests.sort(key=lambda r: r.timestamp)
1366
1481
 
1367
- if request_rate == float("inf"):
1368
- # If the request rate is infinity, then we don't need to wait.
1369
- continue
1482
+ start_time = time.perf_counter()
1483
+ trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
1484
+
1485
+ for request in input_requests:
1486
+ trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
1487
+ target_arrival_time = start_time + (trace_time_s * slowdown_factor)
1488
+
1489
+ sleep_duration = target_arrival_time - time.perf_counter()
1490
+ if sleep_duration > 0:
1491
+ await asyncio.sleep(sleep_duration)
1492
+
1493
+ yield request
1494
+ else:
1495
+ input_requests_iter = iter(input_requests)
1496
+ for request in input_requests_iter:
1497
+ yield request
1370
1498
 
1371
- # Sample the request interval from the exponential distribution.
1372
- interval = np.random.exponential(1.0 / request_rate)
1373
- # The next request will be sent after the interval.
1374
- await asyncio.sleep(interval)
1499
+ if request_rate == float("inf"):
1500
+ # If the request rate is infinity, then we don't need to wait.
1501
+ continue
1502
+
1503
+ # Sample the request interval from the exponential distribution.
1504
+ interval = np.random.exponential(1.0 / request_rate)
1505
+ # The next request will be sent after the interval.
1506
+ await asyncio.sleep(interval)
1375
1507
 
1376
1508
 
1377
1509
  def calculate_metrics(
@@ -1397,7 +1529,7 @@ def calculate_metrics(
1397
1529
  tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
1398
1530
  )
1399
1531
  retokenized_output_lens.append(retokenized_output_len)
1400
- total_input += input_requests[i].prompt_len
1532
+ total_input += outputs[i].prompt_len
1401
1533
  if output_len > 1:
1402
1534
  tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
1403
1535
  itls += outputs[i].itl
@@ -1469,6 +1601,9 @@ async def benchmark(
1469
1601
  pd_separated: bool = False,
1470
1602
  flush_cache: bool = False,
1471
1603
  warmup_requests: int = 1,
1604
+ use_trace_timestamps: bool = False,
1605
+ mooncake_slowdown_factor=1.0,
1606
+ mooncake_num_rounds=1,
1472
1607
  ):
1473
1608
  if backend in ASYNC_REQUEST_FUNCS:
1474
1609
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -1488,8 +1623,32 @@ async def benchmark(
1488
1623
  # Warmup
1489
1624
  print(f"Starting warmup with {warmup_requests} sequences...")
1490
1625
 
1491
- # Use the first request for all warmup iterations
1492
- test_request = input_requests[0]
1626
+ # Handle the data structure difference for the warmup request
1627
+ if args.dataset_name == "mooncake":
1628
+ # For mooncake, input_requests is a list of dicts.
1629
+ # We need to build a temporary DatasetRow for the warmup phase.
1630
+ warmup_record = input_requests[0]
1631
+
1632
+ # Build prompt from hash_ids, just like in the async generator
1633
+ hash_ids = warmup_record.get("hash_ids", [])
1634
+ prompt_text = ""
1635
+ for hash_id in hash_ids:
1636
+ prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
1637
+ prompt_text += "Can you tell me a detailed story in 1000 words?"
1638
+
1639
+ output_len = warmup_record.get("output_length", 32)
1640
+ prompt_len = len(tokenizer.encode(prompt_text))
1641
+
1642
+ # Create a temporary DatasetRow object for warmup
1643
+ test_request = DatasetRow(
1644
+ prompt=prompt_text,
1645
+ prompt_len=prompt_len,
1646
+ output_len=output_len,
1647
+ image_data=None, # Mooncake doesn't have image data
1648
+ )
1649
+ else:
1650
+ # For all other datasets, input_requests is a list of DatasetRow objects
1651
+ test_request = input_requests[0]
1493
1652
 
1494
1653
  if lora_names is not None and len(lora_names) != 0:
1495
1654
  lora_name = lora_names[0]
@@ -1543,12 +1702,26 @@ async def benchmark(
1543
1702
  if profile_output.success:
1544
1703
  print("Profiler started")
1545
1704
 
1546
- pbar = None if disable_tqdm else tqdm(total=len(input_requests))
1547
-
1548
1705
  # Run all requests
1549
1706
  benchmark_start_time = time.perf_counter()
1550
1707
  tasks: List[asyncio.Task] = []
1551
- async for request in get_request(input_requests, request_rate):
1708
+ pbar_total = len(input_requests)
1709
+ if (
1710
+ backend == "sglang" and args.dataset_name == "mooncake"
1711
+ ): # Assuming mooncake is mainly for sglang or similar backends
1712
+ print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
1713
+ request_generator = get_mooncake_request_over_time(
1714
+ input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
1715
+ )
1716
+ print(
1717
+ f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
1718
+ )
1719
+ pbar_total *= args.mooncake_num_rounds
1720
+ else:
1721
+ request_generator = get_request(input_requests, request_rate)
1722
+
1723
+ pbar = None if disable_tqdm else tqdm(total=pbar_total)
1724
+ async for request in request_generator:
1552
1725
  if lora_names is not None and len(lora_names) != 0:
1553
1726
  idx = random.randint(0, len(lora_names) - 1)
1554
1727
  lora_name = lora_names[idx]
@@ -1564,6 +1737,7 @@ async def benchmark(
1564
1737
  lora_name=lora_name,
1565
1738
  image_data=request.image_data,
1566
1739
  extra_request_body=extra_request_body,
1740
+ timestamp=request.timestamp,
1567
1741
  )
1568
1742
 
1569
1743
  tasks.append(
@@ -1609,7 +1783,11 @@ async def benchmark(
1609
1783
 
1610
1784
  print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
1611
1785
  print("{:<40} {:<10}".format("Backend:", backend))
1612
- print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
1786
+ print(
1787
+ "{:<40} {:<10}".format(
1788
+ "Traffic request rate:", "trace" if use_trace_timestamps else request_rate
1789
+ )
1790
+ )
1613
1791
  print(
1614
1792
  "{:<40} {:<10}".format(
1615
1793
  "Max request concurrency:",
@@ -1678,7 +1856,7 @@ async def benchmark(
1678
1856
  # Arguments
1679
1857
  "backend": args.backend,
1680
1858
  "dataset_name": args.dataset_name,
1681
- "request_rate": request_rate,
1859
+ "request_rate": "trace" if use_trace_timestamps else request_rate,
1682
1860
  "max_concurrency": max_concurrency,
1683
1861
  "sharegpt_output_len": args.sharegpt_output_len,
1684
1862
  "random_input_len": args.random_input_len,
@@ -1731,7 +1909,9 @@ async def benchmark(
1731
1909
  elif args.dataset_name.startswith("random"):
1732
1910
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
1733
1911
  else:
1734
- output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
1912
+ output_file_name = (
1913
+ f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
1914
+ )
1735
1915
 
1736
1916
  result_details = {
1737
1917
  "input_lens": [output.prompt_len for output in outputs],
@@ -1786,6 +1966,17 @@ def run_benchmark(args_: argparse.Namespace):
1786
1966
  if not hasattr(args, "tokenize_prompt"):
1787
1967
  args.tokenize_prompt = False
1788
1968
 
1969
+ if not hasattr(args, "use_trace_timestamps"):
1970
+ args.use_trace_timestamps = False
1971
+ if not hasattr(args, "mooncake_slowdown_factor"):
1972
+ args.mooncake_slowdown_factor = 1.0
1973
+
1974
+ if not hasattr(args, "mooncake_slowdown_factor"):
1975
+ args.mooncake_slowdown_factor = 1.0
1976
+
1977
+ if not hasattr(args, "mooncake_num_rounds"):
1978
+ args.mooncake_num_rounds = 1
1979
+
1789
1980
  print(f"benchmark_args={args}")
1790
1981
 
1791
1982
  # Set global environments
@@ -1919,6 +2110,9 @@ def run_benchmark(args_: argparse.Namespace):
1919
2110
  pd_separated=args.pd_separated,
1920
2111
  flush_cache=args.flush_cache,
1921
2112
  warmup_requests=args.warmup_requests,
2113
+ use_trace_timestamps=args.use_trace_timestamps,
2114
+ mooncake_slowdown_factor=args.mooncake_slowdown_factor,
2115
+ mooncake_num_rounds=args.mooncake_num_rounds,
1922
2116
  )
1923
2117
  )
1924
2118
 
@@ -1975,6 +2169,7 @@ if __name__ == "__main__":
1975
2169
  "generated-shared-prefix",
1976
2170
  "mmmu",
1977
2171
  "random-image",
2172
+ "mooncake",
1978
2173
  ],
1979
2174
  help="Name of the dataset to benchmark on.",
1980
2175
  )
@@ -2051,6 +2246,11 @@ if __name__ == "__main__":
2051
2246
  help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
2052
2247
  "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
2053
2248
  )
2249
+ parser.add_argument(
2250
+ "--use-trace-timestamps",
2251
+ action="store_true",
2252
+ help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
2253
+ )
2054
2254
  parser.add_argument(
2055
2255
  "--max-concurrency",
2056
2256
  type=int,
@@ -2174,5 +2374,33 @@ if __name__ == "__main__":
2174
2374
  default=256,
2175
2375
  help="Target length in tokens for outputs in generated-shared-prefix dataset",
2176
2376
  )
2377
+ mooncake_group = parser.add_argument_group("mooncake dataset arguments")
2378
+ mooncake_group.add_argument(
2379
+ "--mooncake-slowdown-factor",
2380
+ type=float,
2381
+ default=1.0,
2382
+ help="Slowdown factor for replaying the mooncake trace. "
2383
+ "A value of 2.0 means the replay is twice as slow. "
2384
+ "NOTE: --request-rate is IGNORED in mooncake mode.",
2385
+ )
2386
+ mooncake_group.add_argument(
2387
+ "--mooncake-num-rounds",
2388
+ type=int,
2389
+ default=1,
2390
+ help="Number of conversation rounds for each session in the mooncake dataset. "
2391
+ "A value > 1 will enable true multi-turn session benchmarking.",
2392
+ )
2393
+ mooncake_group.add_argument(
2394
+ "--mooncake-workload",
2395
+ type=str,
2396
+ default="conversation",
2397
+ choices=[
2398
+ "mooncake",
2399
+ "conversation",
2400
+ "synthetic",
2401
+ "toolagent",
2402
+ ],
2403
+ help="Underlying workload for the mooncake dataset.",
2404
+ )
2177
2405
  args = parser.parse_args()
2178
2406
  run_benchmark(args)
@@ -1,11 +1,13 @@
1
1
  from sglang.srt.configs.chatglm import ChatGLMConfig
2
2
  from sglang.srt.configs.dbrx import DbrxConfig
3
3
  from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
4
+ from sglang.srt.configs.dots_vlm import DotsVLMConfig
4
5
  from sglang.srt.configs.exaone import ExaoneConfig
5
6
  from sglang.srt.configs.janus_pro import MultiModalityConfig
6
7
  from sglang.srt.configs.kimi_vl import KimiVLConfig
7
8
  from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
8
9
  from sglang.srt.configs.longcat_flash import LongcatFlashConfig
10
+ from sglang.srt.configs.qwen3_next import Qwen3NextConfig
9
11
  from sglang.srt.configs.step3_vl import (
10
12
  Step3TextConfig,
11
13
  Step3VisionEncoderConfig,
@@ -24,4 +26,6 @@ __all__ = [
24
26
  "Step3VLConfig",
25
27
  "Step3TextConfig",
26
28
  "Step3VisionEncoderConfig",
29
+ "Qwen3NextConfig",
30
+ "DotsVLMConfig",
27
31
  ]
@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
8
8
 
9
9
  class DeviceConfig:
10
10
  device: Optional[torch.device]
11
+ gpu_id: Optional[int]
11
12
 
12
- def __init__(self, device: str = "cuda") -> None:
13
+ def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
13
14
  if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
14
15
  self.device_type = device
15
16
  else:
16
17
  raise RuntimeError(f"Not supported device type: {device}")
17
18
  self.device = torch.device(self.device_type)
19
+ self.gpu_id = gpu_id
@@ -0,0 +1,139 @@
1
+ from typing import Any, List, Optional, Union
2
+
3
+ from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig
4
+ from transformers.feature_extraction_utils import BatchFeature
5
+ from transformers.image_utils import ImageInput
6
+ from transformers.processing_utils import ProcessingKwargs, Unpack
7
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
8
+
9
+ try:
10
+ from transformers import Qwen2_5_VLProcessor
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Qwen2_5_VLProcessor can not be found. Please upgrade your transformers version."
14
+ )
15
+
16
+ from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
17
+
18
+
19
+ class DotsVisionConfig(PretrainedConfig):
20
+ model_type: str = "dots_vit"
21
+
22
+ def __init__(
23
+ self,
24
+ embed_dim: int = 1536, # vision encoder embed size
25
+ hidden_size: int = 1536, # after merger hidden size
26
+ intermediate_size: int = 4224,
27
+ num_hidden_layers: int = 42,
28
+ num_attention_heads: int = 12,
29
+ num_channels: int = 3,
30
+ patch_size: int = 14,
31
+ spatial_merge_size: int = 2,
32
+ temporal_patch_size: int = 1,
33
+ rms_norm_eps: float = 1e-5,
34
+ use_bias: bool = False,
35
+ attn_implementation="flash_attention_2", # "eager","sdpa","flash_attention_2"
36
+ initializer_range=0.02,
37
+ init_merger_std=0.02,
38
+ is_causal=False, # ve causal forward
39
+ post_norm=True,
40
+ gradient_checkpointing=False,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.embed_dim = embed_dim
45
+ self.hidden_size = hidden_size
46
+ self.intermediate_size = intermediate_size
47
+ self.num_hidden_layers = num_hidden_layers
48
+ self.num_attention_heads = num_attention_heads
49
+ self.num_channels = num_channels
50
+ self.patch_size = patch_size
51
+ self.spatial_merge_size = spatial_merge_size
52
+ self.temporal_patch_size = temporal_patch_size
53
+ self.rms_norm_eps = rms_norm_eps
54
+ self.use_bias = use_bias
55
+ self.attn_implementation = attn_implementation
56
+ self.initializer_range = initializer_range
57
+ self.init_merger_std = init_merger_std
58
+ self.is_causal = is_causal
59
+ self.post_norm = post_norm
60
+ self.gradient_checkpointing = gradient_checkpointing
61
+
62
+
63
+ class DotsVLMConfig(PretrainedConfig):
64
+ model_type = "dots_vlm"
65
+
66
+ def __init__(self, **kwargs):
67
+ super().__init__(**kwargs)
68
+ vision_config = kwargs.get("vision_config", {})
69
+ self.im_span_id = kwargs.get("image_token_id", 128815)
70
+ self.video_span_id = kwargs.get("video_token_id", 128836)
71
+ self.vision_config = DotsVisionConfig(**vision_config)
72
+ self.language_config = DeepseekV2Config(**kwargs)
73
+ self.architectures = ["DotsVLMForCausalLM"]
74
+
75
+
76
+ class DotsVLMProcessorKwargs(ProcessingKwargs, total=False):
77
+ _defaults = {
78
+ "text_kwargs": {
79
+ "padding": False,
80
+ },
81
+ }
82
+
83
+
84
+ class DotsVLMProcessor(Qwen2_5_VLProcessor):
85
+ r"""
86
+ Constructs a DotsVLM processor which derives from Qwen2_5_VLProcessor, but overrides the image and video token ids.
87
+ Besides, its tokenizer is a LlamaTokenizerFast instead of Qwen2TokenizerFast.
88
+ [`DotsVLMProcessor`] offers all the functionalities of [`DotsVisionConfig`] and [`LlamaTokenizerFast`]. See the
89
+ [`~DotsVLMProcessor.__call__`] and [`~DotsVLMProcessor.decode`] for more information.
90
+ Args:
91
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
92
+ The image processor is a required input.
93
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
94
+ The tokenizer is a required input.
95
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
96
+ in a chat into a tokenizable string.
97
+ """
98
+
99
+ attributes = ["image_processor", "tokenizer"]
100
+
101
+ valid_kwargs = ["chat_template"]
102
+
103
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
104
+
105
+ def __init__(
106
+ self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
107
+ ):
108
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
109
+ self.image_token = (
110
+ "<|imgpad|>"
111
+ if not hasattr(tokenizer, "image_token")
112
+ else tokenizer.image_token
113
+ )
114
+ self.video_token = (
115
+ "<|video_pad|>"
116
+ if not hasattr(tokenizer, "video_token")
117
+ else tokenizer.video_token
118
+ )
119
+ self.img_token = (
120
+ "<|img|>" if not hasattr(tokenizer, "img_token") else tokenizer.img_token
121
+ )
122
+ self.endofimg_token = (
123
+ "<|endofimg|>"
124
+ if not hasattr(tokenizer, "endofimg_token")
125
+ else tokenizer.endofimg_token
126
+ )
127
+ self.image_token_id = (
128
+ tokenizer.image_token_id
129
+ if getattr(tokenizer, "image_token_id", None)
130
+ else tokenizer.encode(self.image_token)[0]
131
+ )
132
+ self.video_token_id = (
133
+ tokenizer.video_token_id
134
+ if getattr(tokenizer, "video_token_id", None)
135
+ else tokenizer.encode(self.video_token)[0]
136
+ )
137
+
138
+
139
+ AutoProcessor.register(DotsVLMConfig, DotsVLMProcessor)
@@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
23
23
  LAYERED = "layered"
24
24
  JAX = "jax"
25
25
  REMOTE = "remote"
26
+ REMOTE_INSTANCE = "remote_instance"
26
27
 
27
28
 
28
29
  @dataclass