sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -36,10 +36,10 @@ else:
36
36
  Image = Any
37
37
 
38
38
 
39
- # Parameters for a session
40
39
  @dataclass
41
40
  class BaseReq(ABC):
42
41
  rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
42
+ http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
43
43
 
44
44
  def regenerate_rid(self):
45
45
  """Generate a new request ID and return it."""
@@ -53,6 +53,7 @@ class BaseReq(ABC):
53
53
  @dataclass
54
54
  class BaseBatchReq(ABC):
55
55
  rids: Optional[List[str]] = field(default=None, kw_only=True)
56
+ http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
56
57
 
57
58
  def regenerate_rids(self):
58
59
  """Generate new request IDs and return them."""
@@ -60,9 +61,11 @@ class BaseBatchReq(ABC):
60
61
  return self.rids
61
62
 
62
63
 
64
+ # Parameters for a session
63
65
  @dataclass
64
66
  class SessionParams:
65
67
  id: Optional[str] = None
68
+ rid: Optional[str] = None
66
69
  offset: Optional[int] = None
67
70
  replace: Optional[bool] = None
68
71
  drop_previous_output: Optional[bool] = None
@@ -169,6 +172,9 @@ class GenerateReqInput(BaseReq):
169
172
  # (Internal) Whether to return bytes for image generation
170
173
  return_bytes: bool = False
171
174
 
175
+ # Whether to return entropy
176
+ return_entropy: bool = False
177
+
172
178
  def contains_mm_input(self) -> bool:
173
179
  return (
174
180
  has_valid_data(self.image_data)
@@ -567,6 +573,7 @@ class GenerateReqInput(BaseReq):
567
573
  no_logs=self.no_logs,
568
574
  custom_labels=self.custom_labels,
569
575
  return_bytes=self.return_bytes,
576
+ return_entropy=self.return_entropy,
570
577
  )
571
578
 
572
579
 
@@ -632,6 +639,9 @@ class TokenizedGenerateReqInput(BaseReq):
632
639
  # (Internal) Whether to return bytes for image generation
633
640
  return_bytes: bool = False
634
641
 
642
+ # Whether to return entropy
643
+ return_entropy: bool = False
644
+
635
645
 
636
646
  @dataclass
637
647
  class BatchTokenizedGenerateReqInput(BaseBatchReq):
@@ -815,6 +825,7 @@ class BatchTokenIDOutput(BaseBatchReq):
815
825
  completion_tokens: List[int]
816
826
  cached_tokens: List[int]
817
827
  spec_verify_ct: List[int]
828
+ spec_accepted_tokens: List[int]
818
829
 
819
830
  # Logprobs
820
831
  input_token_logprobs_val: List[float]
@@ -829,6 +840,7 @@ class BatchTokenIDOutput(BaseBatchReq):
829
840
  input_token_ids_logprobs_idx: List[List]
830
841
  output_token_ids_logprobs_val: List[List]
831
842
  output_token_ids_logprobs_idx: List[List]
843
+ output_token_entropy_val: List[float]
832
844
 
833
845
  # Hidden states
834
846
  output_hidden_states: List[List[float]]
@@ -839,6 +851,9 @@ class BatchTokenIDOutput(BaseBatchReq):
839
851
  placeholder_tokens_idx: List[Optional[List[int]]]
840
852
  placeholder_tokens_val: List[Optional[List[int]]]
841
853
 
854
+ # The trainer step id. Used to know which step's weights are used for sampling.
855
+ token_steps: List[List[int]] = None
856
+
842
857
 
843
858
  @dataclass
844
859
  class BatchMultimodalDecodeReq(BaseBatchReq):
@@ -860,11 +875,16 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
860
875
  completion_tokens: List[int]
861
876
  cached_tokens: List[int]
862
877
 
863
- # Placeholder token info
878
+ # The information of placeholder tokens (e.g., image token)
879
+ # idx is the index of the token in the prompt after expansion.
880
+ # val is the length of padded tokens after expansion.
864
881
  placeholder_tokens_idx: List[Optional[List[int]]]
865
882
  placeholder_tokens_val: List[Optional[List[int]]]
866
883
 
867
- return_bytes: bool = False
884
+ return_bytes: List[bool]
885
+
886
+ # The trainer step id. Used to know which step's weights are used for sampling.
887
+ token_steps: List[List[int]] = None
868
888
 
869
889
 
870
890
  @dataclass
@@ -881,6 +901,7 @@ class BatchStrOutput(BaseBatchReq):
881
901
  completion_tokens: List[int]
882
902
  cached_tokens: List[int]
883
903
  spec_verify_ct: List[int]
904
+ spec_accepted_tokens: List[int]
884
905
 
885
906
  # Logprobs
886
907
  input_token_logprobs_val: List[float]
@@ -895,13 +916,20 @@ class BatchStrOutput(BaseBatchReq):
895
916
  input_token_ids_logprobs_idx: List[List]
896
917
  output_token_ids_logprobs_val: List[List]
897
918
  output_token_ids_logprobs_idx: List[List]
919
+ output_token_entropy_val: List[float]
898
920
 
899
921
  # Hidden states
900
922
  output_hidden_states: List[List[float]]
901
923
 
924
+ # The information of placeholder tokens (e.g., image token)
925
+ # idx is the index of the token in the prompt after expansion.
926
+ # val is the length of padded tokens after expansion.
902
927
  placeholder_tokens_idx: List[Optional[List[int]]]
903
928
  placeholder_tokens_val: List[Optional[List[int]]]
904
929
 
930
+ # The trainer step id. Used to know which step's weights are used for sampling.
931
+ token_steps: List[List[int]] = None
932
+
905
933
 
906
934
  @dataclass
907
935
  class BatchMultimodalOutput(BaseBatchReq):
@@ -933,7 +961,7 @@ class BatchEmbeddingOutput(BaseBatchReq):
933
961
  # The finish reason
934
962
  finished_reasons: List[BaseFinishReason]
935
963
  # The output embedding
936
- embeddings: List[List[float]]
964
+ embeddings: Union[List[List[float]], List[Dict[int, float]]]
937
965
  # Token counts
938
966
  prompt_tokens: List[int]
939
967
  cached_tokens: List[int]
@@ -978,6 +1006,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
978
1006
  torch_empty_cache: bool = False
979
1007
  # Whether to keep the scheduler paused after weight update
980
1008
  keep_pause: bool = False
1009
+ # The trainer step id. Used to know which step's weights are used for sampling.
1010
+ token_step: int = 0
981
1011
 
982
1012
 
983
1013
  @dataclass
@@ -1050,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
1050
1080
  backend: str = "nccl"
1051
1081
 
1052
1082
 
1083
+ # Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
1084
+ # are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
1085
+ @dataclass
1086
+ class UpdateWeightsFromIPCReqInput(BaseReq):
1087
+ # ZMQ socket paths for each device UUID
1088
+ zmq_handles: Dict[str, str]
1089
+ # Whether to flush cache after weight update
1090
+ flush_cache: bool = True
1091
+ # Optional: Update weight version along with weights
1092
+ weight_version: Optional[str] = None
1093
+
1094
+
1095
+ @dataclass
1096
+ class UpdateWeightsFromIPCReqOutput(BaseReq):
1097
+ success: bool
1098
+ message: str
1099
+
1100
+
1053
1101
  @dataclass
1054
1102
  class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
1055
1103
  success: bool
@@ -1206,6 +1254,8 @@ class ProfileReqInput(BaseReq):
1206
1254
  profile_by_stage: bool = False
1207
1255
  with_stack: Optional[bool] = None
1208
1256
  record_shapes: Optional[bool] = None
1257
+ # Merge profiles from all ranks into a single trace
1258
+ merge_profiles: bool = False
1209
1259
 
1210
1260
 
1211
1261
  class ProfileReqType(Enum):
@@ -1224,6 +1274,8 @@ class ProfileReq(BaseReq):
1224
1274
  with_stack: Optional[bool] = None
1225
1275
  record_shapes: Optional[bool] = None
1226
1276
  profile_id: Optional[str] = None
1277
+ # Merge profiles from all ranks into a single trace
1278
+ merge_profiles: bool = False
1227
1279
 
1228
1280
 
1229
1281
  @dataclass
@@ -1375,18 +1427,6 @@ class LoRAUpdateOutput(BaseReq):
1375
1427
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
1376
1428
 
1377
1429
 
1378
- @dataclass
1379
- class MultiTokenizerRegisterReq(BaseBatchReq):
1380
- ipc_name: Optional[str] = None
1381
-
1382
-
1383
- @dataclass
1384
- class MultiTokenizerWrapper:
1385
- # FIXME(lsyin): remove this
1386
- worker_id: int
1387
- obj: Optional[Any] = None
1388
-
1389
-
1390
1430
  class BlockReqType(Enum):
1391
1431
  BLOCK = 1
1392
1432
  UNBLOCK = 2
@@ -1415,6 +1455,16 @@ class WatchLoadUpdateReq(BaseReq):
1415
1455
  loads: List[GetLoadReqOutput]
1416
1456
 
1417
1457
 
1458
+ @dataclass
1459
+ class LazyDumpTensorsReqInput(BaseReq):
1460
+ pass
1461
+
1462
+
1463
+ @dataclass
1464
+ class LazyDumpTensorsReqOutput(BaseReq):
1465
+ success: bool
1466
+
1467
+
1418
1468
  def _check_all_req_types():
1419
1469
  """A helper function to check all request types are defined in this file."""
1420
1470
  import inspect
@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
16
16
  Modality,
17
17
  MultimodalDataItem,
18
18
  MultimodalInputs,
19
- global_server_args_dict,
20
19
  )
21
20
  from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
22
21
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
+ from sglang.srt.server_args import get_global_server_args
23
23
  from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
24
24
  from sglang.utils import logger
25
25
 
@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
280
280
  input_ids_tensor[input_ids_tensor == token_id] = pad_value
281
281
 
282
282
  ret_input_ids = input_ids_tensor.tolist()
283
-
284
283
  return ret_input_ids
285
284
 
286
285
 
@@ -428,7 +427,7 @@ def _adjust_embedding_length(
428
427
  f"tokens from multimodal embeddings."
429
428
  )
430
429
  if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
431
- chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
430
+ chunked_prefill_size = get_global_server_args().chunked_prefill_size
432
431
  if chunked_prefill_size != -1:
433
432
  logger.warning(
434
433
  "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
@@ -507,7 +506,7 @@ def embed_mm_inputs(
507
506
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
508
507
  ] = None,
509
508
  placeholder_tokens: dict[Modality, List[int]] = None,
510
- use_deepstack: bool = False,
509
+ use_deepstack: Dict[Modality, bool] = {},
511
510
  ) -> Optional[torch.Tensor]:
512
511
  """
513
512
  Embed multimodal inputs and integrate them with text token embeddings.
@@ -533,7 +532,9 @@ def embed_mm_inputs(
533
532
  for mm_inputs in mm_inputs_list:
534
533
  item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
535
534
 
536
- embeddings, masks, deepstack_embeddings = [], [], []
535
+ # deepstack_embeddings: per-modality
536
+ modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
537
+
537
538
  # 2. Get multimodal embedding separately
538
539
  # Try get mm embedding if any
539
540
  for modality in Modality.all():
@@ -549,7 +550,8 @@ def embed_mm_inputs(
549
550
  # "image", "video", etc
550
551
  modality_id = modality.name.lower()
551
552
  embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
552
- if len(items) != 0 and embedder is not None:
553
+ if len(items) != 0:
554
+ assert embedder is not None, f"no embedding method found for {modality}"
553
555
  placeholder_tensor = torch.as_tensor(
554
556
  [item.pad_value for item in items],
555
557
  device=input_ids.device,
@@ -580,11 +582,12 @@ def embed_mm_inputs(
580
582
  items_offset_list=items_offsets,
581
583
  )
582
584
 
583
- if use_deepstack and embedding is not None:
585
+ if use_deepstack.get(modality, None) and embedding is not None:
584
586
  embedding, deepstack_embedding = (
585
587
  multimodal_model.separate_deepstack_embeds(embedding)
586
588
  )
587
589
  deepstack_embeddings += [deepstack_embedding]
590
+ modalities += [modality]
588
591
  embeddings += [embedding]
589
592
  masks += [mask]
590
593
 
@@ -597,17 +600,14 @@ def embed_mm_inputs(
597
600
  input_ids.clamp_(min=0, max=vocab_size - 1)
598
601
  inputs_embeds = input_embedding(input_ids)
599
602
 
600
- # 4. scatter embeddings into input embedding
601
-
602
603
  # deepstack embedding
603
604
  if use_deepstack:
604
- num_deepstack_embeddings = (
605
- len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
606
- )
605
+ num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
606
+
607
607
  deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
608
608
  inputs_embeds.shape[-1] * num_deepstack_embeddings,
609
609
  )
610
-
610
+ # a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
611
611
  input_deepstack_embeds = torch.zeros(
612
612
  deepstack_embedding_shape,
613
613
  device=inputs_embeds.device,
@@ -616,14 +616,16 @@ def embed_mm_inputs(
616
616
 
617
617
  other_info["input_deepstack_embeds"] = input_deepstack_embeds
618
618
 
619
- for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
619
+ # 4. scatter embeddings into input embedding
620
+ for i, modality, embedding, mask in zip(
621
+ range(len(embeddings)), modalities, embeddings, masks
622
+ ):
620
623
  if embedding is None or mask is None:
621
624
  continue
622
625
  # in-place update
623
626
  indices = torch.where(mask.squeeze(dim=-1))[0]
624
627
  inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
625
-
626
- if use_deepstack:
628
+ if use_deepstack.get(modality, None):
627
629
  input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
628
630
  inputs_embeds.device, inputs_embeds.dtype
629
631
  )
@@ -640,7 +642,7 @@ def general_mm_embed_routine(
640
642
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
641
643
  ] = None,
642
644
  placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
643
- use_deepstack: bool = False,
645
+ use_deepstack: Dict[Modality, bool] = {},
644
646
  **kwargs,
645
647
  ) -> torch.Tensor:
646
648
  """
@@ -652,7 +654,7 @@ def general_mm_embed_routine(
652
654
  language_model: Base language model to use
653
655
  data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
654
656
  placeholder_tokens: Token IDs for multimodal placeholders
655
- use_deepstack: Whether to use deepstack embeddings
657
+ use_deepstack: Whether to use deepstack embeddings for each modality, default False
656
658
  **kwargs: Additional arguments passed to language model
657
659
 
658
660
  Returns:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # Copyright 2023-2024 SGLang Team
2
4
  # Licensed under the Apache License, Version 2.0 (the "License");
3
5
  # you may not use this file except in compliance with the License.
@@ -21,7 +23,7 @@ import sys
21
23
  import threading
22
24
  from functools import partialmethod
23
25
  from multiprocessing import shared_memory
24
- from typing import Any, Dict
26
+ from typing import TYPE_CHECKING, Any, Dict, Union
25
27
 
26
28
  import setproctitle
27
29
  import zmq
@@ -30,12 +32,12 @@ import zmq.asyncio
30
32
  from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
31
33
  from sglang.srt.managers.disagg_service import start_disagg_service
32
34
  from sglang.srt.managers.io_struct import (
35
+ BaseBatchReq,
36
+ BaseReq,
33
37
  BatchEmbeddingOutput,
34
38
  BatchMultimodalOutput,
35
39
  BatchStrOutput,
36
40
  BatchTokenIDOutput,
37
- MultiTokenizerRegisterReq,
38
- MultiTokenizerWrapper,
39
41
  )
40
42
  from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
41
43
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs
43
45
  from sglang.srt.utils import get_zmq_socket, kill_process_tree
44
46
  from sglang.utils import get_exception_traceback
45
47
 
48
+ if TYPE_CHECKING:
49
+ from sglang.srt.managers.detokenizer_manager import DetokenizerManager
50
+
46
51
  logger = logging.getLogger(__name__)
47
52
 
48
53
 
@@ -56,29 +61,24 @@ class SocketMapping:
56
61
  socket.close()
57
62
  self._mapping.clear()
58
63
 
59
- def register_ipc_mapping(
60
- self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
61
- ):
64
+ def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool):
62
65
  type_str = "tokenizer" if is_tokenizer else "detokenizer"
63
- if worker_id in self._mapping:
64
- logger.warning(
65
- f"{type_str} already registered with worker {worker_id}, skipping..."
66
- )
66
+ if ipc_name in self._mapping:
67
+ logger.warning(f"{type_str} already registered {ipc_name=}, skipping...")
67
68
  return
68
- logger.info(
69
- f"{type_str} not registered with worker {worker_id}, registering..."
70
- )
71
- socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
72
- self._mapping[worker_id] = socket
73
- self._mapping[worker_id].send_pyobj(recv_obj)
74
-
75
- def send_output(self, worker_id: str, output: Any):
76
- if worker_id not in self._mapping:
77
- logger.error(
78
- f"worker ID {worker_id} not registered. Check if the server Process is alive"
79
- )
69
+ logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...")
70
+ socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
71
+ self._mapping[ipc_name] = socket
72
+
73
+ def send_output(self, ipc_name: str, output: Any):
74
+ if ipc_name is None:
75
+ # Some unhandled cases
76
+ logger.warning(f"IPC name is None, output type={type(output)}, skipping...")
80
77
  return
81
- self._mapping[worker_id].send_pyobj(output)
78
+
79
+ if ipc_name not in self._mapping:
80
+ self._register_ipc_mapping(ipc_name, is_tokenizer=False)
81
+ self._mapping[ipc_name].send_pyobj(output)
82
82
 
83
83
 
84
84
  def _handle_output_by_index(output, i):
@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
190
190
  if output.output_token_ids_logprobs_idx
191
191
  else None
192
192
  ),
193
+ output_token_entropy_val=(
194
+ [output.output_token_entropy_val[i]]
195
+ if output.output_token_entropy_val
196
+ else None
197
+ ),
193
198
  output_hidden_states=(
194
199
  [output.output_hidden_states[i]]
195
200
  if output.output_hidden_states
@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
197
202
  ),
198
203
  placeholder_tokens_idx=None,
199
204
  placeholder_tokens_val=None,
205
+ token_steps=([output.token_steps[i]] if output.token_steps else None),
200
206
  )
201
207
  elif isinstance(output, BatchEmbeddingOutput):
202
208
  new_output = BatchEmbeddingOutput(
@@ -246,6 +252,11 @@ def _handle_output_by_index(output, i):
246
252
  spec_verify_ct=(
247
253
  [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
248
254
  ),
255
+ spec_accepted_tokens=(
256
+ [output.spec_accepted_tokens[i]]
257
+ if len(output.spec_accepted_tokens) > i
258
+ else None
259
+ ),
249
260
  input_token_logprobs_val=(
250
261
  [output.input_token_logprobs_val[i]]
251
262
  if output.input_token_logprobs_val
@@ -306,6 +317,11 @@ def _handle_output_by_index(output, i):
306
317
  if output.output_token_ids_logprobs_idx
307
318
  else None
308
319
  ),
320
+ output_token_entropy_val=(
321
+ [output.output_token_entropy_val[i]]
322
+ if output.output_token_entropy_val
323
+ else None
324
+ ),
309
325
  output_hidden_states=(
310
326
  [output.output_hidden_states[i]]
311
327
  if output.output_hidden_states
@@ -313,6 +329,7 @@ def _handle_output_by_index(output, i):
313
329
  ),
314
330
  placeholder_tokens_idx=None,
315
331
  placeholder_tokens_val=None,
332
+ token_steps=([output.token_steps[i]] if output.token_steps else None),
316
333
  )
317
334
  elif isinstance(output, BatchMultimodalOutput):
318
335
  new_output = BatchMultimodalOutput(
@@ -345,20 +362,11 @@ def _handle_output_by_index(output, i):
345
362
  class MultiHttpWorkerDetokenizerMixin:
346
363
  """Mixin class for DetokenizerManager"""
347
364
 
348
- def get_worker_ids_from_req_rids(self, rids):
349
- if isinstance(rids, list):
350
- worker_ids = [int(rid.split("_")[0]) for rid in rids]
351
- elif isinstance(rids, str):
352
- worker_ids = [int(rids.split("_")[0])]
353
- else:
354
- worker_ids = []
355
- return worker_ids
356
-
357
- def maybe_clear_socket_mapping(self):
365
+ def maybe_clear_socket_mapping(self: DetokenizerManager):
358
366
  if hasattr(self, "socket_mapping"):
359
367
  self.socket_mapping.clear_all_sockets()
360
368
 
361
- def multi_http_worker_event_loop(self):
369
+ def multi_http_worker_event_loop(self: DetokenizerManager):
362
370
  """The event loop that handles requests, for multi multi-http-worker mode"""
363
371
  self.socket_mapping = SocketMapping()
364
372
  while True:
@@ -366,23 +374,15 @@ class MultiHttpWorkerDetokenizerMixin:
366
374
  output = self._request_dispatcher(recv_obj)
367
375
  if output is None:
368
376
  continue
369
- # Extract worker_id from rid
370
- if isinstance(recv_obj.rids, list):
371
- worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
372
- else:
373
- raise RuntimeError(
374
- f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
375
- )
377
+
378
+ assert isinstance(
379
+ recv_obj, BaseBatchReq
380
+ ), "for multi-http-worker, recv_obj must be BaseBatchReq"
376
381
 
377
382
  # Send data using the corresponding socket
378
- for i, worker_id in enumerate(worker_ids):
379
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
380
- self.socket_mapping.register_ipc_mapping(
381
- recv_obj, worker_id, is_tokenizer=False
382
- )
383
- else:
384
- new_output = _handle_output_by_index(output, i)
385
- self.socket_mapping.send_output(worker_id, new_output)
383
+ for i, ipc_name in enumerate(recv_obj.http_worker_ipcs):
384
+ new_output = _handle_output_by_index(output, i)
385
+ self.socket_mapping.send_output(ipc_name, new_output)
386
386
 
387
387
 
388
388
  class MultiTokenizerRouter:
@@ -432,26 +432,17 @@ class MultiTokenizerRouter:
432
432
  await self._distribute_result_to_workers(recv_obj)
433
433
 
434
434
  async def _distribute_result_to_workers(self, recv_obj):
435
- """Distribute result to corresponding workers based on rid"""
436
- if isinstance(recv_obj, MultiTokenizerWrapper):
437
- worker_ids = [recv_obj.worker_id]
438
- recv_obj = recv_obj.obj
435
+ # Distribute result to each worker
436
+ if isinstance(recv_obj, BaseReq):
437
+ ipc_names = [recv_obj.http_worker_ipc]
438
+ elif isinstance(recv_obj, BaseBatchReq):
439
+ ipc_names = recv_obj.http_worker_ipcs
439
440
  else:
440
- worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
441
-
442
- if len(worker_ids) == 0:
443
- logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
444
- return
441
+ raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}")
445
442
 
446
- # Distribute result to each worker
447
- for i, worker_id in enumerate(worker_ids):
448
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
449
- self.socket_mapping.register_ipc_mapping(
450
- recv_obj, worker_id, is_tokenizer=True
451
- )
452
- else:
453
- new_recv_obj = _handle_output_by_index(recv_obj, i)
454
- self.socket_mapping.send_output(worker_id, new_recv_obj)
443
+ for i, ipc_name in enumerate(ipc_names):
444
+ new_recv_obj = _handle_output_by_index(recv_obj, i)
445
+ self.socket_mapping.send_output(ipc_name, new_recv_obj)
455
446
 
456
447
 
457
448
  class TokenizerWorker(TokenizerManager):
@@ -483,21 +474,15 @@ class TokenizerWorker(TokenizerManager):
483
474
  self.register_multi_tokenizer_communicator = _Communicator(
484
475
  self.send_to_scheduler, 2
485
476
  )
486
- self._result_dispatcher._mapping.append(
487
- (
488
- MultiTokenizerRegisterReq,
489
- self.register_multi_tokenizer_communicator.handle_recv,
490
- )
491
- )
492
477
 
493
- async def register_to_main_tokenizer_manager(self):
494
- """Register this worker to the main TokenizerManager"""
495
- # create a handle loop to receive messages from the main TokenizerManager
496
- self.auto_create_handle_loop()
497
- req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
498
- req.ipc_name = self.tokenizer_ipc_name
499
- _Communicator.enable_multi_tokenizer = True
500
- await self.register_multi_tokenizer_communicator(req)
478
+ def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]):
479
+
480
+ if isinstance(req, BaseReq):
481
+ req.http_worker_ipc = self.tokenizer_ipc_name
482
+ elif isinstance(req, BaseBatchReq):
483
+ req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids)
484
+ else:
485
+ raise ValueError(f"Unknown req type: {type(req)}")
501
486
 
502
487
 
503
488
  async def print_exception_wrapper(func):