sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
18
18
 
19
19
  import copy
20
20
  import uuid
21
+ from abc import ABC
21
22
  from dataclasses import dataclass, field
22
23
  from enum import Enum
23
24
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@ -35,6 +36,32 @@ else:
35
36
  Image = Any
36
37
 
37
38
 
39
+ @dataclass
40
+ class BaseReq(ABC):
41
+ rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
42
+ http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
43
+
44
+ def regenerate_rid(self):
45
+ """Generate a new request ID and return it."""
46
+ if isinstance(self.rid, list):
47
+ self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
48
+ else:
49
+ self.rid = uuid.uuid4().hex
50
+ return self.rid
51
+
52
+
53
+ @dataclass
54
+ class BaseBatchReq(ABC):
55
+ rids: Optional[List[str]] = field(default=None, kw_only=True)
56
+ http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
57
+
58
+ def regenerate_rids(self):
59
+ """Generate new request IDs and return them."""
60
+ self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
61
+ return self.rids
62
+
63
+
64
+ # Parameters for a session
38
65
  @dataclass
39
66
  class SessionParams:
40
67
  id: Optional[str] = None
@@ -62,7 +89,7 @@ MultimodalDataInputFormat = Union[
62
89
 
63
90
 
64
91
  @dataclass
65
- class GenerateReqInput:
92
+ class GenerateReqInput(BaseReq):
66
93
  # The input prompt. It can be a single prompt or a batch of prompts.
67
94
  text: Optional[Union[List[str], str]] = None
68
95
  # The token ids for text; one can specify either text or input_ids
@@ -82,8 +109,6 @@ class GenerateReqInput:
82
109
  audio_data: Optional[MultimodalDataInputFormat] = None
83
110
  # The sampling_params. See descriptions below.
84
111
  sampling_params: Optional[Union[List[Dict], Dict]] = None
85
- # The request id.
86
- rid: Optional[Union[List[str], str]] = None
87
112
  # Whether to return logprobs.
88
113
  return_logprob: Optional[Union[List[bool], bool]] = None
89
114
  # If return logprobs, the start location in the prompt for returning logprobs.
@@ -132,17 +157,23 @@ class GenerateReqInput:
132
157
  # Conversation id used for tracking requests
133
158
  conversation_id: Optional[str] = None
134
159
 
135
- # Label for the request
136
- label: Optional[str] = None
137
-
138
160
  # Priority for the request
139
161
  priority: Optional[int] = None
140
162
 
141
- # Image gen grpc migration
163
+ # Extra key for classifying the request (e.g. cache_salt)
164
+ extra_key: Optional[Union[List[str], str]] = None
165
+
166
+ # Whether to disallow logging for this request (e.g. due to ZDR)
167
+ no_logs: bool = False
168
+
169
+ # For custom metric labels
170
+ custom_labels: Optional[Dict[str, str]] = None
171
+
172
+ # (Internal) Whether to return bytes for image generation
142
173
  return_bytes: bool = False
143
174
 
144
- # For customer metric labels
145
- customer_labels: Optional[Dict[str, str]] = None
175
+ # Whether to return entropy
176
+ return_entropy: bool = False
146
177
 
147
178
  def contains_mm_input(self) -> bool:
148
179
  return (
@@ -485,11 +516,6 @@ class GenerateReqInput:
485
516
  ):
486
517
  raise ValueError("Session params must be a dict or a list of dicts.")
487
518
 
488
- def regenerate_rid(self):
489
- """Generate a new request ID and return it."""
490
- self.rid = uuid.uuid4().hex
491
- return self.rid
492
-
493
519
  def __getitem__(self, i):
494
520
  return GenerateReqInput(
495
521
  text=self.text[i] if self.text is not None else None,
@@ -542,16 +568,17 @@ class GenerateReqInput:
542
568
  self.data_parallel_rank if self.data_parallel_rank is not None else None
543
569
  ),
544
570
  conversation_id=self.conversation_id,
545
- label=self.label,
546
571
  priority=self.priority,
572
+ extra_key=self.extra_key,
573
+ no_logs=self.no_logs,
574
+ custom_labels=self.custom_labels,
547
575
  return_bytes=self.return_bytes,
576
+ return_entropy=self.return_entropy,
548
577
  )
549
578
 
550
579
 
551
580
  @dataclass
552
- class TokenizedGenerateReqInput:
553
- # The request id
554
- rid: str
581
+ class TokenizedGenerateReqInput(BaseReq):
555
582
  # The input text
556
583
  input_text: str
557
584
  # The input token ids
@@ -570,6 +597,7 @@ class TokenizedGenerateReqInput:
570
597
  token_ids_logprob: List[int]
571
598
  # Whether to stream output
572
599
  stream: bool
600
+
573
601
  # Whether to return hidden states
574
602
  return_hidden_states: bool = False
575
603
 
@@ -596,24 +624,27 @@ class TokenizedGenerateReqInput:
596
624
  # For data parallel rank routing
597
625
  data_parallel_rank: Optional[int] = None
598
626
 
599
- # For dp balance
600
- dp_balance_id: int = -1
601
-
602
- # Label for the request
603
- label: Optional[str] = None
604
-
605
627
  # Priority for the request
606
628
  priority: Optional[int] = None
607
629
 
608
- # Image gen grpc migration
609
- return_bytes: bool = False
630
+ # Extra key for classifying the request (e.g. cache_salt)
631
+ extra_key: Optional[str] = None
632
+
633
+ # Whether to disallow logging for this request (e.g. due to ZDR)
634
+ no_logs: bool = False
610
635
 
611
636
  # tracing context
612
637
  trace_context: Optional[Dict] = None
613
638
 
639
+ # (Internal) Whether to return bytes for image generation
640
+ return_bytes: bool = False
641
+
642
+ # Whether to return entropy
643
+ return_entropy: bool = False
644
+
614
645
 
615
646
  @dataclass
616
- class BatchTokenizedGenerateReqInput:
647
+ class BatchTokenizedGenerateReqInput(BaseBatchReq):
617
648
  # The batch of tokenized requests
618
649
  batch: List[TokenizedGenerateReqInput]
619
650
 
@@ -628,7 +659,7 @@ class BatchTokenizedGenerateReqInput:
628
659
 
629
660
 
630
661
  @dataclass
631
- class EmbeddingReqInput:
662
+ class EmbeddingReqInput(BaseReq):
632
663
  # The input prompt. It can be a single prompt or a batch of prompts.
633
664
  text: Optional[Union[List[List[str]], List[str], str]] = None
634
665
  # The image input. It can be an image instance, file name, URL, or base64 encoded string.
@@ -644,8 +675,6 @@ class EmbeddingReqInput:
644
675
  audio_data: Optional[MultimodalDataInputFormat] = None
645
676
  # The token ids for text; one can either specify text or input_ids.
646
677
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
647
- # The request id.
648
- rid: Optional[Union[List[str], str]] = None
649
678
  # Dummy sampling params for compatibility
650
679
  sampling_params: Optional[Union[List[Dict], Dict]] = None
651
680
  # Dummy input embeds for compatibility
@@ -656,6 +685,8 @@ class EmbeddingReqInput:
656
685
  modalities: Optional[List[str]] = None
657
686
  # For cross-encoder requests
658
687
  is_cross_encoder_request: bool = False
688
+ # Priority for the request
689
+ priority: Optional[int] = None
659
690
 
660
691
  # For background responses (OpenAI responses API)
661
692
  background: bool = False
@@ -714,10 +745,6 @@ class EmbeddingReqInput:
714
745
  for i in range(self.batch_size):
715
746
  self.sampling_params[i]["max_new_tokens"] = 0
716
747
 
717
- def regenerate_rid(self):
718
- self.rid = uuid.uuid4().hex
719
- return self.rid
720
-
721
748
  def contains_mm_input(self) -> bool:
722
749
  return (
723
750
  has_valid_data(self.image_data)
@@ -746,9 +773,7 @@ class EmbeddingReqInput:
746
773
 
747
774
 
748
775
  @dataclass
749
- class TokenizedEmbeddingReqInput:
750
- # The request id
751
- rid: str
776
+ class TokenizedEmbeddingReqInput(BaseReq):
752
777
  # The input text
753
778
  input_text: str
754
779
  # The input token ids
@@ -761,12 +786,12 @@ class TokenizedEmbeddingReqInput:
761
786
  sampling_params: SamplingParams
762
787
  # For data parallel rank routing
763
788
  data_parallel_rank: Optional[int] = None
764
- # For dp balance
765
- dp_balance_id: int = -1
789
+ # Priority for the request
790
+ priority: Optional[int] = None
766
791
 
767
792
 
768
793
  @dataclass
769
- class BatchTokenizedEmbeddingReqInput:
794
+ class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
770
795
  # The batch of tokenized embedding requests
771
796
  batch: List[TokenizedEmbeddingReqInput]
772
797
 
@@ -781,9 +806,7 @@ class BatchTokenizedEmbeddingReqInput:
781
806
 
782
807
 
783
808
  @dataclass
784
- class BatchTokenIDOut:
785
- # The request id
786
- rids: List[str]
809
+ class BatchTokenIDOutput(BaseBatchReq):
787
810
  # The finish reason
788
811
  finished_reasons: List[BaseFinishReason]
789
812
  # For incremental decoding
@@ -802,6 +825,7 @@ class BatchTokenIDOut:
802
825
  completion_tokens: List[int]
803
826
  cached_tokens: List[int]
804
827
  spec_verify_ct: List[int]
828
+ spec_accepted_tokens: List[int]
805
829
 
806
830
  # Logprobs
807
831
  input_token_logprobs_val: List[float]
@@ -816,6 +840,7 @@ class BatchTokenIDOut:
816
840
  input_token_ids_logprobs_idx: List[List]
817
841
  output_token_ids_logprobs_val: List[List]
818
842
  output_token_ids_logprobs_idx: List[List]
843
+ output_token_entropy_val: List[float]
819
844
 
820
845
  # Hidden states
821
846
  output_hidden_states: List[List[float]]
@@ -826,9 +851,12 @@ class BatchTokenIDOut:
826
851
  placeholder_tokens_idx: List[Optional[List[int]]]
827
852
  placeholder_tokens_val: List[Optional[List[int]]]
828
853
 
854
+ # The trainer step id. Used to know which step's weights are used for sampling.
855
+ token_steps: List[List[int]] = None
856
+
829
857
 
830
858
  @dataclass
831
- class BatchMultimodalDecodeReq:
859
+ class BatchMultimodalDecodeReq(BaseBatchReq):
832
860
  decoded_ids: List[int]
833
861
  input_token_logprobs_val: List[float]
834
862
  input_token_logprobs_idx: List[int]
@@ -840,8 +868,6 @@ class BatchMultimodalDecodeReq:
840
868
  image_resolutions: List[List[int]]
841
869
  resize_image_resolutions: List[List[int]]
842
870
 
843
- # The request id
844
- rids: List[str]
845
871
  finished_reasons: List[BaseFinishReason]
846
872
 
847
873
  # Token counts
@@ -849,17 +875,20 @@ class BatchMultimodalDecodeReq:
849
875
  completion_tokens: List[int]
850
876
  cached_tokens: List[int]
851
877
 
852
- # Placeholder token info
878
+ # The information of placeholder tokens (e.g., image token)
879
+ # idx is the index of the token in the prompt after expansion.
880
+ # val is the length of padded tokens after expansion.
853
881
  placeholder_tokens_idx: List[Optional[List[int]]]
854
882
  placeholder_tokens_val: List[Optional[List[int]]]
855
883
 
856
- return_bytes: bool = False
884
+ return_bytes: List[bool]
885
+
886
+ # The trainer step id. Used to know which step's weights are used for sampling.
887
+ token_steps: List[List[int]] = None
857
888
 
858
889
 
859
890
  @dataclass
860
- class BatchStrOut:
861
- # The request id
862
- rids: List[str]
891
+ class BatchStrOutput(BaseBatchReq):
863
892
  # The finish reason
864
893
  finished_reasons: List[dict]
865
894
  # The output decoded strings
@@ -872,6 +901,7 @@ class BatchStrOut:
872
901
  completion_tokens: List[int]
873
902
  cached_tokens: List[int]
874
903
  spec_verify_ct: List[int]
904
+ spec_accepted_tokens: List[int]
875
905
 
876
906
  # Logprobs
877
907
  input_token_logprobs_val: List[float]
@@ -886,18 +916,23 @@ class BatchStrOut:
886
916
  input_token_ids_logprobs_idx: List[List]
887
917
  output_token_ids_logprobs_val: List[List]
888
918
  output_token_ids_logprobs_idx: List[List]
919
+ output_token_entropy_val: List[float]
889
920
 
890
921
  # Hidden states
891
922
  output_hidden_states: List[List[float]]
892
923
 
924
+ # The information of placeholder tokens (e.g., image token)
925
+ # idx is the index of the token in the prompt after expansion.
926
+ # val is the length of padded tokens after expansion.
893
927
  placeholder_tokens_idx: List[Optional[List[int]]]
894
928
  placeholder_tokens_val: List[Optional[List[int]]]
895
929
 
930
+ # The trainer step id. Used to know which step's weights are used for sampling.
931
+ token_steps: List[List[int]] = None
932
+
896
933
 
897
934
  @dataclass
898
- class BatchMultimodalOut:
899
- # The request id
900
- rids: List[str]
935
+ class BatchMultimodalOutput(BaseBatchReq):
901
936
  # The finish reason
902
937
  finished_reasons: List[dict]
903
938
  decoded_ids: List[List[int]]
@@ -922,13 +957,11 @@ class BatchMultimodalOut:
922
957
 
923
958
 
924
959
  @dataclass
925
- class BatchEmbeddingOut:
926
- # The request id
927
- rids: List[str]
960
+ class BatchEmbeddingOutput(BaseBatchReq):
928
961
  # The finish reason
929
962
  finished_reasons: List[BaseFinishReason]
930
963
  # The output embedding
931
- embeddings: List[List[float]]
964
+ embeddings: Union[List[List[float]], List[Dict[int, float]]]
932
965
  # Token counts
933
966
  prompt_tokens: List[int]
934
967
  cached_tokens: List[int]
@@ -938,27 +971,27 @@ class BatchEmbeddingOut:
938
971
 
939
972
 
940
973
  @dataclass
941
- class ClearHiCacheReqInput:
974
+ class ClearHiCacheReqInput(BaseReq):
942
975
  pass
943
976
 
944
977
 
945
978
  @dataclass
946
- class ClearHiCacheReqOutput:
979
+ class ClearHiCacheReqOutput(BaseReq):
947
980
  success: bool
948
981
 
949
982
 
950
983
  @dataclass
951
- class FlushCacheReqInput:
984
+ class FlushCacheReqInput(BaseReq):
952
985
  pass
953
986
 
954
987
 
955
988
  @dataclass
956
- class FlushCacheReqOutput:
989
+ class FlushCacheReqOutput(BaseReq):
957
990
  success: bool
958
991
 
959
992
 
960
993
  @dataclass
961
- class UpdateWeightFromDiskReqInput:
994
+ class UpdateWeightFromDiskReqInput(BaseReq):
962
995
  # The model path with the new weights
963
996
  model_path: str
964
997
  # The format to load the weights
@@ -973,10 +1006,12 @@ class UpdateWeightFromDiskReqInput:
973
1006
  torch_empty_cache: bool = False
974
1007
  # Whether to keep the scheduler paused after weight update
975
1008
  keep_pause: bool = False
1009
+ # The trainer step id. Used to know which step's weights are used for sampling.
1010
+ token_step: int = 0
976
1011
 
977
1012
 
978
1013
  @dataclass
979
- class UpdateWeightFromDiskReqOutput:
1014
+ class UpdateWeightFromDiskReqOutput(BaseReq):
980
1015
  success: bool
981
1016
  message: str
982
1017
  # Number of paused requests during weight sync.
@@ -984,7 +1019,7 @@ class UpdateWeightFromDiskReqOutput:
984
1019
 
985
1020
 
986
1021
  @dataclass
987
- class UpdateWeightsFromDistributedReqInput:
1022
+ class UpdateWeightsFromDistributedReqInput(BaseReq):
988
1023
  names: List[str]
989
1024
  dtypes: List[str]
990
1025
  shapes: List[List[int]]
@@ -999,13 +1034,13 @@ class UpdateWeightsFromDistributedReqInput:
999
1034
 
1000
1035
 
1001
1036
  @dataclass
1002
- class UpdateWeightsFromDistributedReqOutput:
1037
+ class UpdateWeightsFromDistributedReqOutput(BaseReq):
1003
1038
  success: bool
1004
1039
  message: str
1005
1040
 
1006
1041
 
1007
1042
  @dataclass
1008
- class UpdateWeightsFromTensorReqInput:
1043
+ class UpdateWeightsFromTensorReqInput(BaseReq):
1009
1044
  """Update model weights from tensor input.
1010
1045
 
1011
1046
  - Tensors are serialized for transmission
@@ -1024,13 +1059,13 @@ class UpdateWeightsFromTensorReqInput:
1024
1059
 
1025
1060
 
1026
1061
  @dataclass
1027
- class UpdateWeightsFromTensorReqOutput:
1062
+ class UpdateWeightsFromTensorReqOutput(BaseReq):
1028
1063
  success: bool
1029
1064
  message: str
1030
1065
 
1031
1066
 
1032
1067
  @dataclass
1033
- class InitWeightsSendGroupForRemoteInstanceReqInput:
1068
+ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
1034
1069
  # The master address
1035
1070
  master_address: str
1036
1071
  # The ports for each rank's communication group
@@ -1045,14 +1080,32 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
1045
1080
  backend: str = "nccl"
1046
1081
 
1047
1082
 
1083
+ # Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
1084
+ # are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
1085
+ @dataclass
1086
+ class UpdateWeightsFromIPCReqInput(BaseReq):
1087
+ # ZMQ socket paths for each device UUID
1088
+ zmq_handles: Dict[str, str]
1089
+ # Whether to flush cache after weight update
1090
+ flush_cache: bool = True
1091
+ # Optional: Update weight version along with weights
1092
+ weight_version: Optional[str] = None
1093
+
1094
+
1095
+ @dataclass
1096
+ class UpdateWeightsFromIPCReqOutput(BaseReq):
1097
+ success: bool
1098
+ message: str
1099
+
1100
+
1048
1101
  @dataclass
1049
- class InitWeightsSendGroupForRemoteInstanceReqOutput:
1102
+ class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
1050
1103
  success: bool
1051
1104
  message: str
1052
1105
 
1053
1106
 
1054
1107
  @dataclass
1055
- class SendWeightsToRemoteInstanceReqInput:
1108
+ class SendWeightsToRemoteInstanceReqInput(BaseReq):
1056
1109
  # The master address
1057
1110
  master_address: str
1058
1111
  # The ports for each rank's communication group
@@ -1062,13 +1115,13 @@ class SendWeightsToRemoteInstanceReqInput:
1062
1115
 
1063
1116
 
1064
1117
  @dataclass
1065
- class SendWeightsToRemoteInstanceReqOutput:
1118
+ class SendWeightsToRemoteInstanceReqOutput(BaseReq):
1066
1119
  success: bool
1067
1120
  message: str
1068
1121
 
1069
1122
 
1070
1123
  @dataclass
1071
- class InitWeightsUpdateGroupReqInput:
1124
+ class InitWeightsUpdateGroupReqInput(BaseReq):
1072
1125
  # The master address
1073
1126
  master_address: str
1074
1127
  # The master port
@@ -1084,13 +1137,24 @@ class InitWeightsUpdateGroupReqInput:
1084
1137
 
1085
1138
 
1086
1139
  @dataclass
1087
- class InitWeightsUpdateGroupReqOutput:
1140
+ class InitWeightsUpdateGroupReqOutput(BaseReq):
1088
1141
  success: bool
1089
1142
  message: str
1090
1143
 
1091
1144
 
1092
1145
  @dataclass
1093
- class UpdateWeightVersionReqInput:
1146
+ class DestroyWeightsUpdateGroupReqInput(BaseReq):
1147
+ group_name: str = "weight_update_group"
1148
+
1149
+
1150
+ @dataclass
1151
+ class DestroyWeightsUpdateGroupReqOutput(BaseReq):
1152
+ success: bool
1153
+ message: str
1154
+
1155
+
1156
+ @dataclass
1157
+ class UpdateWeightVersionReqInput(BaseReq):
1094
1158
  # The new weight version
1095
1159
  new_version: str
1096
1160
  # Whether to abort all running requests before updating
@@ -1098,89 +1162,87 @@ class UpdateWeightVersionReqInput:
1098
1162
 
1099
1163
 
1100
1164
  @dataclass
1101
- class GetWeightsByNameReqInput:
1165
+ class GetWeightsByNameReqInput(BaseReq):
1102
1166
  name: str
1103
1167
  truncate_size: int = 100
1104
1168
 
1105
1169
 
1106
1170
  @dataclass
1107
- class GetWeightsByNameReqOutput:
1171
+ class GetWeightsByNameReqOutput(BaseReq):
1108
1172
  parameter: list
1109
1173
 
1110
1174
 
1111
1175
  @dataclass
1112
- class ReleaseMemoryOccupationReqInput:
1176
+ class ReleaseMemoryOccupationReqInput(BaseReq):
1113
1177
  # Optional tags to identify the memory region, which is primarily used for RL
1114
1178
  # Currently we only support `weights` and `kv_cache`
1115
1179
  tags: Optional[List[str]] = None
1116
1180
 
1117
1181
 
1118
1182
  @dataclass
1119
- class ReleaseMemoryOccupationReqOutput:
1183
+ class ReleaseMemoryOccupationReqOutput(BaseReq):
1120
1184
  pass
1121
1185
 
1122
1186
 
1123
1187
  @dataclass
1124
- class ResumeMemoryOccupationReqInput:
1188
+ class ResumeMemoryOccupationReqInput(BaseReq):
1125
1189
  # Optional tags to identify the memory region, which is primarily used for RL
1126
1190
  # Currently we only support `weights` and `kv_cache`
1127
1191
  tags: Optional[List[str]] = None
1128
1192
 
1129
1193
 
1130
1194
  @dataclass
1131
- class ResumeMemoryOccupationReqOutput:
1195
+ class ResumeMemoryOccupationReqOutput(BaseReq):
1132
1196
  pass
1133
1197
 
1134
1198
 
1135
1199
  @dataclass
1136
- class SlowDownReqInput:
1200
+ class SlowDownReqInput(BaseReq):
1137
1201
  forward_sleep_time: Optional[float]
1138
1202
 
1139
1203
 
1140
1204
  @dataclass
1141
- class SlowDownReqOutput:
1205
+ class SlowDownReqOutput(BaseReq):
1142
1206
  pass
1143
1207
 
1144
1208
 
1145
1209
  @dataclass
1146
- class AbortReq:
1147
- # The request id
1148
- rid: str = ""
1210
+ class AbortReq(BaseReq):
1149
1211
  # Whether to abort all requests
1150
1212
  abort_all: bool = False
1151
1213
  # The finished reason data
1152
1214
  finished_reason: Optional[Dict[str, Any]] = None
1153
1215
  abort_reason: Optional[str] = None
1154
- # used in MultiTokenzierManager mode
1155
- rids: Optional[Union[List[str], str]] = None
1156
1216
 
1157
1217
  def __post_init__(self):
1158
- self.rids = self.rid
1218
+ # FIXME: This is a hack to keep the same with the old code
1219
+ if self.rid is None:
1220
+ self.rid = ""
1159
1221
 
1160
1222
 
1161
1223
  @dataclass
1162
- class GetInternalStateReq:
1224
+ class GetInternalStateReq(BaseReq):
1163
1225
  pass
1164
1226
 
1165
1227
 
1166
1228
  @dataclass
1167
- class GetInternalStateReqOutput:
1229
+ class GetInternalStateReqOutput(BaseReq):
1168
1230
  internal_state: Dict[Any, Any]
1169
1231
 
1170
1232
 
1171
1233
  @dataclass
1172
- class SetInternalStateReq:
1234
+ class SetInternalStateReq(BaseReq):
1173
1235
  server_args: Dict[str, Any]
1174
1236
 
1175
1237
 
1176
1238
  @dataclass
1177
- class SetInternalStateReqOutput:
1239
+ class SetInternalStateReqOutput(BaseReq):
1178
1240
  updated: bool
1179
1241
  server_args: Dict[str, Any]
1180
1242
 
1181
1243
 
1182
1244
  @dataclass
1183
- class ProfileReqInput:
1245
+ class ProfileReqInput(BaseReq):
1184
1246
  # The output directory
1185
1247
  output_dir: Optional[str] = None
1186
1248
  # If set, it profile as many as this number of steps.
@@ -1192,6 +1254,8 @@ class ProfileReqInput:
1192
1254
  profile_by_stage: bool = False
1193
1255
  with_stack: Optional[bool] = None
1194
1256
  record_shapes: Optional[bool] = None
1257
+ # Merge profiles from all ranks into a single trace
1258
+ merge_profiles: bool = False
1195
1259
 
1196
1260
 
1197
1261
  class ProfileReqType(Enum):
@@ -1200,7 +1264,7 @@ class ProfileReqType(Enum):
1200
1264
 
1201
1265
 
1202
1266
  @dataclass
1203
- class ProfileReq:
1267
+ class ProfileReq(BaseReq):
1204
1268
  type: ProfileReqType
1205
1269
  output_dir: Optional[str] = None
1206
1270
  start_step: Optional[int] = None
@@ -1210,21 +1274,23 @@ class ProfileReq:
1210
1274
  with_stack: Optional[bool] = None
1211
1275
  record_shapes: Optional[bool] = None
1212
1276
  profile_id: Optional[str] = None
1277
+ # Merge profiles from all ranks into a single trace
1278
+ merge_profiles: bool = False
1213
1279
 
1214
1280
 
1215
1281
  @dataclass
1216
- class ProfileReqOutput:
1282
+ class ProfileReqOutput(BaseReq):
1217
1283
  success: bool
1218
1284
  message: str
1219
1285
 
1220
1286
 
1221
1287
  @dataclass
1222
- class FreezeGCReq:
1288
+ class FreezeGCReq(BaseReq):
1223
1289
  pass
1224
1290
 
1225
1291
 
1226
1292
  @dataclass
1227
- class ConfigureLoggingReq:
1293
+ class ConfigureLoggingReq(BaseReq):
1228
1294
  log_requests: Optional[bool] = None
1229
1295
  log_requests_level: Optional[int] = None
1230
1296
  dump_requests_folder: Optional[str] = None
@@ -1233,35 +1299,39 @@ class ConfigureLoggingReq:
1233
1299
 
1234
1300
 
1235
1301
  @dataclass
1236
- class OpenSessionReqInput:
1302
+ class OpenSessionReqInput(BaseReq):
1237
1303
  capacity_of_str_len: int
1238
1304
  session_id: Optional[str] = None
1239
1305
 
1240
1306
 
1241
1307
  @dataclass
1242
- class CloseSessionReqInput:
1308
+ class CloseSessionReqInput(BaseReq):
1243
1309
  session_id: str
1244
1310
 
1245
1311
 
1246
1312
  @dataclass
1247
- class OpenSessionReqOutput:
1313
+ class OpenSessionReqOutput(BaseReq):
1248
1314
  session_id: Optional[str]
1249
1315
  success: bool
1250
1316
 
1251
1317
 
1252
1318
  @dataclass
1253
- class HealthCheckOutput:
1319
+ class HealthCheckOutput(BaseReq):
1254
1320
  pass
1255
1321
 
1256
1322
 
1257
- class ExpertDistributionReq(Enum):
1323
+ class ExpertDistributionReqType(Enum):
1258
1324
  START_RECORD = 1
1259
1325
  STOP_RECORD = 2
1260
1326
  DUMP_RECORD = 3
1261
1327
 
1262
1328
 
1329
+ class ExpertDistributionReq(BaseReq):
1330
+ action: ExpertDistributionReqType
1331
+
1332
+
1263
1333
  @dataclass
1264
- class ExpertDistributionReqOutput:
1334
+ class ExpertDistributionReqOutput(BaseReq):
1265
1335
  pass
1266
1336
 
1267
1337
 
@@ -1279,7 +1349,7 @@ class Tool:
1279
1349
 
1280
1350
 
1281
1351
  @dataclass
1282
- class ParseFunctionCallReq:
1352
+ class ParseFunctionCallReq(BaseReq):
1283
1353
  text: str # The text to parse.
1284
1354
  tools: List[Tool] = field(
1285
1355
  default_factory=list
@@ -1290,31 +1360,31 @@ class ParseFunctionCallReq:
1290
1360
 
1291
1361
 
1292
1362
  @dataclass
1293
- class SeparateReasoningReqInput:
1363
+ class SeparateReasoningReqInput(BaseReq):
1294
1364
  text: str # The text to parse.
1295
1365
  reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
1296
1366
 
1297
1367
 
1298
1368
  @dataclass
1299
- class VertexGenerateReqInput:
1369
+ class VertexGenerateReqInput(BaseReq):
1300
1370
  instances: List[dict]
1301
1371
  parameters: Optional[dict] = None
1302
1372
 
1303
1373
 
1304
1374
  @dataclass
1305
- class RpcReqInput:
1375
+ class RpcReqInput(BaseReq):
1306
1376
  method: str
1307
1377
  parameters: Optional[Dict] = None
1308
1378
 
1309
1379
 
1310
1380
  @dataclass
1311
- class RpcReqOutput:
1381
+ class RpcReqOutput(BaseReq):
1312
1382
  success: bool
1313
1383
  message: str
1314
1384
 
1315
1385
 
1316
1386
  @dataclass
1317
- class LoadLoRAAdapterReqInput:
1387
+ class LoadLoRAAdapterReqInput(BaseReq):
1318
1388
  # The name of the lora module to newly loaded.
1319
1389
  lora_name: str
1320
1390
  # The path of loading.
@@ -1334,7 +1404,7 @@ class LoadLoRAAdapterReqInput:
1334
1404
 
1335
1405
 
1336
1406
  @dataclass
1337
- class UnloadLoRAAdapterReqInput:
1407
+ class UnloadLoRAAdapterReqInput(BaseReq):
1338
1408
  # The name of lora module to unload.
1339
1409
  lora_name: str
1340
1410
  # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
@@ -1348,25 +1418,13 @@ class UnloadLoRAAdapterReqInput:
1348
1418
 
1349
1419
 
1350
1420
  @dataclass
1351
- class LoRAUpdateResult:
1421
+ class LoRAUpdateOutput(BaseReq):
1352
1422
  success: bool
1353
1423
  error_message: Optional[str] = None
1354
1424
  loaded_adapters: Optional[Dict[str, LoRARef]] = None
1355
1425
 
1356
1426
 
1357
- LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
1358
-
1359
-
1360
- @dataclass
1361
- class MultiTokenizerRegisterReq:
1362
- rids: Optional[Union[List[str], str]] = None
1363
- ipc_name: Optional[str] = None
1364
-
1365
-
1366
- @dataclass
1367
- class MultiTokenizerWrapper:
1368
- worker_id: int
1369
- obj: Optional[Any] = None
1427
+ LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
1370
1428
 
1371
1429
 
1372
1430
  class BlockReqType(Enum):
@@ -1375,17 +1433,17 @@ class BlockReqType(Enum):
1375
1433
 
1376
1434
 
1377
1435
  @dataclass
1378
- class BlockReqInput:
1436
+ class BlockReqInput(BaseReq):
1379
1437
  type: BlockReqType
1380
1438
 
1381
1439
 
1382
1440
  @dataclass
1383
- class GetLoadReqInput:
1441
+ class GetLoadReqInput(BaseReq):
1384
1442
  pass
1385
1443
 
1386
1444
 
1387
1445
  @dataclass
1388
- class GetLoadReqOutput:
1446
+ class GetLoadReqOutput(BaseReq):
1389
1447
  dp_rank: int
1390
1448
  num_reqs: int
1391
1449
  num_waiting_reqs: int
@@ -1393,5 +1451,41 @@ class GetLoadReqOutput:
1393
1451
 
1394
1452
 
1395
1453
  @dataclass
1396
- class WatchLoadUpdateReq:
1454
+ class WatchLoadUpdateReq(BaseReq):
1397
1455
  loads: List[GetLoadReqOutput]
1456
+
1457
+
1458
+ @dataclass
1459
+ class LazyDumpTensorsReqInput(BaseReq):
1460
+ pass
1461
+
1462
+
1463
+ @dataclass
1464
+ class LazyDumpTensorsReqOutput(BaseReq):
1465
+ success: bool
1466
+
1467
+
1468
+ def _check_all_req_types():
1469
+ """A helper function to check all request types are defined in this file."""
1470
+ import inspect
1471
+ import sys
1472
+
1473
+ all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
1474
+ for class_type in all_classes:
1475
+ # check its name
1476
+ name = class_type[0]
1477
+ is_io_struct = (
1478
+ name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
1479
+ )
1480
+ is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
1481
+ class_type[1], BaseBatchReq
1482
+ )
1483
+ if is_io_struct and not is_base_req:
1484
+ raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
1485
+ if is_base_req and not is_io_struct:
1486
+ raise ValueError(
1487
+ f"{name} is a subclass of BaseReq but not follow the naming convention."
1488
+ )
1489
+
1490
+
1491
+ _check_all_req_types()