sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,578 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """package for sglang requests tracing"""
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ import os
20
+ import random
21
+ import threading
22
+ import time
23
+ import uuid
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
26
+
27
+ if TYPE_CHECKING:
28
+ from sglang.srt.managers.scheduler import Req
29
+
30
+ logger = logging.getLogger(__name__)
31
+ opentelemetry_imported = False
32
+ tracing_enabled = False
33
+
34
+ try:
35
+ from opentelemetry import context, propagate, trace
36
+ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
37
+ from opentelemetry.sdk.resources import SERVICE_NAME, Resource
38
+ from opentelemetry.sdk.trace import TracerProvider, id_generator
39
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor
40
+
41
+ opentelemetry_imported = True
42
+ except ImportError:
43
+
44
+ class id_generator:
45
+ class IdGenerator:
46
+ pass
47
+
48
+ logger.info("opentelemetry package is not installed, tracing disabled")
49
+
50
+
51
+ @dataclass
52
+ class SglangTraceThreadInfo:
53
+ host_id: str
54
+ pid: int
55
+ thread_label: str
56
+ tp_rank: int
57
+ dp_rank: int
58
+ tracer: trace.Tracer
59
+
60
+
61
+ @dataclass
62
+ class SglangTraceSliceContext:
63
+ slice_name: str
64
+ span: Optional[trace.span.Span] = None
65
+ # When True, defers slice_name assignment until trace_slice_end()
66
+ anonymous: bool = False
67
+
68
+
69
+ @dataclass
70
+ class SglangTraceThreadContext:
71
+ thread_info: SglangTraceThreadInfo
72
+ cur_slice_stack: List[SglangTraceSliceContext]
73
+ thread_span: Optional[trace.span.Span] = None
74
+ # Record the most recently completed span as the previous span for the next span to be created.
75
+ last_span_context: Optional[trace.span.SpanContext] = None
76
+
77
+
78
+ @dataclass
79
+ class SglangTraceReqContext:
80
+ rid: str
81
+ start_time_ns: int
82
+ threads_context: Dict[int, SglangTraceThreadContext]
83
+ bootstrap_room: Optional[int] = None
84
+
85
+ # Indicates whether this instance is a replica from the main process.
86
+ # When True, root_span is None and only root_span_context is preserved.
87
+ is_copy: bool = False
88
+ root_span: Optional[trace.span.Span] = None
89
+ root_span_context: Optional[context.Context] = None
90
+
91
+
92
+ @dataclass
93
+ class SglangTracePropagateContext:
94
+ root_span_context: context.Context
95
+ prev_span_context: Optional[trace.span.SpanContext]
96
+
97
+ def to_dict(self):
98
+ carrier: dict[str, str] = {}
99
+ context.attach(self.root_span_context)
100
+ propagate.inject(carrier)
101
+
102
+ if self.prev_span_context:
103
+ return {
104
+ "root_span": carrier,
105
+ "prev_span": {
106
+ "span_id": self.prev_span_context.span_id,
107
+ "trace_id": self.prev_span_context.trace_id,
108
+ },
109
+ }
110
+ else:
111
+ return {"root_span": carrier, "prev_span": "None"}
112
+
113
+ @classmethod
114
+ def instance_from_dict(cls, d):
115
+ if "root_span" not in d or "prev_span" not in d:
116
+ return None
117
+
118
+ carrier = d["root_span"]
119
+ root_span_context = propagate.extract(carrier)
120
+
121
+ if d["prev_span"] == "None":
122
+ prev_span_context = None
123
+ else:
124
+ prev_span_context = trace.span.SpanContext(
125
+ trace_id=d["prev_span"]["trace_id"],
126
+ span_id=d["prev_span"]["span_id"],
127
+ is_remote=True,
128
+ )
129
+
130
+ return cls(root_span_context, prev_span_context)
131
+
132
+
133
+ class SglangTraceCustomIdGenerator(id_generator.IdGenerator):
134
+ """
135
+ The default IdGenerator may produce duplicate trace IDs across multiple TP scheduler processes,
136
+ hence a custom IdGenerator is implemented.
137
+ """
138
+
139
+ def __init__(self):
140
+ super().__init__()
141
+ self.local_random = random.Random()
142
+ self.local_random.seed(time.time())
143
+
144
+ def generate_trace_id(self) -> int:
145
+ return self.local_random.getrandbits(64)
146
+
147
+ def generate_span_id(self) -> int:
148
+ return self.local_random.getrandbits(64)
149
+
150
+
151
+ # global variables
152
+ threads_info: Dict[int, SglangTraceThreadInfo] = {}
153
+ reqs_context: Dict[str, SglangTraceReqContext] = {}
154
+
155
+ __get_cur_time_ns = lambda: int(time.time() * 1e9)
156
+
157
+
158
+ def __get_host_id() -> str:
159
+ """
160
+ In distributed tracing systems, obtain a unique node identifier
161
+ and inject it into all subsequently generated spans
162
+ to prevent PID conflicts between threads on different nodes.
163
+ """
164
+ if os.path.exists("/etc/machine-id"):
165
+ try:
166
+ with open("/etc/machine-id", "r") as f:
167
+ return f.read().strip()
168
+ except:
169
+ pass
170
+
171
+ mac = uuid.getnode()
172
+ if mac != 0:
173
+ return uuid.UUID(int=mac).hex
174
+
175
+ return "unknown"
176
+
177
+
178
+ # Should be called by each tracked process.
179
+ def process_tracing_init(otlp_endpoint, server_name):
180
+ global tracing_enabled
181
+ global __get_cur_time_ns
182
+ if not opentelemetry_imported:
183
+ tracing_enabled = False
184
+ return
185
+
186
+ try:
187
+ resource = Resource.create(
188
+ attributes={
189
+ SERVICE_NAME: server_name,
190
+ }
191
+ )
192
+ tracer_provider = TracerProvider(
193
+ resource=resource, id_generator=SglangTraceCustomIdGenerator()
194
+ )
195
+
196
+ processor = BatchSpanProcessor(
197
+ OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
198
+ )
199
+ tracer_provider.add_span_processor(processor)
200
+ trace.set_tracer_provider(tracer_provider)
201
+ except Exception as e:
202
+ logger.error(f": initialize opentelemetry error:{e}")
203
+ logger.warning("pelease set correct otlp endpoint")
204
+ tracing_enabled = False
205
+ return
206
+
207
+ if hasattr(time, "time_ns"):
208
+ __get_cur_time_ns = lambda: int(time.time_ns())
209
+
210
+ tracing_enabled = True
211
+
212
+
213
+ # Should be called by each tracked thread.
214
+ def trace_set_thread_info(
215
+ thread_label: str, tp_rank: Optional[int] = None, dp_rank: Optional[int] = None
216
+ ):
217
+ if not tracing_enabled:
218
+ return
219
+
220
+ pid = threading.get_native_id()
221
+ if pid in threads_info:
222
+ return
223
+
224
+ threads_info[pid] = SglangTraceThreadInfo(
225
+ host_id=__get_host_id(),
226
+ pid=pid,
227
+ thread_label=thread_label,
228
+ tp_rank=tp_rank,
229
+ dp_rank=dp_rank,
230
+ tracer=trace.get_tracer("sglang server"),
231
+ )
232
+
233
+
234
+ def __create_thread_context(pid, req_span_context, ts: Optional[int] = None):
235
+ if pid not in threads_info:
236
+ trace_set_thread_info("unknown")
237
+
238
+ thread_info = threads_info[pid]
239
+ thread_context = SglangTraceThreadContext(
240
+ thread_info=thread_info,
241
+ cur_slice_stack=[],
242
+ )
243
+
244
+ thread_name = f"{thread_info.thread_label}"
245
+ if thread_info.tp_rank is not None:
246
+ thread_name += f" [TP {thread_info.tp_rank}] "
247
+ thread_name += f"(host:{thread_info.host_id[:8]} | pid:{pid})"
248
+ ts = ts or __get_cur_time_ns()
249
+ thread_context.thread_span = thread_context.thread_info.tracer.start_span(
250
+ name=thread_name,
251
+ start_time=ts,
252
+ context=req_span_context,
253
+ )
254
+
255
+ if thread_info.tp_rank is not None:
256
+ thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank})
257
+
258
+ thread_context.thread_span.set_attributes(
259
+ {
260
+ "host_id": thread_info.host_id,
261
+ "pid": thread_info.pid,
262
+ "thread_label": thread_info.thread_label,
263
+ }
264
+ )
265
+
266
+ return thread_context
267
+
268
+
269
+ def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
270
+ if not tracing_enabled:
271
+ return None
272
+
273
+ rid = str(rid)
274
+ if rid not in reqs_context or not reqs_context[rid].root_span_context:
275
+ return None
276
+
277
+ pid = threading.get_native_id()
278
+ prev_span_context = None
279
+ thread_context = reqs_context[rid].threads_context[pid]
280
+ if thread_context.cur_slice_stack:
281
+ cur_slice_info = thread_context.cur_slice_stack[0]
282
+ prev_span_context = cur_slice_info.span.get_span_context()
283
+ elif thread_context.last_span_context:
284
+ prev_span_context = thread_context.last_span_context
285
+
286
+ trace_context = SglangTracePropagateContext(
287
+ reqs_context[rid].root_span_context, prev_span_context
288
+ )
289
+ return trace_context.to_dict()
290
+
291
+
292
+ def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]]):
293
+ if not tracing_enabled:
294
+ return
295
+ if not trace_context:
296
+ return
297
+
298
+ trace_context = SglangTracePropagateContext.instance_from_dict(trace_context)
299
+ if not trace_context:
300
+ return
301
+
302
+ rid = str(rid)
303
+ # Create a copy of the request context
304
+ if rid not in reqs_context:
305
+ reqs_context[rid] = SglangTraceReqContext(
306
+ rid=rid,
307
+ start_time_ns=__get_cur_time_ns(),
308
+ threads_context={},
309
+ root_span_context=trace_context.root_span_context,
310
+ is_copy=True,
311
+ )
312
+
313
+ pid = threading.get_native_id()
314
+
315
+ if pid in reqs_context[rid].threads_context:
316
+ return
317
+
318
+ # Create new thread context.
319
+ reqs_context[rid].threads_context[pid] = __create_thread_context(
320
+ pid,
321
+ trace_context.root_span_context,
322
+ reqs_context[rid].start_time_ns,
323
+ )
324
+
325
+ reqs_context[rid].threads_context[
326
+ pid
327
+ ].last_span_context = trace_context.prev_span_context
328
+
329
+
330
+ def trace_req_start(
331
+ rid: str,
332
+ bootstrap_room: Optional[int] = None,
333
+ ts: Optional[int] = None,
334
+ ):
335
+ if not tracing_enabled:
336
+ return
337
+
338
+ rid = str(rid)
339
+
340
+ ts = ts or __get_cur_time_ns()
341
+
342
+ pid = threading.get_native_id()
343
+ if pid not in threads_info:
344
+ return
345
+
346
+ # create req context and root span
347
+ reqs_context[rid] = SglangTraceReqContext(
348
+ rid=rid,
349
+ start_time_ns=ts,
350
+ threads_context={},
351
+ bootstrap_room=bootstrap_room,
352
+ is_copy=False,
353
+ )
354
+
355
+ # Drop the worker_id added by MultiTokenizer
356
+ orig_rid = rid.split("_")[-1]
357
+ tracer = threads_info[pid].tracer
358
+ root_span = tracer.start_span(
359
+ name=f"Req {orig_rid[:8]}",
360
+ start_time=ts,
361
+ )
362
+
363
+ root_span.set_attributes(
364
+ {
365
+ "rid": rid,
366
+ "bootstrap_room": bootstrap_room if bootstrap_room else "None",
367
+ }
368
+ )
369
+
370
+ reqs_context[rid].root_span = root_span
371
+ reqs_context[rid].root_span_context = trace.set_span_in_context(root_span)
372
+
373
+ # create thread context and thread span
374
+ reqs_context[rid].threads_context[pid] = __create_thread_context(
375
+ pid,
376
+ reqs_context[rid].root_span_context,
377
+ ts,
378
+ )
379
+
380
+
381
+ def trace_req_finish(
382
+ rid: str, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None
383
+ ):
384
+ if not tracing_enabled:
385
+ return
386
+
387
+ rid = str(rid)
388
+ if rid not in reqs_context:
389
+ return
390
+
391
+ req_context = reqs_context[rid]
392
+ ts = ts or __get_cur_time_ns()
393
+
394
+ # End all unclosed thread spans.
395
+ for thread_context in req_context.threads_context.values():
396
+ thread_context.thread_span.end(end_time=ts)
397
+
398
+ if attrs:
399
+ req_context.root_span.set_attributes(attrs)
400
+
401
+ req_context.root_span.end(end_time=ts)
402
+
403
+ del reqs_context[rid]
404
+
405
+
406
+ def trace_slice_start(
407
+ name: str,
408
+ rid: str,
409
+ ts: Optional[int] = None,
410
+ anonymous: bool = False,
411
+ ):
412
+ if not tracing_enabled:
413
+ return
414
+
415
+ rid = str(rid)
416
+ if rid not in reqs_context:
417
+ return
418
+
419
+ pid = threading.get_native_id()
420
+ if pid not in reqs_context[rid].threads_context:
421
+ return
422
+
423
+ thread_context = reqs_context[rid].threads_context[pid]
424
+
425
+ ts = ts or __get_cur_time_ns()
426
+
427
+ slice_info = SglangTraceSliceContext(
428
+ slice_name=name,
429
+ anonymous=anonymous,
430
+ )
431
+
432
+ # find prev slice
433
+ prev_span_context = None
434
+ if not thread_context.cur_slice_stack:
435
+ if thread_context.last_span_context:
436
+ prev_span_context = thread_context.last_span_context
437
+
438
+ parent_span = thread_context.thread_span
439
+ if thread_context.cur_slice_stack:
440
+ parent_span = thread_context.cur_slice_stack[-1].span
441
+
442
+ parent_span_context = trace.set_span_in_context(parent_span)
443
+ span = thread_context.thread_info.tracer.start_span(
444
+ name=slice_info.slice_name,
445
+ start_time=ts,
446
+ context=parent_span_context,
447
+ )
448
+
449
+ if prev_span_context:
450
+ span.add_link(prev_span_context)
451
+
452
+ slice_info.span = span
453
+
454
+ thread_context.cur_slice_stack.append(slice_info)
455
+
456
+
457
+ def trace_slice_end(
458
+ name: str,
459
+ rid: str,
460
+ ts: Optional[int] = None,
461
+ attrs: Optional[Dict[str, Any]] = None,
462
+ auto_next_anon: bool = False,
463
+ thread_finish_flag: bool = False,
464
+ ):
465
+ if not tracing_enabled:
466
+ return
467
+
468
+ rid = str(rid)
469
+ if rid not in reqs_context:
470
+ return
471
+
472
+ pid = threading.get_native_id()
473
+ if pid not in reqs_context[rid].threads_context:
474
+ return
475
+
476
+ thread_context = reqs_context[rid].threads_context[pid]
477
+
478
+ if not thread_context.cur_slice_stack:
479
+ logger.warning(f"No matching with the SLICE_START event{name} is required.")
480
+ return
481
+
482
+ ts = ts or __get_cur_time_ns()
483
+ slice_info = thread_context.cur_slice_stack[-1]
484
+ span = slice_info.span
485
+
486
+ if slice_info.anonymous:
487
+ span.update_name(name)
488
+ else:
489
+ span = slice_info.span
490
+ if slice_info.slice_name != name:
491
+ span.set_status(trace.Status(trace.StatusCode.ERROR))
492
+ logger.warning(f"Slice name mismatch: {name} != {slice_info.slice_name}")
493
+
494
+ if attrs:
495
+ span.set_attributes(attrs)
496
+
497
+ span.end(end_time=ts)
498
+
499
+ thread_context.cur_slice_stack.pop()
500
+ if len(thread_context.cur_slice_stack) == 0:
501
+ thread_context.last_span_context = span.get_span_context()
502
+
503
+ # If this is the last slice in the thread,
504
+ # release the thread context and check whether to release the request context.
505
+ if thread_finish_flag:
506
+ thread_context.thread_span.end(end_time=ts)
507
+ del reqs_context[rid].threads_context[pid]
508
+ if reqs_context[rid].is_copy and not reqs_context[rid].threads_context:
509
+ del reqs_context[rid]
510
+ return
511
+
512
+ if auto_next_anon:
513
+ trace_slice_start("", rid, ts, True)
514
+
515
+
516
+ # alias
517
+ trace_slice = trace_slice_end
518
+
519
+
520
+ # Add event to the current slice on the same thread with the same rid.
521
+ def trace_event(name: str, rid: str, ts: Optional[int] = None):
522
+ if not tracing_enabled:
523
+ return
524
+
525
+ rid = str(rid)
526
+ if rid not in reqs_context:
527
+ return
528
+
529
+ pid = threading.get_native_id()
530
+ if pid not in reqs_context[rid].threads_context:
531
+ return
532
+
533
+ thread_context = reqs_context[rid].threads_context[pid]
534
+
535
+ if not thread_context.cur_slice_stack:
536
+ logger.warning(f"No slice is currently being traced.")
537
+ return
538
+
539
+ ts = ts or __get_cur_time_ns()
540
+
541
+ slice_info = thread_context.cur_slice_stack[-1]
542
+ slice_info.span.add_event(name=name, timestamp=ts)
543
+
544
+
545
+ # Add attrs to the current slice on the same thread with the same rid.
546
+ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
547
+ if not tracing_enabled:
548
+ return
549
+
550
+ rid = str(rid)
551
+ if rid not in reqs_context:
552
+ return
553
+
554
+ pid = threading.get_native_id()
555
+ if pid not in reqs_context[rid].threads_context:
556
+ return
557
+
558
+ thread_context = reqs_context[rid].threads_context[pid]
559
+
560
+ if not thread_context.cur_slice_stack:
561
+ logger.warning(f"No slice is currently being traced.")
562
+ return
563
+
564
+ slice_info = thread_context.cur_slice_stack[-1]
565
+ slice_info.span.set_attributes(attrs)
566
+
567
+
568
+ def trace_slice_batch(
569
+ name: str,
570
+ reqs: List[Req],
571
+ ):
572
+ for req in reqs:
573
+ trace_slice(
574
+ name,
575
+ req.rid,
576
+ auto_next_anon=not req.finished(),
577
+ thread_finish_flag=req.finished(),
578
+ )
@@ -30,8 +30,9 @@ from sglang.srt.model_executor.forward_batch_info import (
30
30
  )
31
31
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
32
32
  from sglang.srt.operations_strategy import OperationsStrategy
33
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
34
- from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
33
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
34
+ from sglang.srt.speculative.spec_info import SpecInput
35
+ from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
35
36
 
36
37
  if TYPE_CHECKING:
37
38
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
48
49
 
49
50
  def get_token_num_per_seq(
50
51
  forward_mode: ForwardMode,
51
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
52
+ spec_info: Optional[SpecInput] = None,
52
53
  ):
53
54
  if forward_mode.is_target_verify():
54
55
  return spec_info.draft_token_num
@@ -273,7 +274,7 @@ def compute_split_token_index(
273
274
  def compute_split_indices_for_cuda_graph_replay(
274
275
  forward_mode: ForwardMode,
275
276
  cuda_graph_num_tokens: int,
276
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
277
+ spec_info: Optional[SpecInput],
277
278
  ):
278
279
  forward_mode_for_tbo_split = (
279
280
  forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
333
334
  forward_mode: ForwardMode,
334
335
  bs: int,
335
336
  num_token_non_padded: int,
336
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
337
+ spec_info: Optional[SpecInput],
337
338
  ):
338
339
  token_num_per_seq = get_token_num_per_seq(
339
340
  forward_mode=forward_mode, spec_info=spec_info
@@ -704,6 +705,8 @@ class TboForwardBatchPreparer:
704
705
  extend_num_tokens=extend_num_tokens,
705
706
  attn_backend=output_attn_backend,
706
707
  num_token_non_padded=out_num_token_non_padded,
708
+ # TODO: handle it when we need TBO + DeepSeek V3.2
709
+ num_token_non_padded_cpu=None,
707
710
  tbo_split_seq_index=None,
708
711
  tbo_parent_token_range=(start_token_index, end_token_index),
709
712
  tbo_children=None,
@@ -0,0 +1,2 @@
1
+ # Temporarily do this to avoid changing all imports in the repo
2
+ from .common import *