sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
16
16
  Modality,
17
17
  MultimodalDataItem,
18
18
  MultimodalInputs,
19
- global_server_args_dict,
20
19
  )
21
20
  from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
22
21
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
+ from sglang.srt.server_args import get_global_server_args
23
23
  from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
24
24
  from sglang.utils import logger
25
25
 
@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
280
280
  input_ids_tensor[input_ids_tensor == token_id] = pad_value
281
281
 
282
282
  ret_input_ids = input_ids_tensor.tolist()
283
-
284
283
  return ret_input_ids
285
284
 
286
285
 
@@ -428,7 +427,7 @@ def _adjust_embedding_length(
428
427
  f"tokens from multimodal embeddings."
429
428
  )
430
429
  if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
431
- chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
430
+ chunked_prefill_size = get_global_server_args().chunked_prefill_size
432
431
  if chunked_prefill_size != -1:
433
432
  logger.warning(
434
433
  "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
@@ -507,6 +506,7 @@ def embed_mm_inputs(
507
506
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
508
507
  ] = None,
509
508
  placeholder_tokens: dict[Modality, List[int]] = None,
509
+ use_deepstack: Dict[Modality, bool] = {},
510
510
  ) -> Optional[torch.Tensor]:
511
511
  """
512
512
  Embed multimodal inputs and integrate them with text token embeddings.
@@ -522,7 +522,7 @@ def embed_mm_inputs(
522
522
  Returns:
523
523
  Combined embedding tensor with multimodal content integrated
524
524
  """
525
-
525
+ other_info = {}
526
526
  if mm_inputs_list is None:
527
527
  return None
528
528
 
@@ -532,7 +532,9 @@ def embed_mm_inputs(
532
532
  for mm_inputs in mm_inputs_list:
533
533
  item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
534
534
 
535
- embeddings, masks = [], []
535
+ # deepstack_embeddings: per-modality
536
+ modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
537
+
536
538
  # 2. Get multimodal embedding separately
537
539
  # Try get mm embedding if any
538
540
  for modality in Modality.all():
@@ -548,7 +550,8 @@ def embed_mm_inputs(
548
550
  # "image", "video", etc
549
551
  modality_id = modality.name.lower()
550
552
  embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
551
- if len(items) != 0 and embedder is not None:
553
+ if len(items) != 0:
554
+ assert embedder is not None, f"no embedding method found for {modality}"
552
555
  placeholder_tensor = torch.as_tensor(
553
556
  [item.pad_value for item in items],
554
557
  device=input_ids.device,
@@ -578,6 +581,13 @@ def embed_mm_inputs(
578
581
  extend_length=extend_seq_lens,
579
582
  items_offset_list=items_offsets,
580
583
  )
584
+
585
+ if use_deepstack.get(modality, None) and embedding is not None:
586
+ embedding, deepstack_embedding = (
587
+ multimodal_model.separate_deepstack_embeds(embedding)
588
+ )
589
+ deepstack_embeddings += [deepstack_embedding]
590
+ modalities += [modality]
581
591
  embeddings += [embedding]
582
592
  masks += [mask]
583
593
 
@@ -590,14 +600,37 @@ def embed_mm_inputs(
590
600
  input_ids.clamp_(min=0, max=vocab_size - 1)
591
601
  inputs_embeds = input_embedding(input_ids)
592
602
 
603
+ # deepstack embedding
604
+ if use_deepstack:
605
+ num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
606
+
607
+ deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
608
+ inputs_embeds.shape[-1] * num_deepstack_embeddings,
609
+ )
610
+ # a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
611
+ input_deepstack_embeds = torch.zeros(
612
+ deepstack_embedding_shape,
613
+ device=inputs_embeds.device,
614
+ dtype=inputs_embeds.dtype,
615
+ )
616
+
617
+ other_info["input_deepstack_embeds"] = input_deepstack_embeds
618
+
593
619
  # 4. scatter embeddings into input embedding
594
- for embedding, mask in zip(embeddings, masks):
620
+ for i, modality, embedding, mask in zip(
621
+ range(len(embeddings)), modalities, embeddings, masks
622
+ ):
595
623
  if embedding is None or mask is None:
596
624
  continue
597
625
  # in-place update
598
626
  indices = torch.where(mask.squeeze(dim=-1))[0]
599
627
  inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
600
- return inputs_embeds
628
+ if use_deepstack.get(modality, None):
629
+ input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
630
+ inputs_embeds.device, inputs_embeds.dtype
631
+ )
632
+
633
+ return inputs_embeds, other_info
601
634
 
602
635
 
603
636
  def general_mm_embed_routine(
@@ -609,6 +642,7 @@ def general_mm_embed_routine(
609
642
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
610
643
  ] = None,
611
644
  placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
645
+ use_deepstack: Dict[Modality, bool] = {},
612
646
  **kwargs,
613
647
  ) -> torch.Tensor:
614
648
  """
@@ -620,6 +654,7 @@ def general_mm_embed_routine(
620
654
  language_model: Base language model to use
621
655
  data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
622
656
  placeholder_tokens: Token IDs for multimodal placeholders
657
+ use_deepstack: Whether to use deepstack embeddings for each modality, default False
623
658
  **kwargs: Additional arguments passed to language model
624
659
 
625
660
  Returns:
@@ -645,16 +680,20 @@ def general_mm_embed_routine(
645
680
  for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
646
681
  if forward_batch.mm_inputs[i] is not None
647
682
  ]
648
- inputs_embeds = embed_mm_inputs(
683
+ inputs_embeds, other_info = embed_mm_inputs(
649
684
  mm_inputs_list=mm_inputs_list,
650
685
  extend_prefix_lens=extend_prefix_lens,
651
686
  extend_seq_lens=extend_seq_lens,
652
687
  input_ids=input_ids,
653
- input_embedding=embed_tokens,
654
688
  multimodal_model=multimodal_model,
689
+ input_embedding=embed_tokens,
655
690
  data_embedding_func_mapping=data_embedding_funcs,
656
691
  placeholder_tokens=placeholder_tokens,
692
+ use_deepstack=use_deepstack,
657
693
  )
694
+ # add for qwen3_vl deepstack
695
+ if use_deepstack:
696
+ kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
658
697
  # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
659
698
  # just being defensive here
660
699
  forward_batch.mm_inputs = None
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # Copyright 2023-2024 SGLang Team
2
4
  # Licensed under the Apache License, Version 2.0 (the "License");
3
5
  # you may not use this file except in compliance with the License.
@@ -11,7 +13,7 @@
11
13
  # See the License for the specific language governing permissions and
12
14
  # limitations under the License.
13
15
  # ==============================================================================
14
- """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
16
+ """Mixin class and utils for multi-http-worker mode"""
15
17
  import asyncio
16
18
  import logging
17
19
  import multiprocessing as multiprocessing
@@ -21,7 +23,7 @@ import sys
21
23
  import threading
22
24
  from functools import partialmethod
23
25
  from multiprocessing import shared_memory
24
- from typing import Any, Dict
26
+ from typing import TYPE_CHECKING, Any, Dict, Union
25
27
 
26
28
  import setproctitle
27
29
  import zmq
@@ -30,12 +32,12 @@ import zmq.asyncio
30
32
  from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
31
33
  from sglang.srt.managers.disagg_service import start_disagg_service
32
34
  from sglang.srt.managers.io_struct import (
33
- BatchEmbeddingOut,
34
- BatchMultimodalOut,
35
- BatchStrOut,
36
- BatchTokenIDOut,
37
- MultiTokenizerRegisterReq,
38
- MultiTokenizerWrapper,
35
+ BaseBatchReq,
36
+ BaseReq,
37
+ BatchEmbeddingOutput,
38
+ BatchMultimodalOutput,
39
+ BatchStrOutput,
40
+ BatchTokenIDOutput,
39
41
  )
40
42
  from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
41
43
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs
43
45
  from sglang.srt.utils import get_zmq_socket, kill_process_tree
44
46
  from sglang.utils import get_exception_traceback
45
47
 
48
+ if TYPE_CHECKING:
49
+ from sglang.srt.managers.detokenizer_manager import DetokenizerManager
50
+
46
51
  logger = logging.getLogger(__name__)
47
52
 
48
53
 
@@ -56,35 +61,30 @@ class SocketMapping:
56
61
  socket.close()
57
62
  self._mapping.clear()
58
63
 
59
- def register_ipc_mapping(
60
- self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
61
- ):
64
+ def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool):
62
65
  type_str = "tokenizer" if is_tokenizer else "detokenizer"
63
- if worker_id in self._mapping:
64
- logger.warning(
65
- f"{type_str} already registered with worker {worker_id}, skipping..."
66
- )
66
+ if ipc_name in self._mapping:
67
+ logger.warning(f"{type_str} already registered {ipc_name=}, skipping...")
67
68
  return
68
- logger.info(
69
- f"{type_str} not registered with worker {worker_id}, registering..."
70
- )
71
- socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
72
- self._mapping[worker_id] = socket
73
- self._mapping[worker_id].send_pyobj(recv_obj)
74
-
75
- def send_output(self, worker_id: str, output: Any):
76
- if worker_id not in self._mapping:
77
- logger.error(
78
- f"worker ID {worker_id} not registered. Check if the server Process is alive"
79
- )
69
+ logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...")
70
+ socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
71
+ self._mapping[ipc_name] = socket
72
+
73
+ def send_output(self, ipc_name: str, output: Any):
74
+ if ipc_name is None:
75
+ # Some unhandled cases
76
+ logger.warning(f"IPC name is None, output type={type(output)}, skipping...")
80
77
  return
81
- self._mapping[worker_id].send_pyobj(output)
78
+
79
+ if ipc_name not in self._mapping:
80
+ self._register_ipc_mapping(ipc_name, is_tokenizer=False)
81
+ self._mapping[ipc_name].send_pyobj(output)
82
82
 
83
83
 
84
84
  def _handle_output_by_index(output, i):
85
85
  """NOTE: A maintainable method is better here."""
86
- if isinstance(output, BatchTokenIDOut):
87
- new_output = BatchTokenIDOut(
86
+ if isinstance(output, BatchTokenIDOutput):
87
+ new_output = BatchTokenIDOutput(
88
88
  rids=[output.rids[i]],
89
89
  finished_reasons=(
90
90
  [output.finished_reasons[i]]
@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
190
190
  if output.output_token_ids_logprobs_idx
191
191
  else None
192
192
  ),
193
+ output_token_entropy_val=(
194
+ [output.output_token_entropy_val[i]]
195
+ if output.output_token_entropy_val
196
+ else None
197
+ ),
193
198
  output_hidden_states=(
194
199
  [output.output_hidden_states[i]]
195
200
  if output.output_hidden_states
@@ -197,9 +202,10 @@ def _handle_output_by_index(output, i):
197
202
  ),
198
203
  placeholder_tokens_idx=None,
199
204
  placeholder_tokens_val=None,
205
+ token_steps=([output.token_steps[i]] if output.token_steps else None),
200
206
  )
201
- elif isinstance(output, BatchEmbeddingOut):
202
- new_output = BatchEmbeddingOut(
207
+ elif isinstance(output, BatchEmbeddingOutput):
208
+ new_output = BatchEmbeddingOutput(
203
209
  rids=[output.rids[i]],
204
210
  finished_reasons=(
205
211
  [output.finished_reasons[i]]
@@ -216,8 +222,8 @@ def _handle_output_by_index(output, i):
216
222
  placeholder_tokens_idx=None,
217
223
  placeholder_tokens_val=None,
218
224
  )
219
- elif isinstance(output, BatchStrOut):
220
- new_output = BatchStrOut(
225
+ elif isinstance(output, BatchStrOutput):
226
+ new_output = BatchStrOutput(
221
227
  rids=[output.rids[i]],
222
228
  finished_reasons=(
223
229
  [output.finished_reasons[i]]
@@ -246,6 +252,11 @@ def _handle_output_by_index(output, i):
246
252
  spec_verify_ct=(
247
253
  [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
248
254
  ),
255
+ spec_accepted_tokens=(
256
+ [output.spec_accepted_tokens[i]]
257
+ if len(output.spec_accepted_tokens) > i
258
+ else None
259
+ ),
249
260
  input_token_logprobs_val=(
250
261
  [output.input_token_logprobs_val[i]]
251
262
  if output.input_token_logprobs_val
@@ -306,6 +317,11 @@ def _handle_output_by_index(output, i):
306
317
  if output.output_token_ids_logprobs_idx
307
318
  else None
308
319
  ),
320
+ output_token_entropy_val=(
321
+ [output.output_token_entropy_val[i]]
322
+ if output.output_token_entropy_val
323
+ else None
324
+ ),
309
325
  output_hidden_states=(
310
326
  [output.output_hidden_states[i]]
311
327
  if output.output_hidden_states
@@ -313,9 +329,10 @@ def _handle_output_by_index(output, i):
313
329
  ),
314
330
  placeholder_tokens_idx=None,
315
331
  placeholder_tokens_val=None,
332
+ token_steps=([output.token_steps[i]] if output.token_steps else None),
316
333
  )
317
- elif isinstance(output, BatchMultimodalOut):
318
- new_output = BatchMultimodalOut(
334
+ elif isinstance(output, BatchMultimodalOutput):
335
+ new_output = BatchMultimodalOutput(
319
336
  rids=[output.rids[i]],
320
337
  finished_reasons=(
321
338
  [output.finished_reasons[i]]
@@ -343,22 +360,13 @@ def _handle_output_by_index(output, i):
343
360
 
344
361
 
345
362
  class MultiHttpWorkerDetokenizerMixin:
346
- """Mixin class for MultiTokenizerManager and DetokenizerManager"""
363
+ """Mixin class for DetokenizerManager"""
347
364
 
348
- def get_worker_ids_from_req_rids(self, rids):
349
- if isinstance(rids, list):
350
- worker_ids = [int(rid.split("_")[0]) for rid in rids]
351
- elif isinstance(rids, str):
352
- worker_ids = [int(rids.split("_")[0])]
353
- else:
354
- worker_ids = []
355
- return worker_ids
356
-
357
- def maybe_clear_socket_mapping(self):
365
+ def maybe_clear_socket_mapping(self: DetokenizerManager):
358
366
  if hasattr(self, "socket_mapping"):
359
367
  self.socket_mapping.clear_all_sockets()
360
368
 
361
- def multi_http_worker_event_loop(self):
369
+ def multi_http_worker_event_loop(self: DetokenizerManager):
362
370
  """The event loop that handles requests, for multi multi-http-worker mode"""
363
371
  self.socket_mapping = SocketMapping()
364
372
  while True:
@@ -366,27 +374,19 @@ class MultiHttpWorkerDetokenizerMixin:
366
374
  output = self._request_dispatcher(recv_obj)
367
375
  if output is None:
368
376
  continue
369
- # Extract worker_id from rid
370
- if isinstance(recv_obj.rids, list):
371
- worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
372
- else:
373
- raise RuntimeError(
374
- f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
375
- )
377
+
378
+ assert isinstance(
379
+ recv_obj, BaseBatchReq
380
+ ), "for multi-http-worker, recv_obj must be BaseBatchReq"
376
381
 
377
382
  # Send data using the corresponding socket
378
- for i, worker_id in enumerate(worker_ids):
379
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
380
- self.socket_mapping.register_ipc_mapping(
381
- recv_obj, worker_id, is_tokenizer=False
382
- )
383
- else:
384
- new_output = _handle_output_by_index(output, i)
385
- self.socket_mapping.send_output(worker_id, new_output)
383
+ for i, ipc_name in enumerate(recv_obj.http_worker_ipcs):
384
+ new_output = _handle_output_by_index(output, i)
385
+ self.socket_mapping.send_output(ipc_name, new_output)
386
386
 
387
387
 
388
388
  class MultiTokenizerRouter:
389
- """A router to receive requests from MultiTokenizerManager"""
389
+ """A router to receive requests from TokenizerWorker"""
390
390
 
391
391
  def __init__(
392
392
  self,
@@ -432,30 +432,21 @@ class MultiTokenizerRouter:
432
432
  await self._distribute_result_to_workers(recv_obj)
433
433
 
434
434
  async def _distribute_result_to_workers(self, recv_obj):
435
- """Distribute result to corresponding workers based on rid"""
436
- if isinstance(recv_obj, MultiTokenizerWrapper):
437
- worker_ids = [recv_obj.worker_id]
438
- recv_obj = recv_obj.obj
435
+ # Distribute result to each worker
436
+ if isinstance(recv_obj, BaseReq):
437
+ ipc_names = [recv_obj.http_worker_ipc]
438
+ elif isinstance(recv_obj, BaseBatchReq):
439
+ ipc_names = recv_obj.http_worker_ipcs
439
440
  else:
440
- worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
441
-
442
- if len(worker_ids) == 0:
443
- logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
444
- return
441
+ raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}")
445
442
 
446
- # Distribute result to each worker
447
- for i, worker_id in enumerate(worker_ids):
448
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
449
- self.socket_mapping.register_ipc_mapping(
450
- recv_obj, worker_id, is_tokenizer=True
451
- )
452
- else:
453
- new_recv_obj = _handle_output_by_index(recv_obj, i)
454
- self.socket_mapping.send_output(worker_id, new_recv_obj)
443
+ for i, ipc_name in enumerate(ipc_names):
444
+ new_recv_obj = _handle_output_by_index(recv_obj, i)
445
+ self.socket_mapping.send_output(ipc_name, new_recv_obj)
455
446
 
456
447
 
457
- class MultiTokenizerManager(TokenizerManager):
458
- """Multi Process Tokenizer Manager that tokenizes the text."""
448
+ class TokenizerWorker(TokenizerManager):
449
+ """Tokenizer Worker in multi-http-worker mode"""
459
450
 
460
451
  def __init__(
461
452
  self,
@@ -483,21 +474,15 @@ class MultiTokenizerManager(TokenizerManager):
483
474
  self.register_multi_tokenizer_communicator = _Communicator(
484
475
  self.send_to_scheduler, 2
485
476
  )
486
- self._result_dispatcher._mapping.append(
487
- (
488
- MultiTokenizerRegisterReq,
489
- self.register_multi_tokenizer_communicator.handle_recv,
490
- )
491
- )
492
477
 
493
- async def register_to_main_tokenizer_manager(self):
494
- """Register this worker to the main TokenizerManager"""
495
- # create a handle loop to receive messages from the main TokenizerManager
496
- self.auto_create_handle_loop()
497
- req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
498
- req.ipc_name = self.tokenizer_ipc_name
499
- _Communicator.enable_multi_tokenizer = True
500
- await self.register_multi_tokenizer_communicator(req)
478
+ def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]):
479
+
480
+ if isinstance(req, BaseReq):
481
+ req.http_worker_ipc = self.tokenizer_ipc_name
482
+ elif isinstance(req, BaseBatchReq):
483
+ req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids)
484
+ else:
485
+ raise ValueError(f"Unknown req type: {type(req)}")
501
486
 
502
487
 
503
488
  async def print_exception_wrapper(func):
@@ -12,8 +12,7 @@ logger = logging.getLogger(__name__)
12
12
  PROCESSOR_MAPPING = {}
13
13
 
14
14
 
15
- def import_processors():
16
- package_name = "sglang.srt.multimodal.processors"
15
+ def import_processors(package_name: str):
17
16
  package = importlib.import_module(package_name)
18
17
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
19
18
  if not ispkg:
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.utils import get_compiler_backend
9
+
10
+ if TYPE_CHECKING:
11
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
12
+ from sglang.srt.managers.scheduler import GenerationBatchResult
13
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
14
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
15
+
16
+
17
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
18
+ def _resolve_future_token_ids(input_ids, future_token_ids_map):
19
+ input_ids[:] = torch.where(
20
+ input_ids < 0,
21
+ future_token_ids_map[torch.clamp(-input_ids, min=0)],
22
+ input_ids,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class FutureIndices:
28
+ indices: torch.Tensor
29
+ interval: Optional[slice] = None
30
+
31
+
32
+ class FutureMap:
33
+ def __init__(
34
+ self,
35
+ max_running_requests: int,
36
+ device: torch.device,
37
+ spec_algo: Optional[SpeculativeAlgorithm] = None,
38
+ ):
39
+ self.future_ct = 0
40
+ # A factor of 3 is used to avoid collision in the circular buffer.
41
+ self.future_limit = max_running_requests * 3
42
+ # A factor of 5 is used to ensure the buffer is large enough.
43
+ self.future_buffer_len = max_running_requests * 5
44
+ self.device = device
45
+ self.spec_algo = spec_algo
46
+ self.buf_initialized = False
47
+
48
+ if self.spec_algo.is_none():
49
+ self.token_ids_buf = torch.empty(
50
+ (self.future_buffer_len,), dtype=torch.int64, device=self.device
51
+ )
52
+
53
+ def _lazy_init_buf(self, draft_input: EagleDraftInput):
54
+ if self.buf_initialized or not self.spec_algo.is_eagle():
55
+ return
56
+
57
+ self.buf_initialized = True
58
+
59
+ # get the template for each tensor
60
+ topk_p0 = draft_input.topk_p[0]
61
+ topk_index0 = draft_input.topk_index[0]
62
+ hidden_states0 = draft_input.hidden_states[0]
63
+ verified_id0 = draft_input.verified_id[0]
64
+ new_seq_lens0 = draft_input.new_seq_lens[0]
65
+
66
+ self.topk_p_buf = torch.empty(
67
+ (self.future_buffer_len, *topk_p0.shape),
68
+ dtype=topk_p0.dtype,
69
+ device=self.device,
70
+ )
71
+ self.topk_index_buf = torch.empty(
72
+ (self.future_buffer_len, *topk_index0.shape),
73
+ dtype=topk_index0.dtype,
74
+ device=self.device,
75
+ )
76
+ self.hidden_states_buf = torch.empty(
77
+ (self.future_buffer_len, *hidden_states0.shape),
78
+ dtype=hidden_states0.dtype,
79
+ device=self.device,
80
+ )
81
+ self.verified_id_buf = torch.empty(
82
+ (self.future_buffer_len, *verified_id0.shape),
83
+ dtype=verified_id0.dtype,
84
+ device=self.device,
85
+ )
86
+ self.new_seq_lens_buf = torch.empty(
87
+ (self.future_buffer_len, *new_seq_lens0.shape),
88
+ dtype=new_seq_lens0.dtype,
89
+ device=self.device,
90
+ )
91
+
92
+ def alloc_future_indices(self, bs: int) -> FutureIndices:
93
+ """Update the circular buffer pointer and allocate future indices."""
94
+ cur_future_ct = self.future_ct
95
+ self.future_ct = (cur_future_ct + bs) % self.future_limit
96
+ start = cur_future_ct + 1
97
+ end = cur_future_ct + 1 + bs
98
+ indices = torch.arange(start, end, dtype=torch.int64, device=self.device)
99
+ return FutureIndices(indices=indices, interval=slice(start, end))
100
+
101
+ def resolve_future(self, model_worker_batch: ModelWorkerBatch):
102
+ if self.spec_algo.is_eagle():
103
+ # TODO(lsyin): write future indices into spec_info.future_indices
104
+ draft_input: EagleDraftInput = model_worker_batch.spec_info
105
+ if draft_input is None:
106
+ # FIXME(lsyin): No future exists, only for prefill batch, not compatible with mixed mode
107
+ return
108
+ indices = draft_input.future_indices.indices
109
+ draft_input.topk_p = self.topk_p_buf[indices]
110
+ draft_input.topk_index = self.topk_index_buf[indices]
111
+ draft_input.hidden_states = self.hidden_states_buf[indices]
112
+ draft_input.verified_id = self.verified_id_buf[indices]
113
+ draft_input.new_seq_lens = self.new_seq_lens_buf[indices]
114
+ else:
115
+ _resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
116
+
117
+ def store_to_map(
118
+ self, future_indices: FutureIndices, batch_result: GenerationBatchResult
119
+ ):
120
+ intv = future_indices.interval
121
+ if self.spec_algo.is_eagle():
122
+ draft_input: EagleDraftInput = batch_result.next_draft_input
123
+ self._lazy_init_buf(draft_input)
124
+ self.topk_p_buf[intv] = draft_input.topk_p
125
+ self.topk_index_buf[intv] = draft_input.topk_index
126
+ self.hidden_states_buf[intv] = draft_input.hidden_states
127
+ self.verified_id_buf[intv] = draft_input.verified_id
128
+ self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
129
+ else:
130
+ self.token_ids_buf[intv] = batch_result.next_token_ids