sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,199 @@
1
+ """Merge Chrome trace files from multiple ranks (TP, DP, PP, EP) into a single trace."""
2
+
3
+ import glob
4
+ import gzip
5
+ import json
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ProfileMerger:
15
+ """Merge profile traces from all parallelism types: TP, DP, PP, EP."""
16
+
17
+ def __init__(self, output_dir: str, profile_id: str):
18
+ self.output_dir = output_dir
19
+ self.profile_id = profile_id
20
+ self.merged_trace_path = os.path.join(
21
+ output_dir, f"merged-{profile_id}.trace.json.gz"
22
+ )
23
+
24
+ # Rank types in priority order (used for sorting and labeling)
25
+ self.rank_types = ["tp", "dp", "pp", "ep"]
26
+
27
+ # Sort index multipliers: DP (highest) > EP > PP > TP (lowest)
28
+ # These ensure proper visual ordering in trace viewer
29
+ self.sort_index_multipliers = {
30
+ "dp_rank": 100_000_000,
31
+ "ep_rank": 1_000_000,
32
+ "pp_rank": 10_000,
33
+ "tp_rank": 100,
34
+ }
35
+
36
+ # PID threshold for sort_index updates (only update for system PIDs < 1000)
37
+ self.pid_sort_index_threshold = 1000
38
+
39
+ def merge_chrome_traces(self) -> str:
40
+ """Merge Chrome traces from all ranks into a single trace.
41
+
42
+ Returns:
43
+ Path to merged trace file.
44
+
45
+ Raises:
46
+ ValueError: If no trace files found.
47
+ """
48
+ trace_files = self._discover_trace_files()
49
+ if not trace_files:
50
+ raise ValueError(f"No trace files found for profile_id: {self.profile_id}")
51
+
52
+ logger.info(f"Found {len(trace_files)} trace files to merge")
53
+
54
+ merged_trace = {"traceEvents": []}
55
+ all_device_properties = []
56
+
57
+ for trace_file in sorted(trace_files, key=self._get_rank_sort_key):
58
+ rank_info = self._extract_rank_info(trace_file)
59
+ logger.info(f"Processing {trace_file} with rank info: {rank_info}")
60
+
61
+ output = self._handle_file(trace_file, rank_info)
62
+
63
+ merged_trace["traceEvents"].extend(output["traceEvents"])
64
+
65
+ if "deviceProperties" in output:
66
+ all_device_properties.extend(output["deviceProperties"])
67
+ del output["deviceProperties"]
68
+
69
+ for key, value in output.items():
70
+ if key != "traceEvents" and key not in merged_trace:
71
+ merged_trace[key] = value
72
+
73
+ if all_device_properties:
74
+ merged_trace["deviceProperties"] = all_device_properties
75
+
76
+ with gzip.open(self.merged_trace_path, "wb") as f:
77
+ f.write(json.dumps(merged_trace).encode("utf-8"))
78
+
79
+ logger.info(f"Merged profile saved to: {self.merged_trace_path}")
80
+ logger.info(f"Total events merged: {len(merged_trace['traceEvents'])}")
81
+
82
+ return self.merged_trace_path
83
+
84
+ def _discover_trace_files(self) -> List[str]:
85
+ """Discover trace files matching profile_id (supports TP/DP/PP/EP formats)."""
86
+ patterns = [f"{self.profile_id}*.trace.json.gz"]
87
+
88
+ trace_files = []
89
+ for pattern in patterns:
90
+ search_pattern = os.path.join(self.output_dir, pattern)
91
+ trace_files.extend(glob.glob(search_pattern))
92
+
93
+ trace_files = [
94
+ f
95
+ for f in trace_files
96
+ if not f.endswith(f"merged-{self.profile_id}.trace.json.gz")
97
+ and not f.endswith("-memory.pickle")
98
+ and "TP-" in f
99
+ ]
100
+ trace_files = list(set(trace_files))
101
+ return trace_files
102
+
103
+ def _extract_rank_info(self, filename: str) -> Dict[str, int]:
104
+ """Extract rank info (TP/DP/PP/EP) from filename."""
105
+ basename = os.path.basename(filename)
106
+ rank_info = {}
107
+
108
+ for rank_type in self.rank_types:
109
+ match = re.search(rf"{rank_type.upper()}-(\d+)", basename)
110
+ if match:
111
+ rank_info[f"{rank_type}_rank"] = int(match.group(1))
112
+
113
+ return rank_info
114
+
115
+ def _create_rank_label(self, rank_info: Dict[str, int]) -> str:
116
+ parts = []
117
+ for rank_type in self.rank_types:
118
+ rank_key = f"{rank_type}_rank"
119
+ if rank_key in rank_info:
120
+ parts.append(f"{rank_type.upper()}{rank_info[rank_key]:02d}")
121
+
122
+ return f"[{'-'.join(parts)}]" if parts else "[Unknown]"
123
+
124
+ def _handle_file(self, path: str, rank_info: Dict[str, int]) -> Dict[str, Any]:
125
+ logger.info(f"Processing file: {path}")
126
+
127
+ try:
128
+ with gzip.open(path, "rt", encoding="utf-8") as f:
129
+ trace = json.load(f)
130
+
131
+ output = {
132
+ key: value for key, value in trace.items() if key != "traceEvents"
133
+ }
134
+ output["traceEvents"] = self._process_events(
135
+ trace.get("traceEvents", []), rank_info
136
+ )
137
+ return output
138
+
139
+ except Exception as e:
140
+ logger.error(f"Failed to process trace file {path}: {e}")
141
+ return {"traceEvents": []}
142
+
143
+ def _process_events(
144
+ self, events: List[Dict], rank_info: Dict[str, int]
145
+ ) -> List[Dict]:
146
+ """Process events: update sort_index and add rank labels to PIDs."""
147
+ rank_label = self._create_rank_label(rank_info)
148
+
149
+ for event in events:
150
+ if event.get("name") == "process_sort_index":
151
+ pid = self._maybe_cast_int(event.get("pid"))
152
+ if pid is not None and pid < self.pid_sort_index_threshold:
153
+ event["args"]["sort_index"] = self._calculate_sort_index(
154
+ rank_info, pid
155
+ )
156
+
157
+ event["pid"] = f"{rank_label} {event['pid']}"
158
+
159
+ return events
160
+
161
+ def _calculate_sort_index(self, rank_info: Dict[str, int], pid: int) -> int:
162
+ sort_index = pid
163
+ for rank_type, multiplier in self.sort_index_multipliers.items():
164
+ sort_index += rank_info.get(rank_type, 0) * multiplier
165
+ return sort_index
166
+
167
+ def _get_rank_sort_key(self, path: str) -> Tuple[int, int, int, int]:
168
+ rank_info = self._extract_rank_info(path)
169
+ return tuple(
170
+ rank_info.get(f"{rank_type}_rank", 0)
171
+ for rank_type in ["dp", "ep", "pp", "tp"]
172
+ )
173
+
174
+ def _maybe_cast_int(self, x) -> Optional[int]:
175
+ try:
176
+ return int(x)
177
+ except (ValueError, TypeError):
178
+ return None
179
+
180
+ def get_merge_summary(self) -> Dict[str, Any]:
181
+ if not os.path.exists(self.merged_trace_path):
182
+ return {"error": "Merged trace file not found"}
183
+
184
+ try:
185
+ with gzip.open(self.merged_trace_path, "rt") as f:
186
+ merged_data = json.load(f)
187
+
188
+ trace_files = self._discover_trace_files()
189
+
190
+ return {
191
+ "merged_file": self.merged_trace_path,
192
+ "total_events": len(merged_data.get("traceEvents", [])),
193
+ "total_files": len(trace_files),
194
+ "source_files": [os.path.basename(f) for f in trace_files],
195
+ "profile_id": self.profile_id,
196
+ "device_properties_count": len(merged_data.get("deviceProperties", [])),
197
+ }
198
+ except Exception as e:
199
+ return {"error": f"Failed to read merged trace: {str(e)}"}
@@ -66,7 +66,7 @@ class MockModelRunner:
66
66
  enable_memory_saver=False,
67
67
  )
68
68
  # Required by torch native backend
69
- self.server_args = ServerArgs(model_path="fake_model_path")
69
+ self.server_args = ServerArgs(model_path="dummy")
70
70
 
71
71
 
72
72
  @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
@@ -4,7 +4,6 @@ import torch
4
4
 
5
5
  from sglang.srt.configs.model_config import AttentionArch
6
6
  from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
7
- from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
8
7
  from sglang.srt.layers.radix_attention import RadixAttention
9
8
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
10
9
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -2,8 +2,6 @@ import unittest
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
6
- from sglang.srt.layers.radix_attention import RadixAttention
7
5
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
8
6
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
7
  from sglang.test.test_utils import CustomTestCase
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
16
16
  TRTLLMMLABackend,
17
17
  TRTLLMMLADecodeMetadata,
18
18
  )
19
- from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
19
+ from sglang.srt.layers.attention.utils import get_num_page_per_block_flashmla
20
20
  from sglang.srt.layers.radix_attention import RadixAttention
21
21
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
22
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+ from sglang.srt.server_args import (
24
+ ServerArgs,
25
+ get_global_server_args,
26
+ set_global_server_args_for_scheduler,
27
+ )
23
28
  from sglang.srt.utils import is_flashinfer_available
24
29
  from sglang.test.test_utils import CustomTestCase
25
30
 
@@ -104,15 +109,15 @@ TEST_CASES = {
104
109
  "page_size": 32,
105
110
  "description": "Single FP16 vs reference",
106
111
  },
107
- {
108
- "name": "single_fp8",
109
- "batch_size": 1,
110
- "max_seq_len": 64,
111
- "page_size": 64,
112
- "tolerance": 1e-1,
113
- "kv_cache_dtype": torch.float8_e4m3fn,
114
- "description": "Single FP8 vs reference",
115
- },
112
+ # {
113
+ # "name": "single_fp8",
114
+ # "batch_size": 1,
115
+ # "max_seq_len": 64,
116
+ # "page_size": 64,
117
+ # "tolerance": 1e-1,
118
+ # "kv_cache_dtype": torch.float8_e4m3fn,
119
+ # "description": "Single FP8 vs reference",
120
+ # },
116
121
  {
117
122
  "name": "batch_fp16",
118
123
  "batch_size": 32,
@@ -120,15 +125,15 @@ TEST_CASES = {
120
125
  "page_size": 32,
121
126
  "description": "Batch FP16 vs reference",
122
127
  },
123
- {
124
- "name": "batch_fp8",
125
- "batch_size": 32,
126
- "max_seq_len": 64,
127
- "page_size": 64,
128
- "tolerance": 1e-1,
129
- "kv_cache_dtype": torch.float8_e4m3fn,
130
- "description": "Batch FP8 vs reference",
131
- },
128
+ # {
129
+ # "name": "batch_fp8",
130
+ # "batch_size": 32,
131
+ # "max_seq_len": 64,
132
+ # "page_size": 64,
133
+ # "tolerance": 1e-1,
134
+ # "kv_cache_dtype": torch.float8_e4m3fn,
135
+ # "description": "Batch FP8 vs reference",
136
+ # },
132
137
  ],
133
138
  "page_size_consistency": [
134
139
  # Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
@@ -213,13 +218,7 @@ class MockModelRunner:
213
218
  self.page_size = config["page_size"]
214
219
 
215
220
  # Server args stub - needed by attention backends
216
- self.server_args = type(
217
- "ServerArgs",
218
- (),
219
- {
220
- "enable_dp_attention": False, # Default value for testing
221
- },
222
- )
221
+ self.server_args = get_global_server_args()
223
222
 
224
223
  # Model-config stub with MLA attributes
225
224
  self.model_config = type(
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
320
319
  class TestTRTLLMMLA(CustomTestCase):
321
320
  """Test suite for TRTLLM MLA backend with centralized configuration."""
322
321
 
322
+ @classmethod
323
+ def setUpClass(cls):
324
+ """Set up global server args for testing."""
325
+ server_args = ServerArgs(model_path="dummy")
326
+ server_args.enable_dp_attention = False
327
+ set_global_server_args_for_scheduler(server_args)
328
+
329
+ @classmethod
330
+ def tearDownClass(cls):
331
+ pass
332
+
323
333
  def _merge_config(self, test_case):
324
334
  """Merge test case with default configuration."""
325
335
  config = DEFAULT_CONFIG.copy()
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
841
851
  backend.init_forward_metadata(fb)
842
852
 
843
853
  # Verify metadata exists
844
- self.assertIsNotNone(backend.forward_metadata)
845
- self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
854
+ self.assertIsNotNone(backend.forward_decode_metadata)
855
+ self.assertIsInstance(
856
+ backend.forward_decode_metadata, TRTLLMMLADecodeMetadata
857
+ )
846
858
 
847
859
  # Test metadata structure
848
- metadata = backend.forward_metadata
849
- self.assertIsNotNone(
850
- metadata.workspace, "Workspace should be allocated"
851
- )
860
+ metadata = backend.forward_decode_metadata
852
861
  self.assertIsNotNone(
853
862
  metadata.block_kv_indices, "Block KV indices should be created"
854
863
  )
855
864
 
856
- # Test workspace properties
857
- self.assertEqual(metadata.workspace.device.type, "cuda")
858
- self.assertEqual(metadata.workspace.dtype, torch.uint8)
859
- self.assertGreater(
860
- metadata.workspace.numel(), 0, "Workspace should have non-zero size"
861
- )
862
-
863
865
  # Test block KV indices properties
864
866
  self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
865
867
  self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
915
917
 
916
918
  # Should satisfy TRT-LLM and Triton constraints
917
919
  trtllm_constraint = 128 // scenario["page_size"]
918
- constraint_lcm = math.lcm(
919
- trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
920
+ triton_constraint = get_num_page_per_block_flashmla(
921
+ scenario["page_size"]
920
922
  )
923
+ constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
921
924
  self.assertEqual(
922
925
  calculated_blocks % constraint_lcm,
923
926
  0,
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
965
968
 
966
969
  # Initialize metadata
967
970
  backend.init_forward_metadata(fb)
968
- metadata = backend.forward_metadata
971
+ metadata = backend.forward_decode_metadata
969
972
 
970
973
  # Verify KV indices structure
971
974
  block_kv_indices = metadata.block_kv_indices
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
1016
1019
 
1017
1020
  # Verify CUDA graph buffers are allocated
1018
1021
  self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
1019
- self.assertIsNotNone(backend.decode_cuda_graph_workspace)
1020
1022
 
1021
1023
  # Test capture metadata
1022
1024
  seq_lens = torch.full(
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
1038
1040
  self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
1039
1041
  capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
1040
1042
 
1041
- self.assertIsNotNone(capture_metadata.workspace)
1042
1043
  self.assertIsNotNone(capture_metadata.block_kv_indices)
1043
1044
 
1044
1045
  # Test replay with different sequence lengths
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
1061
1062
  )
1062
1063
 
1063
1064
  # Verify replay updated the metadata
1064
- replay_metadata = backend.forward_metadata
1065
+ replay_metadata = backend.forward_decode_metadata
1065
1066
  self.assertIsNotNone(replay_metadata)
1066
- self.assertEqual(
1067
- replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
1068
- )
1069
1067
 
1070
1068
  def test_metadata_consistency_across_calls(self):
1071
1069
  """Test metadata consistency across multiple forward calls."""
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
1083
1081
  config["batch_size"], seq_lens_1, backend, model_runner, config
1084
1082
  )
1085
1083
  backend.init_forward_metadata(fb_1)
1086
- metadata_1 = backend.forward_metadata
1084
+ metadata_1 = backend.forward_decode_metadata
1087
1085
 
1088
1086
  # Second call with same sequence lengths
1089
1087
  seq_lens_2 = torch.tensor([32, 48], device=config["device"])
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
1091
1089
  config["batch_size"], seq_lens_2, backend, model_runner, config
1092
1090
  )
1093
1091
  backend.init_forward_metadata(fb_2)
1094
- metadata_2 = backend.forward_metadata
1092
+ metadata_2 = backend.forward_decode_metadata
1095
1093
 
1096
1094
  # Metadata structure should be consistent
1097
- self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
1098
1095
  self.assertEqual(
1099
1096
  metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
1100
1097
  )
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
1105
1102
  config["batch_size"], seq_lens_3, backend, model_runner, config
1106
1103
  )
1107
1104
  backend.init_forward_metadata(fb_3)
1108
- metadata_3 = backend.forward_metadata
1105
+ metadata_3 = backend.forward_decode_metadata
1109
1106
 
1110
1107
  # Should still have valid structure
1111
- self.assertIsNotNone(metadata_3.workspace)
1112
1108
  self.assertIsNotNone(metadata_3.block_kv_indices)
1113
1109
  self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
1114
1110
 
@@ -1263,6 +1259,178 @@ class TestTRTLLMMLA(CustomTestCase):
1263
1259
  f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
1264
1260
  )
1265
1261
 
1262
+ def test_draft_extend_padding_unpadding_kernels(self):
1263
+ """Test TRTLLM MLA Triton kernels: pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel."""
1264
+
1265
+ # Import the kernels
1266
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
1267
+ pad_draft_extend_query_kernel,
1268
+ unpad_draft_extend_output_kernel,
1269
+ )
1270
+
1271
+ def _create_test_data(
1272
+ self, batch_size, max_seq_len, num_heads, head_dim, dtype=torch.float32
1273
+ ):
1274
+ """Create test data for kernel testing."""
1275
+ device = torch.device("cuda")
1276
+
1277
+ # Create sequence lengths (varying lengths for each batch)
1278
+ seq_lens = torch.randint(
1279
+ 1, max_seq_len + 1, (batch_size,), device=device, dtype=torch.int32
1280
+ )
1281
+
1282
+ # Create cumulative sequence lengths
1283
+ cum_seq_lens = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
1284
+ cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0)
1285
+
1286
+ # Create input query tensor (flattened format)
1287
+ total_tokens = cum_seq_lens[-1].item()
1288
+ q_input = torch.randn(
1289
+ total_tokens, num_heads, head_dim, device=device, dtype=dtype
1290
+ )
1291
+
1292
+ # Create padded query tensor (batch format)
1293
+ padded_q = torch.zeros(
1294
+ batch_size, max_seq_len, num_heads, head_dim, device=device, dtype=dtype
1295
+ )
1296
+
1297
+ return q_input, padded_q, seq_lens, cum_seq_lens
1298
+
1299
+ def _create_test_output_data(
1300
+ self,
1301
+ batch_size,
1302
+ token_per_batch,
1303
+ tp_q_head_num,
1304
+ v_head_dim,
1305
+ dtype=torch.float32,
1306
+ ):
1307
+ """Create test data for unpad kernel testing."""
1308
+ device = torch.device("cuda")
1309
+
1310
+ # Create accept lengths (varying lengths for each batch)
1311
+ accept_lengths = torch.randint(
1312
+ 1, token_per_batch + 1, (batch_size,), device=device, dtype=torch.int32
1313
+ )
1314
+
1315
+ # Create cumulative accept lengths
1316
+ cum_accept_lengths = torch.zeros(
1317
+ batch_size + 1, device=device, dtype=torch.int32
1318
+ )
1319
+ cum_accept_lengths[1:] = torch.cumsum(accept_lengths, dim=0)
1320
+
1321
+ # Create raw output tensor (batch format)
1322
+ raw_out = torch.randn(
1323
+ batch_size,
1324
+ token_per_batch,
1325
+ tp_q_head_num,
1326
+ v_head_dim,
1327
+ device=device,
1328
+ dtype=dtype,
1329
+ )
1330
+
1331
+ # Create output tensor (flattened format)
1332
+ total_tokens = cum_accept_lengths[-1].item()
1333
+ output = torch.empty(
1334
+ total_tokens, tp_q_head_num, v_head_dim, device=device, dtype=dtype
1335
+ )
1336
+
1337
+ return raw_out, output, accept_lengths, cum_accept_lengths
1338
+
1339
+ # Test 1: pad_draft_extend_query_kernel basic functionality
1340
+ with self.subTest(test="pad_kernel_basic"):
1341
+ batch_size = 4
1342
+ max_seq_len = 8
1343
+ num_heads = 16
1344
+ head_dim = 64
1345
+
1346
+ q_input, padded_q, seq_lens, cum_seq_lens = _create_test_data(
1347
+ self, batch_size, max_seq_len, num_heads, head_dim
1348
+ )
1349
+
1350
+ # Launch kernel
1351
+ BLOCK_SIZE = 64
1352
+ grid = (batch_size * max_seq_len,)
1353
+
1354
+ pad_draft_extend_query_kernel[grid](
1355
+ q_ptr=q_input,
1356
+ padded_q_ptr=padded_q,
1357
+ seq_lens_q_ptr=seq_lens,
1358
+ cumsum_ptr=cum_seq_lens,
1359
+ batch_size=batch_size,
1360
+ max_seq_len=max_seq_len,
1361
+ num_heads=num_heads,
1362
+ head_dim=head_dim,
1363
+ BLOCK_SIZE=BLOCK_SIZE,
1364
+ )
1365
+
1366
+ # Verify the padding worked correctly
1367
+ for i in range(batch_size):
1368
+ seq_len = seq_lens[i].item()
1369
+
1370
+ # Check that valid positions are copied correctly
1371
+ for pos in range(seq_len):
1372
+ input_start = cum_seq_lens[i].item()
1373
+ input_pos = input_start + pos
1374
+
1375
+ # Compare input and output for valid positions
1376
+ input_data = q_input[input_pos]
1377
+ output_data = padded_q[i, pos]
1378
+
1379
+ torch.testing.assert_close(
1380
+ input_data, output_data, rtol=1e-5, atol=1e-6
1381
+ )
1382
+
1383
+ # Check that invalid positions are zero
1384
+ for pos in range(seq_len, max_seq_len):
1385
+ output_data = padded_q[i, pos]
1386
+ self.assertTrue(
1387
+ torch.allclose(output_data, torch.zeros_like(output_data)),
1388
+ f"Position {pos} in batch {i} should be zero",
1389
+ )
1390
+
1391
+ # Test 2: unpad_draft_extend_output_kernel basic functionality
1392
+ with self.subTest(test="unpad_kernel_basic"):
1393
+ batch_size = 4
1394
+ token_per_batch = 8
1395
+ tp_q_head_num = 16
1396
+ v_head_dim = 64
1397
+
1398
+ raw_out, output, accept_lengths, cum_accept_lengths = (
1399
+ _create_test_output_data(
1400
+ self, batch_size, token_per_batch, tp_q_head_num, v_head_dim
1401
+ )
1402
+ )
1403
+
1404
+ # Launch kernel
1405
+ BLOCK_SIZE = 64
1406
+ grid = (batch_size * token_per_batch,)
1407
+
1408
+ unpad_draft_extend_output_kernel[grid](
1409
+ raw_out_ptr=raw_out,
1410
+ output_ptr=output,
1411
+ accept_length_ptr=accept_lengths,
1412
+ cumsum_ptr=cum_accept_lengths,
1413
+ batch_size=batch_size,
1414
+ token_per_batch=token_per_batch,
1415
+ tp_q_head_num=tp_q_head_num,
1416
+ v_head_dim=v_head_dim,
1417
+ BLOCK_SIZE=BLOCK_SIZE,
1418
+ )
1419
+
1420
+ # Verify the unpadding worked correctly
1421
+ for i in range(batch_size):
1422
+ accept_len = accept_lengths[i].item()
1423
+ output_start = cum_accept_lengths[i].item()
1424
+
1425
+ # Check that valid positions are copied correctly
1426
+ for pos in range(accept_len):
1427
+ input_data = raw_out[i, pos]
1428
+ output_data = output[output_start + pos]
1429
+
1430
+ torch.testing.assert_close(
1431
+ input_data, output_data, rtol=1e-5, atol=1e-6
1432
+ )
1433
+
1266
1434
 
1267
1435
  if __name__ == "__main__":
1268
1436
  unittest.main()
@@ -1,16 +1,14 @@
1
1
  import argparse
2
2
  import ast
3
3
  import asyncio
4
- import json
5
4
  import re
6
5
  import time
6
+ from typing import Optional
7
7
 
8
8
  import numpy as np
9
9
 
10
10
  import sglang as sgl
11
- from sglang.lang.api import set_default_backend
12
- from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
13
- from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
11
+ from sglang.utils import download_and_cache_file, read_jsonl
14
12
 
15
13
  INVALID = -9999999
16
14