sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
18
18
 
19
19
  import copy
20
20
  import uuid
21
+ from abc import ABC
21
22
  from dataclasses import dataclass, field
22
23
  from enum import Enum
23
24
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@ -35,10 +36,33 @@ else:
35
36
  Image = Any
36
37
 
37
38
 
39
+ # Parameters for a session
40
+ @dataclass
41
+ class BaseReq(ABC):
42
+ rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
43
+
44
+ def regenerate_rid(self):
45
+ """Generate a new request ID and return it."""
46
+ if isinstance(self.rid, list):
47
+ self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
48
+ else:
49
+ self.rid = uuid.uuid4().hex
50
+ return self.rid
51
+
52
+
53
+ @dataclass
54
+ class BaseBatchReq(ABC):
55
+ rids: Optional[List[str]] = field(default=None, kw_only=True)
56
+
57
+ def regenerate_rids(self):
58
+ """Generate new request IDs and return them."""
59
+ self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
60
+ return self.rids
61
+
62
+
38
63
  @dataclass
39
64
  class SessionParams:
40
65
  id: Optional[str] = None
41
- rid: Optional[str] = None
42
66
  offset: Optional[int] = None
43
67
  replace: Optional[bool] = None
44
68
  drop_previous_output: Optional[bool] = None
@@ -62,7 +86,7 @@ MultimodalDataInputFormat = Union[
62
86
 
63
87
 
64
88
  @dataclass
65
- class GenerateReqInput:
89
+ class GenerateReqInput(BaseReq):
66
90
  # The input prompt. It can be a single prompt or a batch of prompts.
67
91
  text: Optional[Union[List[str], str]] = None
68
92
  # The token ids for text; one can specify either text or input_ids
@@ -82,8 +106,6 @@ class GenerateReqInput:
82
106
  audio_data: Optional[MultimodalDataInputFormat] = None
83
107
  # The sampling_params. See descriptions below.
84
108
  sampling_params: Optional[Union[List[Dict], Dict]] = None
85
- # The request id.
86
- rid: Optional[Union[List[str], str]] = None
87
109
  # Whether to return logprobs.
88
110
  return_logprob: Optional[Union[List[bool], bool]] = None
89
111
  # If return logprobs, the start location in the prompt for returning logprobs.
@@ -121,6 +143,7 @@ class GenerateReqInput:
121
143
  bootstrap_host: Optional[Union[List[str], str]] = None
122
144
  bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
123
145
  bootstrap_room: Optional[Union[List[int], int]] = None
146
+ bootstrap_pair_key: Optional[Union[List[str], str]] = None
124
147
 
125
148
  # For data parallel rank routing
126
149
  data_parallel_rank: Optional[int] = None
@@ -128,6 +151,24 @@ class GenerateReqInput:
128
151
  # For background responses (OpenAI responses API)
129
152
  background: bool = False
130
153
 
154
+ # Conversation id used for tracking requests
155
+ conversation_id: Optional[str] = None
156
+
157
+ # Priority for the request
158
+ priority: Optional[int] = None
159
+
160
+ # Extra key for classifying the request (e.g. cache_salt)
161
+ extra_key: Optional[Union[List[str], str]] = None
162
+
163
+ # Whether to disallow logging for this request (e.g. due to ZDR)
164
+ no_logs: bool = False
165
+
166
+ # For custom metric labels
167
+ custom_labels: Optional[Dict[str, str]] = None
168
+
169
+ # (Internal) Whether to return bytes for image generation
170
+ return_bytes: bool = False
171
+
131
172
  def contains_mm_input(self) -> bool:
132
173
  return (
133
174
  has_valid_data(self.image_data)
@@ -258,6 +299,7 @@ class GenerateReqInput:
258
299
  self._normalize_sampling_params(num)
259
300
  self._normalize_logprob_params(num)
260
301
  self._normalize_custom_logit_processor(num)
302
+ self._normalize_bootstrap_params(num)
261
303
 
262
304
  def _expand_inputs(self, num):
263
305
  """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
@@ -297,6 +339,11 @@ class GenerateReqInput:
297
339
  self.image_data = [[self.image_data]] * num
298
340
  self.modalities = ["image"] * num
299
341
  elif isinstance(self.image_data, list):
342
+ # Handle empty list case - treat as no images
343
+ if len(self.image_data) == 0:
344
+ self.image_data = [None] * num
345
+ return
346
+
300
347
  if len(self.image_data) != self.batch_size:
301
348
  raise ValueError(
302
349
  "The length of image_data should be equal to the batch size."
@@ -421,6 +468,40 @@ class GenerateReqInput:
421
468
  "Cannot use list custom_logit_processor with parallel_sample_num > 1"
422
469
  )
423
470
 
471
+ def _normalize_bootstrap_params(self, num):
472
+ """Normalize bootstrap parameters for batch processing."""
473
+ # Normalize bootstrap_host
474
+ if self.bootstrap_host is None:
475
+ self.bootstrap_host = [None] * num
476
+ elif not isinstance(self.bootstrap_host, list):
477
+ self.bootstrap_host = [self.bootstrap_host] * num
478
+ elif isinstance(self.bootstrap_host, list):
479
+ self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num
480
+
481
+ # Normalize bootstrap_port
482
+ if self.bootstrap_port is None:
483
+ self.bootstrap_port = [None] * num
484
+ elif not isinstance(self.bootstrap_port, list):
485
+ self.bootstrap_port = [self.bootstrap_port] * num
486
+ elif isinstance(self.bootstrap_port, list):
487
+ self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num
488
+
489
+ # Normalize bootstrap_room
490
+ if self.bootstrap_room is None:
491
+ self.bootstrap_room = [None] * num
492
+ elif not isinstance(self.bootstrap_room, list):
493
+ self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
494
+ elif isinstance(self.bootstrap_room, list):
495
+ self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num
496
+
497
+ # Normalize bootstrap_pair_key
498
+ if self.bootstrap_pair_key is None:
499
+ self.bootstrap_pair_key = [None] * num
500
+ elif not isinstance(self.bootstrap_pair_key, list):
501
+ self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
502
+ elif isinstance(self.bootstrap_pair_key, list):
503
+ self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num
504
+
424
505
  def _validate_session_params(self):
425
506
  """Validate that session parameters are properly formatted."""
426
507
  if self.session_params is not None:
@@ -429,11 +510,6 @@ class GenerateReqInput:
429
510
  ):
430
511
  raise ValueError("Session params must be a dict or a list of dicts.")
431
512
 
432
- def regenerate_rid(self):
433
- """Generate a new request ID and return it."""
434
- self.rid = uuid.uuid4().hex
435
- return self.rid
436
-
437
513
  def __getitem__(self, i):
438
514
  return GenerateReqInput(
439
515
  text=self.text[i] if self.text is not None else None,
@@ -453,7 +529,13 @@ class GenerateReqInput:
453
529
  return_text_in_logprobs=self.return_text_in_logprobs,
454
530
  stream=self.stream,
455
531
  log_metrics=self.log_metrics,
532
+ return_hidden_states=(
533
+ self.return_hidden_states[i]
534
+ if isinstance(self.return_hidden_states, list)
535
+ else self.return_hidden_states
536
+ ),
456
537
  modalities=self.modalities[i] if self.modalities else None,
538
+ session_params=self.session_params,
457
539
  lora_path=self.lora_path[i] if self.lora_path is not None else None,
458
540
  lora_id=self.lora_id[i] if self.lora_id is not None else None,
459
541
  custom_logit_processor=(
@@ -461,11 +543,6 @@ class GenerateReqInput:
461
543
  if self.custom_logit_processor is not None
462
544
  else None
463
545
  ),
464
- return_hidden_states=(
465
- self.return_hidden_states[i]
466
- if isinstance(self.return_hidden_states, list)
467
- else self.return_hidden_states
468
- ),
469
546
  # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
470
547
  bootstrap_host=(
471
548
  self.bootstrap_host[i] if self.bootstrap_host is not None else None
@@ -476,16 +553,25 @@ class GenerateReqInput:
476
553
  bootstrap_room=(
477
554
  self.bootstrap_room[i] if self.bootstrap_room is not None else None
478
555
  ),
556
+ bootstrap_pair_key=(
557
+ self.bootstrap_pair_key[i]
558
+ if self.bootstrap_pair_key is not None
559
+ else None
560
+ ),
479
561
  data_parallel_rank=(
480
562
  self.data_parallel_rank if self.data_parallel_rank is not None else None
481
563
  ),
564
+ conversation_id=self.conversation_id,
565
+ priority=self.priority,
566
+ extra_key=self.extra_key,
567
+ no_logs=self.no_logs,
568
+ custom_labels=self.custom_labels,
569
+ return_bytes=self.return_bytes,
482
570
  )
483
571
 
484
572
 
485
573
  @dataclass
486
- class TokenizedGenerateReqInput:
487
- # The request id
488
- rid: str
574
+ class TokenizedGenerateReqInput(BaseReq):
489
575
  # The input text
490
576
  input_text: str
491
577
  # The input token ids
@@ -505,36 +591,50 @@ class TokenizedGenerateReqInput:
505
591
  # Whether to stream output
506
592
  stream: bool
507
593
 
508
- # LoRA related
509
- lora_id: Optional[str] = None # None means just use the base model
594
+ # Whether to return hidden states
595
+ return_hidden_states: bool = False
596
+
510
597
  # The input embeds
511
598
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
512
599
 
513
600
  # Session info for continual prompting
514
601
  session_params: Optional[SessionParams] = None
515
602
 
603
+ # LoRA related
604
+ lora_id: Optional[str] = None # None means just use the base model
605
+
516
606
  # Custom logit processor for advanced sampling control. Must be a serialized instance
517
607
  # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
518
608
  # Use the processor's `to_str()` method to generate the serialized string.
519
609
  custom_logit_processor: Optional[str] = None
520
610
 
521
- # Whether to return hidden states
522
- return_hidden_states: bool = False
523
-
524
611
  # For disaggregated inference
525
612
  bootstrap_host: Optional[str] = None
526
613
  bootstrap_port: Optional[int] = None
527
614
  bootstrap_room: Optional[int] = None
615
+ bootstrap_pair_key: Optional[str] = None
528
616
 
529
617
  # For data parallel rank routing
530
618
  data_parallel_rank: Optional[int] = None
531
619
 
532
- # For dp balance
533
- dp_balance_id: int = -1
620
+ # Priority for the request
621
+ priority: Optional[int] = None
622
+
623
+ # Extra key for classifying the request (e.g. cache_salt)
624
+ extra_key: Optional[str] = None
625
+
626
+ # Whether to disallow logging for this request (e.g. due to ZDR)
627
+ no_logs: bool = False
628
+
629
+ # tracing context
630
+ trace_context: Optional[Dict] = None
631
+
632
+ # (Internal) Whether to return bytes for image generation
633
+ return_bytes: bool = False
534
634
 
535
635
 
536
636
  @dataclass
537
- class BatchTokenizedGenerateReqInput:
637
+ class BatchTokenizedGenerateReqInput(BaseBatchReq):
538
638
  # The batch of tokenized requests
539
639
  batch: List[TokenizedGenerateReqInput]
540
640
 
@@ -549,7 +649,7 @@ class BatchTokenizedGenerateReqInput:
549
649
 
550
650
 
551
651
  @dataclass
552
- class EmbeddingReqInput:
652
+ class EmbeddingReqInput(BaseReq):
553
653
  # The input prompt. It can be a single prompt or a batch of prompts.
554
654
  text: Optional[Union[List[List[str]], List[str], str]] = None
555
655
  # The image input. It can be an image instance, file name, URL, or base64 encoded string.
@@ -565,8 +665,6 @@ class EmbeddingReqInput:
565
665
  audio_data: Optional[MultimodalDataInputFormat] = None
566
666
  # The token ids for text; one can either specify text or input_ids.
567
667
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
568
- # The request id.
569
- rid: Optional[Union[List[str], str]] = None
570
668
  # Dummy sampling params for compatibility
571
669
  sampling_params: Optional[Union[List[Dict], Dict]] = None
572
670
  # Dummy input embeds for compatibility
@@ -577,10 +675,15 @@ class EmbeddingReqInput:
577
675
  modalities: Optional[List[str]] = None
578
676
  # For cross-encoder requests
579
677
  is_cross_encoder_request: bool = False
678
+ # Priority for the request
679
+ priority: Optional[int] = None
580
680
 
581
681
  # For background responses (OpenAI responses API)
582
682
  background: bool = False
583
683
 
684
+ # tracing context
685
+ trace_context: Optional[Dict] = None
686
+
584
687
  def normalize_batch_and_arguments(self):
585
688
  # at least one of text, input_ids, or image should be provided
586
689
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -632,10 +735,6 @@ class EmbeddingReqInput:
632
735
  for i in range(self.batch_size):
633
736
  self.sampling_params[i]["max_new_tokens"] = 0
634
737
 
635
- def regenerate_rid(self):
636
- self.rid = uuid.uuid4().hex
637
- return self.rid
638
-
639
738
  def contains_mm_input(self) -> bool:
640
739
  return (
641
740
  has_valid_data(self.image_data)
@@ -664,9 +763,7 @@ class EmbeddingReqInput:
664
763
 
665
764
 
666
765
  @dataclass
667
- class TokenizedEmbeddingReqInput:
668
- # The request id
669
- rid: str
766
+ class TokenizedEmbeddingReqInput(BaseReq):
670
767
  # The input text
671
768
  input_text: str
672
769
  # The input token ids
@@ -679,12 +776,12 @@ class TokenizedEmbeddingReqInput:
679
776
  sampling_params: SamplingParams
680
777
  # For data parallel rank routing
681
778
  data_parallel_rank: Optional[int] = None
682
- # For dp balance
683
- dp_balance_id: int = -1
779
+ # Priority for the request
780
+ priority: Optional[int] = None
684
781
 
685
782
 
686
783
  @dataclass
687
- class BatchTokenizedEmbeddingReqInput:
784
+ class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
688
785
  # The batch of tokenized embedding requests
689
786
  batch: List[TokenizedEmbeddingReqInput]
690
787
 
@@ -699,9 +796,7 @@ class BatchTokenizedEmbeddingReqInput:
699
796
 
700
797
 
701
798
  @dataclass
702
- class BatchTokenIDOut:
703
- # The request id
704
- rids: List[str]
799
+ class BatchTokenIDOutput(BaseBatchReq):
705
800
  # The finish reason
706
801
  finished_reasons: List[BaseFinishReason]
707
802
  # For incremental decoding
@@ -738,11 +833,26 @@ class BatchTokenIDOut:
738
833
  # Hidden states
739
834
  output_hidden_states: List[List[float]]
740
835
 
836
+ # The information of placeholder tokens (e.g., image token)
837
+ # idx is the index of the token in the prompt after expansion.
838
+ # val is the length of padded tokens after expansion.
839
+ placeholder_tokens_idx: List[Optional[List[int]]]
840
+ placeholder_tokens_val: List[Optional[List[int]]]
841
+
741
842
 
742
843
  @dataclass
743
- class BatchMultimodalDecodeReq:
744
- # The request id
745
- rids: List[str]
844
+ class BatchMultimodalDecodeReq(BaseBatchReq):
845
+ decoded_ids: List[int]
846
+ input_token_logprobs_val: List[float]
847
+ input_token_logprobs_idx: List[int]
848
+ output_token_logprobs_val: List[float]
849
+ output_token_logprobs_idx: List[int]
850
+ read_offsets: List[int]
851
+ skip_special_tokens: List[bool]
852
+ spaces_between_special_tokens: List[bool]
853
+ image_resolutions: List[List[int]]
854
+ resize_image_resolutions: List[List[int]]
855
+
746
856
  finished_reasons: List[BaseFinishReason]
747
857
 
748
858
  # Token counts
@@ -750,11 +860,15 @@ class BatchMultimodalDecodeReq:
750
860
  completion_tokens: List[int]
751
861
  cached_tokens: List[int]
752
862
 
863
+ # Placeholder token info
864
+ placeholder_tokens_idx: List[Optional[List[int]]]
865
+ placeholder_tokens_val: List[Optional[List[int]]]
866
+
867
+ return_bytes: bool = False
868
+
753
869
 
754
870
  @dataclass
755
- class BatchStrOut:
756
- # The request id
757
- rids: List[str]
871
+ class BatchStrOutput(BaseBatchReq):
758
872
  # The finish reason
759
873
  finished_reasons: List[dict]
760
874
  # The output decoded strings
@@ -785,26 +899,37 @@ class BatchStrOut:
785
899
  # Hidden states
786
900
  output_hidden_states: List[List[float]]
787
901
 
902
+ placeholder_tokens_idx: List[Optional[List[int]]]
903
+ placeholder_tokens_val: List[Optional[List[int]]]
904
+
788
905
 
789
906
  @dataclass
790
- class BatchMultimodalOut:
791
- # The request id
792
- rids: List[str]
907
+ class BatchMultimodalOutput(BaseBatchReq):
793
908
  # The finish reason
794
909
  finished_reasons: List[dict]
910
+ decoded_ids: List[List[int]]
795
911
  # The outputs
796
- outputs: List[List[Dict]]
912
+ outputs: Union[List[str | bytes], List[List[Dict]]]
913
+
914
+ # probability values for input tokens and output tokens
915
+ input_token_logprobs_val: List[List[float]]
916
+ input_token_logprobs_idx: List[List[int]]
917
+ output_token_logprobs_val: List[List[float]]
918
+ output_token_logprobs_idx: List[List[int]]
797
919
 
798
920
  # Token counts
799
921
  prompt_tokens: List[int]
800
922
  completion_tokens: List[int]
801
923
  cached_tokens: List[int]
802
924
 
925
+ placeholder_tokens_idx: List[Optional[List[int]]]
926
+ placeholder_tokens_val: List[Optional[List[int]]]
927
+
928
+ return_bytes: List[bool]
929
+
803
930
 
804
931
  @dataclass
805
- class BatchEmbeddingOut:
806
- # The request id
807
- rids: List[str]
932
+ class BatchEmbeddingOutput(BaseBatchReq):
808
933
  # The finish reason
809
934
  finished_reasons: List[BaseFinishReason]
810
935
  # The output embedding
@@ -812,30 +937,33 @@ class BatchEmbeddingOut:
812
937
  # Token counts
813
938
  prompt_tokens: List[int]
814
939
  cached_tokens: List[int]
940
+ # Placeholder token info
941
+ placeholder_tokens_idx: List[Optional[List[int]]]
942
+ placeholder_tokens_val: List[Optional[List[int]]]
815
943
 
816
944
 
817
945
  @dataclass
818
- class ClearHiCacheReqInput:
946
+ class ClearHiCacheReqInput(BaseReq):
819
947
  pass
820
948
 
821
949
 
822
950
  @dataclass
823
- class ClearHiCacheReqOutput:
951
+ class ClearHiCacheReqOutput(BaseReq):
824
952
  success: bool
825
953
 
826
954
 
827
955
  @dataclass
828
- class FlushCacheReqInput:
956
+ class FlushCacheReqInput(BaseReq):
829
957
  pass
830
958
 
831
959
 
832
960
  @dataclass
833
- class FlushCacheReqOutput:
961
+ class FlushCacheReqOutput(BaseReq):
834
962
  success: bool
835
963
 
836
964
 
837
965
  @dataclass
838
- class UpdateWeightFromDiskReqInput:
966
+ class UpdateWeightFromDiskReqInput(BaseReq):
839
967
  # The model path with the new weights
840
968
  model_path: str
841
969
  # The format to load the weights
@@ -844,10 +972,16 @@ class UpdateWeightFromDiskReqInput:
844
972
  abort_all_requests: bool = False
845
973
  # Optional: Update weight version along with weights
846
974
  weight_version: Optional[str] = None
975
+ # Whether to update weights asynchronously
976
+ is_async: bool = False
977
+ # Whether to empty torch cache
978
+ torch_empty_cache: bool = False
979
+ # Whether to keep the scheduler paused after weight update
980
+ keep_pause: bool = False
847
981
 
848
982
 
849
983
  @dataclass
850
- class UpdateWeightFromDiskReqOutput:
984
+ class UpdateWeightFromDiskReqOutput(BaseReq):
851
985
  success: bool
852
986
  message: str
853
987
  # Number of paused requests during weight sync.
@@ -855,7 +989,7 @@ class UpdateWeightFromDiskReqOutput:
855
989
 
856
990
 
857
991
  @dataclass
858
- class UpdateWeightsFromDistributedReqInput:
992
+ class UpdateWeightsFromDistributedReqInput(BaseReq):
859
993
  names: List[str]
860
994
  dtypes: List[str]
861
995
  shapes: List[List[int]]
@@ -870,13 +1004,13 @@ class UpdateWeightsFromDistributedReqInput:
870
1004
 
871
1005
 
872
1006
  @dataclass
873
- class UpdateWeightsFromDistributedReqOutput:
1007
+ class UpdateWeightsFromDistributedReqOutput(BaseReq):
874
1008
  success: bool
875
1009
  message: str
876
1010
 
877
1011
 
878
1012
  @dataclass
879
- class UpdateWeightsFromTensorReqInput:
1013
+ class UpdateWeightsFromTensorReqInput(BaseReq):
880
1014
  """Update model weights from tensor input.
881
1015
 
882
1016
  - Tensors are serialized for transmission
@@ -895,13 +1029,51 @@ class UpdateWeightsFromTensorReqInput:
895
1029
 
896
1030
 
897
1031
  @dataclass
898
- class UpdateWeightsFromTensorReqOutput:
1032
+ class UpdateWeightsFromTensorReqOutput(BaseReq):
1033
+ success: bool
1034
+ message: str
1035
+
1036
+
1037
+ @dataclass
1038
+ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
1039
+ # The master address
1040
+ master_address: str
1041
+ # The ports for each rank's communication group
1042
+ ports: str
1043
+ # The rank in the communication group
1044
+ group_rank: int
1045
+ # The world size
1046
+ world_size: int
1047
+ # The group name
1048
+ group_name: str = "weight_send_group"
1049
+ # The backend
1050
+ backend: str = "nccl"
1051
+
1052
+
1053
+ @dataclass
1054
+ class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
1055
+ success: bool
1056
+ message: str
1057
+
1058
+
1059
+ @dataclass
1060
+ class SendWeightsToRemoteInstanceReqInput(BaseReq):
1061
+ # The master address
1062
+ master_address: str
1063
+ # The ports for each rank's communication group
1064
+ ports: str
1065
+ # The group name
1066
+ group_name: str = "weight_send_group"
1067
+
1068
+
1069
+ @dataclass
1070
+ class SendWeightsToRemoteInstanceReqOutput(BaseReq):
899
1071
  success: bool
900
1072
  message: str
901
1073
 
902
1074
 
903
1075
  @dataclass
904
- class InitWeightsUpdateGroupReqInput:
1076
+ class InitWeightsUpdateGroupReqInput(BaseReq):
905
1077
  # The master address
906
1078
  master_address: str
907
1079
  # The master port
@@ -917,13 +1089,24 @@ class InitWeightsUpdateGroupReqInput:
917
1089
 
918
1090
 
919
1091
  @dataclass
920
- class InitWeightsUpdateGroupReqOutput:
1092
+ class InitWeightsUpdateGroupReqOutput(BaseReq):
1093
+ success: bool
1094
+ message: str
1095
+
1096
+
1097
+ @dataclass
1098
+ class DestroyWeightsUpdateGroupReqInput(BaseReq):
1099
+ group_name: str = "weight_update_group"
1100
+
1101
+
1102
+ @dataclass
1103
+ class DestroyWeightsUpdateGroupReqOutput(BaseReq):
921
1104
  success: bool
922
1105
  message: str
923
1106
 
924
1107
 
925
1108
  @dataclass
926
- class UpdateWeightVersionReqInput:
1109
+ class UpdateWeightVersionReqInput(BaseReq):
927
1110
  # The new weight version
928
1111
  new_version: str
929
1112
  # Whether to abort all running requests before updating
@@ -931,88 +1114,87 @@ class UpdateWeightVersionReqInput:
931
1114
 
932
1115
 
933
1116
  @dataclass
934
- class GetWeightsByNameReqInput:
1117
+ class GetWeightsByNameReqInput(BaseReq):
935
1118
  name: str
936
1119
  truncate_size: int = 100
937
1120
 
938
1121
 
939
1122
  @dataclass
940
- class GetWeightsByNameReqOutput:
1123
+ class GetWeightsByNameReqOutput(BaseReq):
941
1124
  parameter: list
942
1125
 
943
1126
 
944
1127
  @dataclass
945
- class ReleaseMemoryOccupationReqInput:
1128
+ class ReleaseMemoryOccupationReqInput(BaseReq):
946
1129
  # Optional tags to identify the memory region, which is primarily used for RL
947
1130
  # Currently we only support `weights` and `kv_cache`
948
1131
  tags: Optional[List[str]] = None
949
1132
 
950
1133
 
951
1134
  @dataclass
952
- class ReleaseMemoryOccupationReqOutput:
1135
+ class ReleaseMemoryOccupationReqOutput(BaseReq):
953
1136
  pass
954
1137
 
955
1138
 
956
1139
  @dataclass
957
- class ResumeMemoryOccupationReqInput:
1140
+ class ResumeMemoryOccupationReqInput(BaseReq):
958
1141
  # Optional tags to identify the memory region, which is primarily used for RL
959
1142
  # Currently we only support `weights` and `kv_cache`
960
1143
  tags: Optional[List[str]] = None
961
1144
 
962
1145
 
963
1146
  @dataclass
964
- class ResumeMemoryOccupationReqOutput:
1147
+ class ResumeMemoryOccupationReqOutput(BaseReq):
965
1148
  pass
966
1149
 
967
1150
 
968
1151
  @dataclass
969
- class SlowDownReqInput:
1152
+ class SlowDownReqInput(BaseReq):
970
1153
  forward_sleep_time: Optional[float]
971
1154
 
972
1155
 
973
1156
  @dataclass
974
- class SlowDownReqOutput:
1157
+ class SlowDownReqOutput(BaseReq):
975
1158
  pass
976
1159
 
977
1160
 
978
1161
  @dataclass
979
- class AbortReq:
980
- # The request id
981
- rid: str = ""
1162
+ class AbortReq(BaseReq):
982
1163
  # Whether to abort all requests
983
1164
  abort_all: bool = False
984
1165
  # The finished reason data
985
1166
  finished_reason: Optional[Dict[str, Any]] = None
986
- # used in MultiTokenzierManager mode
987
- rids: Optional[Union[List[str], str]] = None
1167
+ abort_reason: Optional[str] = None
988
1168
 
989
1169
  def __post_init__(self):
990
- self.rids = self.rid
1170
+ # FIXME: This is a hack to keep the same with the old code
1171
+ if self.rid is None:
1172
+ self.rid = ""
991
1173
 
992
1174
 
993
1175
  @dataclass
994
- class GetInternalStateReq:
1176
+ class GetInternalStateReq(BaseReq):
995
1177
  pass
996
1178
 
997
1179
 
998
1180
  @dataclass
999
- class GetInternalStateReqOutput:
1181
+ class GetInternalStateReqOutput(BaseReq):
1000
1182
  internal_state: Dict[Any, Any]
1001
1183
 
1002
1184
 
1003
1185
  @dataclass
1004
- class SetInternalStateReq:
1186
+ class SetInternalStateReq(BaseReq):
1005
1187
  server_args: Dict[str, Any]
1006
1188
 
1007
1189
 
1008
1190
  @dataclass
1009
- class SetInternalStateReqOutput:
1191
+ class SetInternalStateReqOutput(BaseReq):
1010
1192
  updated: bool
1011
1193
  server_args: Dict[str, Any]
1012
1194
 
1013
1195
 
1014
1196
  @dataclass
1015
- class ProfileReqInput:
1197
+ class ProfileReqInput(BaseReq):
1016
1198
  # The output directory
1017
1199
  output_dir: Optional[str] = None
1018
1200
  # If set, it profile as many as this number of steps.
@@ -1032,7 +1214,7 @@ class ProfileReqType(Enum):
1032
1214
 
1033
1215
 
1034
1216
  @dataclass
1035
- class ProfileReq:
1217
+ class ProfileReq(BaseReq):
1036
1218
  type: ProfileReqType
1037
1219
  output_dir: Optional[str] = None
1038
1220
  start_step: Optional[int] = None
@@ -1045,54 +1227,59 @@ class ProfileReq:
1045
1227
 
1046
1228
 
1047
1229
  @dataclass
1048
- class ProfileReqOutput:
1230
+ class ProfileReqOutput(BaseReq):
1049
1231
  success: bool
1050
1232
  message: str
1051
1233
 
1052
1234
 
1053
1235
  @dataclass
1054
- class FreezeGCReq:
1236
+ class FreezeGCReq(BaseReq):
1055
1237
  pass
1056
1238
 
1057
1239
 
1058
1240
  @dataclass
1059
- class ConfigureLoggingReq:
1241
+ class ConfigureLoggingReq(BaseReq):
1060
1242
  log_requests: Optional[bool] = None
1061
1243
  log_requests_level: Optional[int] = None
1062
1244
  dump_requests_folder: Optional[str] = None
1063
1245
  dump_requests_threshold: Optional[int] = None
1246
+ crash_dump_folder: Optional[str] = None
1064
1247
 
1065
1248
 
1066
1249
  @dataclass
1067
- class OpenSessionReqInput:
1250
+ class OpenSessionReqInput(BaseReq):
1068
1251
  capacity_of_str_len: int
1069
1252
  session_id: Optional[str] = None
1070
1253
 
1071
1254
 
1072
1255
  @dataclass
1073
- class CloseSessionReqInput:
1256
+ class CloseSessionReqInput(BaseReq):
1074
1257
  session_id: str
1075
1258
 
1076
1259
 
1077
1260
  @dataclass
1078
- class OpenSessionReqOutput:
1261
+ class OpenSessionReqOutput(BaseReq):
1079
1262
  session_id: Optional[str]
1080
1263
  success: bool
1081
1264
 
1082
1265
 
1083
1266
  @dataclass
1084
- class HealthCheckOutput:
1267
+ class HealthCheckOutput(BaseReq):
1085
1268
  pass
1086
1269
 
1087
1270
 
1088
- class ExpertDistributionReq(Enum):
1271
+ class ExpertDistributionReqType(Enum):
1089
1272
  START_RECORD = 1
1090
1273
  STOP_RECORD = 2
1091
1274
  DUMP_RECORD = 3
1092
1275
 
1093
1276
 
1277
+ class ExpertDistributionReq(BaseReq):
1278
+ action: ExpertDistributionReqType
1279
+
1280
+
1094
1281
  @dataclass
1095
- class ExpertDistributionReqOutput:
1282
+ class ExpertDistributionReqOutput(BaseReq):
1096
1283
  pass
1097
1284
 
1098
1285
 
@@ -1110,7 +1297,7 @@ class Tool:
1110
1297
 
1111
1298
 
1112
1299
  @dataclass
1113
- class ParseFunctionCallReq:
1300
+ class ParseFunctionCallReq(BaseReq):
1114
1301
  text: str # The text to parse.
1115
1302
  tools: List[Tool] = field(
1116
1303
  default_factory=list
@@ -1121,31 +1308,31 @@ class ParseFunctionCallReq:
1121
1308
 
1122
1309
 
1123
1310
  @dataclass
1124
- class SeparateReasoningReqInput:
1311
+ class SeparateReasoningReqInput(BaseReq):
1125
1312
  text: str # The text to parse.
1126
1313
  reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
1127
1314
 
1128
1315
 
1129
1316
  @dataclass
1130
- class VertexGenerateReqInput:
1317
+ class VertexGenerateReqInput(BaseReq):
1131
1318
  instances: List[dict]
1132
1319
  parameters: Optional[dict] = None
1133
1320
 
1134
1321
 
1135
1322
  @dataclass
1136
- class RpcReqInput:
1323
+ class RpcReqInput(BaseReq):
1137
1324
  method: str
1138
1325
  parameters: Optional[Dict] = None
1139
1326
 
1140
1327
 
1141
1328
  @dataclass
1142
- class RpcReqOutput:
1329
+ class RpcReqOutput(BaseReq):
1143
1330
  success: bool
1144
1331
  message: str
1145
1332
 
1146
1333
 
1147
1334
  @dataclass
1148
- class LoadLoRAAdapterReqInput:
1335
+ class LoadLoRAAdapterReqInput(BaseReq):
1149
1336
  # The name of the lora module to newly loaded.
1150
1337
  lora_name: str
1151
1338
  # The path of loading.
@@ -1165,7 +1352,7 @@ class LoadLoRAAdapterReqInput:
1165
1352
 
1166
1353
 
1167
1354
  @dataclass
1168
- class UnloadLoRAAdapterReqInput:
1355
+ class UnloadLoRAAdapterReqInput(BaseReq):
1169
1356
  # The name of lora module to unload.
1170
1357
  lora_name: str
1171
1358
  # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
@@ -1179,23 +1366,23 @@ class UnloadLoRAAdapterReqInput:
1179
1366
 
1180
1367
 
1181
1368
  @dataclass
1182
- class LoRAUpdateResult:
1369
+ class LoRAUpdateOutput(BaseReq):
1183
1370
  success: bool
1184
1371
  error_message: Optional[str] = None
1185
1372
  loaded_adapters: Optional[Dict[str, LoRARef]] = None
1186
1373
 
1187
1374
 
1188
- LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
1375
+ LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
1189
1376
 
1190
1377
 
1191
1378
  @dataclass
1192
- class MultiTokenizerRegisterReq:
1193
- rids: Optional[Union[List[str], str]] = None
1379
+ class MultiTokenizerRegisterReq(BaseBatchReq):
1194
1380
  ipc_name: Optional[str] = None
1195
1381
 
1196
1382
 
1197
1383
  @dataclass
1198
- class MultiTokenizerWarpper:
1384
+ class MultiTokenizerWrapper:
1385
+ # FIXME(lsyin): remove this
1199
1386
  worker_id: int
1200
1387
  obj: Optional[Any] = None
1201
1388
 
@@ -1206,5 +1393,49 @@ class BlockReqType(Enum):
1206
1393
 
1207
1394
 
1208
1395
  @dataclass
1209
- class BlockReqInput:
1396
+ class BlockReqInput(BaseReq):
1210
1397
  type: BlockReqType
1398
+
1399
+
1400
+ @dataclass
1401
+ class GetLoadReqInput(BaseReq):
1402
+ pass
1403
+
1404
+
1405
+ @dataclass
1406
+ class GetLoadReqOutput(BaseReq):
1407
+ dp_rank: int
1408
+ num_reqs: int
1409
+ num_waiting_reqs: int
1410
+ num_tokens: int
1411
+
1412
+
1413
+ @dataclass
1414
+ class WatchLoadUpdateReq(BaseReq):
1415
+ loads: List[GetLoadReqOutput]
1416
+
1417
+
1418
+ def _check_all_req_types():
1419
+ """A helper function to check all request types are defined in this file."""
1420
+ import inspect
1421
+ import sys
1422
+
1423
+ all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
1424
+ for class_type in all_classes:
1425
+ # check its name
1426
+ name = class_type[0]
1427
+ is_io_struct = (
1428
+ name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
1429
+ )
1430
+ is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
1431
+ class_type[1], BaseBatchReq
1432
+ )
1433
+ if is_io_struct and not is_base_req:
1434
+ raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
1435
+ if is_base_req and not is_io_struct:
1436
+ raise ValueError(
1437
+ f"{name} is a subclass of BaseReq but not follow the naming convention."
1438
+ )
1439
+
1440
+
1441
+ _check_all_req_types()