sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -36,10 +36,10 @@ else:
36
36
  Image = Any
37
37
 
38
38
 
39
- # Parameters for a session
40
39
  @dataclass
41
40
  class BaseReq(ABC):
42
41
  rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
42
+ http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
43
43
 
44
44
  def regenerate_rid(self):
45
45
  """Generate a new request ID and return it."""
@@ -53,6 +53,7 @@ class BaseReq(ABC):
53
53
  @dataclass
54
54
  class BaseBatchReq(ABC):
55
55
  rids: Optional[List[str]] = field(default=None, kw_only=True)
56
+ http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
56
57
 
57
58
  def regenerate_rids(self):
58
59
  """Generate new request IDs and return them."""
@@ -60,9 +61,11 @@ class BaseBatchReq(ABC):
60
61
  return self.rids
61
62
 
62
63
 
64
+ # Parameters for a session
63
65
  @dataclass
64
66
  class SessionParams:
65
67
  id: Optional[str] = None
68
+ rid: Optional[str] = None
66
69
  offset: Optional[int] = None
67
70
  replace: Optional[bool] = None
68
71
  drop_previous_output: Optional[bool] = None
@@ -169,6 +172,9 @@ class GenerateReqInput(BaseReq):
169
172
  # (Internal) Whether to return bytes for image generation
170
173
  return_bytes: bool = False
171
174
 
175
+ # Whether to return entropy
176
+ return_entropy: bool = False
177
+
172
178
  def contains_mm_input(self) -> bool:
173
179
  return (
174
180
  has_valid_data(self.image_data)
@@ -567,6 +573,8 @@ class GenerateReqInput(BaseReq):
567
573
  no_logs=self.no_logs,
568
574
  custom_labels=self.custom_labels,
569
575
  return_bytes=self.return_bytes,
576
+ return_entropy=self.return_entropy,
577
+ http_worker_ipc=self.http_worker_ipc,
570
578
  )
571
579
 
572
580
 
@@ -632,6 +640,9 @@ class TokenizedGenerateReqInput(BaseReq):
632
640
  # (Internal) Whether to return bytes for image generation
633
641
  return_bytes: bool = False
634
642
 
643
+ # Whether to return entropy
644
+ return_entropy: bool = False
645
+
635
646
 
636
647
  @dataclass
637
648
  class BatchTokenizedGenerateReqInput(BaseBatchReq):
@@ -749,6 +760,7 @@ class EmbeddingReqInput(BaseReq):
749
760
  sampling_params=self.sampling_params[i],
750
761
  rid=self.rid[i],
751
762
  is_cross_encoder_request=True,
763
+ http_worker_ipc=self.http_worker_ipc,
752
764
  )
753
765
 
754
766
  return EmbeddingReqInput(
@@ -759,6 +771,7 @@ class EmbeddingReqInput(BaseReq):
759
771
  video_data=self.video_data[i] if self.video_data is not None else None,
760
772
  sampling_params=self.sampling_params[i],
761
773
  rid=self.rid[i],
774
+ http_worker_ipc=self.http_worker_ipc,
762
775
  )
763
776
 
764
777
 
@@ -815,6 +828,7 @@ class BatchTokenIDOutput(BaseBatchReq):
815
828
  completion_tokens: List[int]
816
829
  cached_tokens: List[int]
817
830
  spec_verify_ct: List[int]
831
+ spec_accepted_tokens: List[int]
818
832
 
819
833
  # Logprobs
820
834
  input_token_logprobs_val: List[float]
@@ -829,6 +843,7 @@ class BatchTokenIDOutput(BaseBatchReq):
829
843
  input_token_ids_logprobs_idx: List[List]
830
844
  output_token_ids_logprobs_val: List[List]
831
845
  output_token_ids_logprobs_idx: List[List]
846
+ output_token_entropy_val: List[float]
832
847
 
833
848
  # Hidden states
834
849
  output_hidden_states: List[List[float]]
@@ -839,6 +854,9 @@ class BatchTokenIDOutput(BaseBatchReq):
839
854
  placeholder_tokens_idx: List[Optional[List[int]]]
840
855
  placeholder_tokens_val: List[Optional[List[int]]]
841
856
 
857
+ # The trainer step id. Used to know which step's weights are used for sampling.
858
+ token_steps: List[List[int]] = None
859
+
842
860
 
843
861
  @dataclass
844
862
  class BatchMultimodalDecodeReq(BaseBatchReq):
@@ -860,11 +878,16 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
860
878
  completion_tokens: List[int]
861
879
  cached_tokens: List[int]
862
880
 
863
- # Placeholder token info
881
+ # The information of placeholder tokens (e.g., image token)
882
+ # idx is the index of the token in the prompt after expansion.
883
+ # val is the length of padded tokens after expansion.
864
884
  placeholder_tokens_idx: List[Optional[List[int]]]
865
885
  placeholder_tokens_val: List[Optional[List[int]]]
866
886
 
867
- return_bytes: bool = False
887
+ return_bytes: List[bool]
888
+
889
+ # The trainer step id. Used to know which step's weights are used for sampling.
890
+ token_steps: List[List[int]] = None
868
891
 
869
892
 
870
893
  @dataclass
@@ -881,6 +904,7 @@ class BatchStrOutput(BaseBatchReq):
881
904
  completion_tokens: List[int]
882
905
  cached_tokens: List[int]
883
906
  spec_verify_ct: List[int]
907
+ spec_accepted_tokens: List[int]
884
908
 
885
909
  # Logprobs
886
910
  input_token_logprobs_val: List[float]
@@ -895,13 +919,20 @@ class BatchStrOutput(BaseBatchReq):
895
919
  input_token_ids_logprobs_idx: List[List]
896
920
  output_token_ids_logprobs_val: List[List]
897
921
  output_token_ids_logprobs_idx: List[List]
922
+ output_token_entropy_val: List[float]
898
923
 
899
924
  # Hidden states
900
925
  output_hidden_states: List[List[float]]
901
926
 
927
+ # The information of placeholder tokens (e.g., image token)
928
+ # idx is the index of the token in the prompt after expansion.
929
+ # val is the length of padded tokens after expansion.
902
930
  placeholder_tokens_idx: List[Optional[List[int]]]
903
931
  placeholder_tokens_val: List[Optional[List[int]]]
904
932
 
933
+ # The trainer step id. Used to know which step's weights are used for sampling.
934
+ token_steps: List[List[int]] = None
935
+
905
936
 
906
937
  @dataclass
907
938
  class BatchMultimodalOutput(BaseBatchReq):
@@ -933,7 +964,7 @@ class BatchEmbeddingOutput(BaseBatchReq):
933
964
  # The finish reason
934
965
  finished_reasons: List[BaseFinishReason]
935
966
  # The output embedding
936
- embeddings: List[List[float]]
967
+ embeddings: Union[List[List[float]], List[Dict[int, float]]]
937
968
  # Token counts
938
969
  prompt_tokens: List[int]
939
970
  cached_tokens: List[int]
@@ -978,6 +1009,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
978
1009
  torch_empty_cache: bool = False
979
1010
  # Whether to keep the scheduler paused after weight update
980
1011
  keep_pause: bool = False
1012
+ # The trainer step id. Used to know which step's weights are used for sampling.
1013
+ token_step: int = 0
981
1014
 
982
1015
 
983
1016
  @dataclass
@@ -1050,6 +1083,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
1050
1083
  backend: str = "nccl"
1051
1084
 
1052
1085
 
1086
+ # Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
1087
+ # are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
1088
+ @dataclass
1089
+ class UpdateWeightsFromIPCReqInput(BaseReq):
1090
+ # ZMQ socket paths for each device UUID
1091
+ zmq_handles: Dict[str, str]
1092
+ # Whether to flush cache after weight update
1093
+ flush_cache: bool = True
1094
+ # Optional: Update weight version along with weights
1095
+ weight_version: Optional[str] = None
1096
+
1097
+
1098
+ @dataclass
1099
+ class UpdateWeightsFromIPCReqOutput(BaseReq):
1100
+ success: bool
1101
+ message: str
1102
+
1103
+
1053
1104
  @dataclass
1054
1105
  class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
1055
1106
  success: bool
@@ -1206,6 +1257,8 @@ class ProfileReqInput(BaseReq):
1206
1257
  profile_by_stage: bool = False
1207
1258
  with_stack: Optional[bool] = None
1208
1259
  record_shapes: Optional[bool] = None
1260
+ # Merge profiles from all ranks into a single trace
1261
+ merge_profiles: bool = False
1209
1262
 
1210
1263
 
1211
1264
  class ProfileReqType(Enum):
@@ -1224,6 +1277,8 @@ class ProfileReq(BaseReq):
1224
1277
  with_stack: Optional[bool] = None
1225
1278
  record_shapes: Optional[bool] = None
1226
1279
  profile_id: Optional[str] = None
1280
+ # Merge profiles from all ranks into a single trace
1281
+ merge_profiles: bool = False
1227
1282
 
1228
1283
 
1229
1284
  @dataclass
@@ -1375,18 +1430,6 @@ class LoRAUpdateOutput(BaseReq):
1375
1430
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
1376
1431
 
1377
1432
 
1378
- @dataclass
1379
- class MultiTokenizerRegisterReq(BaseBatchReq):
1380
- ipc_name: Optional[str] = None
1381
-
1382
-
1383
- @dataclass
1384
- class MultiTokenizerWrapper:
1385
- # FIXME(lsyin): remove this
1386
- worker_id: int
1387
- obj: Optional[Any] = None
1388
-
1389
-
1390
1433
  class BlockReqType(Enum):
1391
1434
  BLOCK = 1
1392
1435
  UNBLOCK = 2
@@ -1415,6 +1458,16 @@ class WatchLoadUpdateReq(BaseReq):
1415
1458
  loads: List[GetLoadReqOutput]
1416
1459
 
1417
1460
 
1461
+ @dataclass
1462
+ class LazyDumpTensorsReqInput(BaseReq):
1463
+ pass
1464
+
1465
+
1466
+ @dataclass
1467
+ class LazyDumpTensorsReqOutput(BaseReq):
1468
+ success: bool
1469
+
1470
+
1418
1471
  def _check_all_req_types():
1419
1472
  """A helper function to check all request types are defined in this file."""
1420
1473
  import inspect
@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
16
16
  Modality,
17
17
  MultimodalDataItem,
18
18
  MultimodalInputs,
19
- global_server_args_dict,
20
19
  )
21
20
  from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
22
21
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
+ from sglang.srt.server_args import get_global_server_args
23
23
  from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
24
24
  from sglang.utils import logger
25
25
 
@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
280
280
  input_ids_tensor[input_ids_tensor == token_id] = pad_value
281
281
 
282
282
  ret_input_ids = input_ids_tensor.tolist()
283
-
284
283
  return ret_input_ids
285
284
 
286
285
 
@@ -428,7 +427,7 @@ def _adjust_embedding_length(
428
427
  f"tokens from multimodal embeddings."
429
428
  )
430
429
  if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
431
- chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
430
+ chunked_prefill_size = get_global_server_args().chunked_prefill_size
432
431
  if chunked_prefill_size != -1:
433
432
  logger.warning(
434
433
  "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
@@ -507,7 +506,7 @@ def embed_mm_inputs(
507
506
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
508
507
  ] = None,
509
508
  placeholder_tokens: dict[Modality, List[int]] = None,
510
- use_deepstack: bool = False,
509
+ use_deepstack: Dict[Modality, bool] = {},
511
510
  ) -> Optional[torch.Tensor]:
512
511
  """
513
512
  Embed multimodal inputs and integrate them with text token embeddings.
@@ -533,7 +532,9 @@ def embed_mm_inputs(
533
532
  for mm_inputs in mm_inputs_list:
534
533
  item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
535
534
 
536
- embeddings, masks, deepstack_embeddings = [], [], []
535
+ # deepstack_embeddings: per-modality
536
+ modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
537
+
537
538
  # 2. Get multimodal embedding separately
538
539
  # Try get mm embedding if any
539
540
  for modality in Modality.all():
@@ -549,7 +550,8 @@ def embed_mm_inputs(
549
550
  # "image", "video", etc
550
551
  modality_id = modality.name.lower()
551
552
  embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
552
- if len(items) != 0 and embedder is not None:
553
+ if len(items) != 0:
554
+ assert embedder is not None, f"no embedding method found for {modality}"
553
555
  placeholder_tensor = torch.as_tensor(
554
556
  [item.pad_value for item in items],
555
557
  device=input_ids.device,
@@ -580,11 +582,12 @@ def embed_mm_inputs(
580
582
  items_offset_list=items_offsets,
581
583
  )
582
584
 
583
- if use_deepstack and embedding is not None:
585
+ if use_deepstack.get(modality, None) and embedding is not None:
584
586
  embedding, deepstack_embedding = (
585
587
  multimodal_model.separate_deepstack_embeds(embedding)
586
588
  )
587
589
  deepstack_embeddings += [deepstack_embedding]
590
+ modalities += [modality]
588
591
  embeddings += [embedding]
589
592
  masks += [mask]
590
593
 
@@ -597,17 +600,14 @@ def embed_mm_inputs(
597
600
  input_ids.clamp_(min=0, max=vocab_size - 1)
598
601
  inputs_embeds = input_embedding(input_ids)
599
602
 
600
- # 4. scatter embeddings into input embedding
601
-
602
603
  # deepstack embedding
603
604
  if use_deepstack:
604
- num_deepstack_embeddings = (
605
- len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
606
- )
605
+ num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
606
+
607
607
  deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
608
608
  inputs_embeds.shape[-1] * num_deepstack_embeddings,
609
609
  )
610
-
610
+ # a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
611
611
  input_deepstack_embeds = torch.zeros(
612
612
  deepstack_embedding_shape,
613
613
  device=inputs_embeds.device,
@@ -616,14 +616,16 @@ def embed_mm_inputs(
616
616
 
617
617
  other_info["input_deepstack_embeds"] = input_deepstack_embeds
618
618
 
619
- for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
619
+ # 4. scatter embeddings into input embedding
620
+ for i, modality, embedding, mask in zip(
621
+ range(len(embeddings)), modalities, embeddings, masks
622
+ ):
620
623
  if embedding is None or mask is None:
621
624
  continue
622
625
  # in-place update
623
626
  indices = torch.where(mask.squeeze(dim=-1))[0]
624
627
  inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
625
-
626
- if use_deepstack:
628
+ if use_deepstack.get(modality, None):
627
629
  input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
628
630
  inputs_embeds.device, inputs_embeds.dtype
629
631
  )
@@ -640,7 +642,7 @@ def general_mm_embed_routine(
640
642
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
641
643
  ] = None,
642
644
  placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
643
- use_deepstack: bool = False,
645
+ use_deepstack: Dict[Modality, bool] = {},
644
646
  **kwargs,
645
647
  ) -> torch.Tensor:
646
648
  """
@@ -652,7 +654,7 @@ def general_mm_embed_routine(
652
654
  language_model: Base language model to use
653
655
  data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
654
656
  placeholder_tokens: Token IDs for multimodal placeholders
655
- use_deepstack: Whether to use deepstack embeddings
657
+ use_deepstack: Whether to use deepstack embeddings for each modality, default False
656
658
  **kwargs: Additional arguments passed to language model
657
659
 
658
660
  Returns:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # Copyright 2023-2024 SGLang Team
2
4
  # Licensed under the Apache License, Version 2.0 (the "License");
3
5
  # you may not use this file except in compliance with the License.
@@ -11,7 +13,12 @@
11
13
  # See the License for the specific language governing permissions and
12
14
  # limitations under the License.
13
15
  # ==============================================================================
14
- """Mixin class and utils for multi-http-worker mode"""
16
+
17
+ """
18
+ Mixin classes and utils for multi-http-worker mode
19
+ This file uses multiple processes to handle requests and tokenization, reducing the overhead of python and http server.
20
+ """
21
+
15
22
  import asyncio
16
23
  import logging
17
24
  import multiprocessing as multiprocessing
@@ -21,7 +28,7 @@ import sys
21
28
  import threading
22
29
  from functools import partialmethod
23
30
  from multiprocessing import shared_memory
24
- from typing import Any, Dict
31
+ from typing import TYPE_CHECKING, Any, Dict, Union
25
32
 
26
33
  import setproctitle
27
34
  import zmq
@@ -30,12 +37,12 @@ import zmq.asyncio
30
37
  from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
31
38
  from sglang.srt.managers.disagg_service import start_disagg_service
32
39
  from sglang.srt.managers.io_struct import (
40
+ BaseBatchReq,
41
+ BaseReq,
33
42
  BatchEmbeddingOutput,
34
43
  BatchMultimodalOutput,
35
44
  BatchStrOutput,
36
45
  BatchTokenIDOutput,
37
- MultiTokenizerRegisterReq,
38
- MultiTokenizerWrapper,
39
46
  )
40
47
  from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
41
48
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -43,6 +50,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs
43
50
  from sglang.srt.utils import get_zmq_socket, kill_process_tree
44
51
  from sglang.utils import get_exception_traceback
45
52
 
53
+ if TYPE_CHECKING:
54
+ from sglang.srt.managers.detokenizer_manager import DetokenizerManager
55
+
46
56
  logger = logging.getLogger(__name__)
47
57
 
48
58
 
@@ -56,29 +66,24 @@ class SocketMapping:
56
66
  socket.close()
57
67
  self._mapping.clear()
58
68
 
59
- def register_ipc_mapping(
60
- self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
61
- ):
69
+ def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool):
62
70
  type_str = "tokenizer" if is_tokenizer else "detokenizer"
63
- if worker_id in self._mapping:
64
- logger.warning(
65
- f"{type_str} already registered with worker {worker_id}, skipping..."
66
- )
71
+ if ipc_name in self._mapping:
72
+ logger.warning(f"{type_str} already registered {ipc_name=}, skipping...")
67
73
  return
68
- logger.info(
69
- f"{type_str} not registered with worker {worker_id}, registering..."
70
- )
71
- socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
72
- self._mapping[worker_id] = socket
73
- self._mapping[worker_id].send_pyobj(recv_obj)
74
-
75
- def send_output(self, worker_id: str, output: Any):
76
- if worker_id not in self._mapping:
77
- logger.error(
78
- f"worker ID {worker_id} not registered. Check if the server Process is alive"
79
- )
74
+ logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...")
75
+ socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
76
+ self._mapping[ipc_name] = socket
77
+
78
+ def send_output(self, ipc_name: str, output: Any):
79
+ if ipc_name is None:
80
+ # Some unhandled cases
81
+ logger.warning(f"IPC name is None, output type={type(output)}, skipping...")
80
82
  return
81
- self._mapping[worker_id].send_pyobj(output)
83
+
84
+ if ipc_name not in self._mapping:
85
+ self._register_ipc_mapping(ipc_name, is_tokenizer=False)
86
+ self._mapping[ipc_name].send_pyobj(output)
82
87
 
83
88
 
84
89
  def _handle_output_by_index(output, i):
@@ -190,6 +195,11 @@ def _handle_output_by_index(output, i):
190
195
  if output.output_token_ids_logprobs_idx
191
196
  else None
192
197
  ),
198
+ output_token_entropy_val=(
199
+ [output.output_token_entropy_val[i]]
200
+ if output.output_token_entropy_val
201
+ else None
202
+ ),
193
203
  output_hidden_states=(
194
204
  [output.output_hidden_states[i]]
195
205
  if output.output_hidden_states
@@ -197,6 +207,7 @@ def _handle_output_by_index(output, i):
197
207
  ),
198
208
  placeholder_tokens_idx=None,
199
209
  placeholder_tokens_val=None,
210
+ token_steps=([output.token_steps[i]] if output.token_steps else None),
200
211
  )
201
212
  elif isinstance(output, BatchEmbeddingOutput):
202
213
  new_output = BatchEmbeddingOutput(
@@ -246,6 +257,11 @@ def _handle_output_by_index(output, i):
246
257
  spec_verify_ct=(
247
258
  [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
248
259
  ),
260
+ spec_accepted_tokens=(
261
+ [output.spec_accepted_tokens[i]]
262
+ if len(output.spec_accepted_tokens) > i
263
+ else None
264
+ ),
249
265
  input_token_logprobs_val=(
250
266
  [output.input_token_logprobs_val[i]]
251
267
  if output.input_token_logprobs_val
@@ -306,6 +322,11 @@ def _handle_output_by_index(output, i):
306
322
  if output.output_token_ids_logprobs_idx
307
323
  else None
308
324
  ),
325
+ output_token_entropy_val=(
326
+ [output.output_token_entropy_val[i]]
327
+ if output.output_token_entropy_val
328
+ else None
329
+ ),
309
330
  output_hidden_states=(
310
331
  [output.output_hidden_states[i]]
311
332
  if output.output_hidden_states
@@ -313,6 +334,7 @@ def _handle_output_by_index(output, i):
313
334
  ),
314
335
  placeholder_tokens_idx=None,
315
336
  placeholder_tokens_val=None,
337
+ token_steps=([output.token_steps[i]] if output.token_steps else None),
316
338
  )
317
339
  elif isinstance(output, BatchMultimodalOutput):
318
340
  new_output = BatchMultimodalOutput(
@@ -345,20 +367,11 @@ def _handle_output_by_index(output, i):
345
367
  class MultiHttpWorkerDetokenizerMixin:
346
368
  """Mixin class for DetokenizerManager"""
347
369
 
348
- def get_worker_ids_from_req_rids(self, rids):
349
- if isinstance(rids, list):
350
- worker_ids = [int(rid.split("_")[0]) for rid in rids]
351
- elif isinstance(rids, str):
352
- worker_ids = [int(rids.split("_")[0])]
353
- else:
354
- worker_ids = []
355
- return worker_ids
356
-
357
- def maybe_clear_socket_mapping(self):
370
+ def maybe_clear_socket_mapping(self: DetokenizerManager):
358
371
  if hasattr(self, "socket_mapping"):
359
372
  self.socket_mapping.clear_all_sockets()
360
373
 
361
- def multi_http_worker_event_loop(self):
374
+ def multi_http_worker_event_loop(self: DetokenizerManager):
362
375
  """The event loop that handles requests, for multi multi-http-worker mode"""
363
376
  self.socket_mapping = SocketMapping()
364
377
  while True:
@@ -366,23 +379,15 @@ class MultiHttpWorkerDetokenizerMixin:
366
379
  output = self._request_dispatcher(recv_obj)
367
380
  if output is None:
368
381
  continue
369
- # Extract worker_id from rid
370
- if isinstance(recv_obj.rids, list):
371
- worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
372
- else:
373
- raise RuntimeError(
374
- f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
375
- )
382
+
383
+ assert isinstance(
384
+ recv_obj, BaseBatchReq
385
+ ), "for multi-http-worker, recv_obj must be BaseBatchReq"
376
386
 
377
387
  # Send data using the corresponding socket
378
- for i, worker_id in enumerate(worker_ids):
379
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
380
- self.socket_mapping.register_ipc_mapping(
381
- recv_obj, worker_id, is_tokenizer=False
382
- )
383
- else:
384
- new_output = _handle_output_by_index(output, i)
385
- self.socket_mapping.send_output(worker_id, new_output)
388
+ for i, ipc_name in enumerate(recv_obj.http_worker_ipcs):
389
+ new_output = _handle_output_by_index(output, i)
390
+ self.socket_mapping.send_output(ipc_name, new_output)
386
391
 
387
392
 
388
393
  class MultiTokenizerRouter:
@@ -432,26 +437,17 @@ class MultiTokenizerRouter:
432
437
  await self._distribute_result_to_workers(recv_obj)
433
438
 
434
439
  async def _distribute_result_to_workers(self, recv_obj):
435
- """Distribute result to corresponding workers based on rid"""
436
- if isinstance(recv_obj, MultiTokenizerWrapper):
437
- worker_ids = [recv_obj.worker_id]
438
- recv_obj = recv_obj.obj
440
+ # Distribute result to each worker
441
+ if isinstance(recv_obj, BaseReq):
442
+ ipc_names = [recv_obj.http_worker_ipc]
443
+ elif isinstance(recv_obj, BaseBatchReq):
444
+ ipc_names = recv_obj.http_worker_ipcs
439
445
  else:
440
- worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
446
+ raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}")
441
447
 
442
- if len(worker_ids) == 0:
443
- logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
444
- return
445
-
446
- # Distribute result to each worker
447
- for i, worker_id in enumerate(worker_ids):
448
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
449
- self.socket_mapping.register_ipc_mapping(
450
- recv_obj, worker_id, is_tokenizer=True
451
- )
452
- else:
453
- new_recv_obj = _handle_output_by_index(recv_obj, i)
454
- self.socket_mapping.send_output(worker_id, new_recv_obj)
448
+ for i, ipc_name in enumerate(ipc_names):
449
+ new_recv_obj = _handle_output_by_index(recv_obj, i)
450
+ self.socket_mapping.send_output(ipc_name, new_recv_obj)
455
451
 
456
452
 
457
453
  class TokenizerWorker(TokenizerManager):
@@ -483,21 +479,15 @@ class TokenizerWorker(TokenizerManager):
483
479
  self.register_multi_tokenizer_communicator = _Communicator(
484
480
  self.send_to_scheduler, 2
485
481
  )
486
- self._result_dispatcher._mapping.append(
487
- (
488
- MultiTokenizerRegisterReq,
489
- self.register_multi_tokenizer_communicator.handle_recv,
490
- )
491
- )
492
482
 
493
- async def register_to_main_tokenizer_manager(self):
494
- """Register this worker to the main TokenizerManager"""
495
- # create a handle loop to receive messages from the main TokenizerManager
496
- self.auto_create_handle_loop()
497
- req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
498
- req.ipc_name = self.tokenizer_ipc_name
499
- _Communicator.enable_multi_tokenizer = True
500
- await self.register_multi_tokenizer_communicator(req)
483
+ def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]):
484
+
485
+ if isinstance(req, BaseReq):
486
+ req.http_worker_ipc = self.tokenizer_ipc_name
487
+ elif isinstance(req, BaseBatchReq):
488
+ req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids)
489
+ else:
490
+ raise ValueError(f"Unknown req type: {type(req)}")
501
491
 
502
492
 
503
493
  async def print_exception_wrapper(func):
@@ -581,3 +571,14 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
581
571
  logger.warning(
582
572
  "uvicorn.supervisors.multiprocess not found, skipping monkey patch"
583
573
  )
574
+
575
+
576
+ class SenderWrapper:
577
+ def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket):
578
+ self.port_args = port_args
579
+ self.send_to_scheduler = send_to_scheduler
580
+
581
+ def send_pyobj(self, obj):
582
+ if isinstance(obj, BaseReq):
583
+ obj.http_worker_ipc = self.port_args.tokenizer_ipc_name
584
+ self.send_to_scheduler.send_pyobj(obj)