sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,228 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
2
+
3
+ import dataclasses
4
+ import logging
5
+ from contextlib import ExitStack
6
+ from typing import Any, Callable, Optional, Union
7
+ from unittest.mock import patch
8
+
9
+ import torch
10
+ import torch.fx as fx
11
+
12
+ import sglang.srt.compilation.weak_ref_tensor_jit # noqa: F401
13
+ from sglang.srt.compilation.compilation_config import CompilationConfig
14
+ from sglang.srt.compilation.compilation_counter import compilation_counter
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def weak_ref_tensor(tensor: Any) -> Any:
20
+ """
21
+ Create a weak reference to a tensor.
22
+ The new tensor will share the same data as the original tensor,
23
+ but will not keep the original tensor alive.
24
+ """
25
+ if isinstance(tensor, torch.Tensor):
26
+ # TODO(yuwei): introduce weak_ref_tensor from sgl_kernel
27
+ return torch.ops.jit_weak_ref_tensor.weak_ref_tensor(tensor)
28
+ return tensor
29
+
30
+
31
+ def weak_ref_tensors(
32
+ tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
33
+ ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
34
+ """
35
+ Convenience function to create weak references to tensors,
36
+ for single tensor, list of tensors or tuple of tensors.
37
+ """
38
+ if isinstance(tensors, torch.Tensor):
39
+ return weak_ref_tensor(tensors)
40
+ if isinstance(tensors, list):
41
+ return [weak_ref_tensor(t) for t in tensors]
42
+ if isinstance(tensors, tuple):
43
+ return tuple(weak_ref_tensor(t) for t in tensors)
44
+ raise ValueError("Invalid type for tensors")
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class ConcreteSizeEntry:
49
+ runtime_shape: int
50
+ need_to_compile: bool # the size is in compile_sizes
51
+ use_cudagraph: bool # the size is in cudagraph_capture_sizes
52
+
53
+ compiled: bool = False
54
+ runnable: Callable = None # type: ignore
55
+ num_finished_warmup: int = 0
56
+ cudagraph: Optional[torch.cuda.CUDAGraph] = None
57
+ output: Optional[Any] = None
58
+
59
+ # for cudagraph debugging, track the input addresses
60
+ # during capture, and check if they are the same during replay
61
+ input_addresses: Optional[list[int]] = None
62
+
63
+
64
+ class CUDAPiecewiseBackend:
65
+
66
+ def __init__(
67
+ self,
68
+ graph: fx.GraphModule,
69
+ compile_config: CompilationConfig,
70
+ inductor_config: dict[str, Any],
71
+ graph_pool: Any,
72
+ piecewise_compile_index: int,
73
+ total_piecewise_compiles: int,
74
+ sym_shape_indices: list[int],
75
+ compiled_graph_for_general_shape: Callable,
76
+ sglang_backend,
77
+ ):
78
+ """
79
+ The backend for piecewise compilation.
80
+ It mainly handles the compilation and cudagraph capturing.
81
+
82
+ We will compile `self.graph` once for the general shape,
83
+ and then compile for different shapes specified in
84
+ `compilation_config.compile_sizes`.
85
+
86
+ Independently, we will capture cudagraph for different shapes.
87
+
88
+ If a shape needs both compilation and cudagraph, we will
89
+ compile it first, and then capture cudagraph.
90
+ """
91
+ self.graph = graph
92
+ self.inductor_config = inductor_config
93
+ self.graph_pool = graph_pool
94
+ self.piecewise_compile_index = piecewise_compile_index
95
+ self.total_piecewise_compiles = total_piecewise_compiles
96
+ self.sglang_backend = sglang_backend
97
+
98
+ self.is_first_graph = piecewise_compile_index == 0
99
+ self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
100
+
101
+ self.compile_sizes: set[int] = set([])
102
+ self.compile_config = compile_config
103
+ self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())
104
+
105
+ self.first_run_finished = False
106
+
107
+ self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
108
+
109
+ self.sym_shape_indices = sym_shape_indices
110
+
111
+ self.is_debugging_mode = True
112
+
113
+ # the entries for different shapes that we need to either
114
+ # compile or capture cudagraph
115
+ self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
116
+
117
+ # to_be_compiled_sizes tracks the remaining sizes to compile,
118
+ # and updates during the compilation process, so we need to copy it
119
+ self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
120
+ for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
121
+ self.concrete_size_entries[shape] = ConcreteSizeEntry(
122
+ runtime_shape=shape,
123
+ need_to_compile=shape in self.compile_sizes,
124
+ use_cudagraph=shape in self.cudagraph_capture_sizes,
125
+ )
126
+
127
+ def check_for_ending_compilation(self):
128
+ if self.is_last_graph and not self.to_be_compiled_sizes:
129
+ # no specific sizes to compile
130
+ # save the hash of the inductor graph for the next run
131
+ self.sglang_backend.compiler_manager.save_to_file()
132
+
133
+ def __call__(self, *args) -> Any:
134
+ if not self.first_run_finished:
135
+ self.first_run_finished = True
136
+ self.check_for_ending_compilation()
137
+ return self.compiled_graph_for_general_shape(*args)
138
+ runtime_shape = args[self.sym_shape_indices[0]]
139
+ if runtime_shape not in self.concrete_size_entries:
140
+ # we don't need to do anything for this shape
141
+ return self.compiled_graph_for_general_shape(*args)
142
+
143
+ entry = self.concrete_size_entries[runtime_shape]
144
+
145
+ if entry.runnable is None:
146
+ entry.runnable = self.compiled_graph_for_general_shape
147
+
148
+ if entry.need_to_compile and not entry.compiled:
149
+ entry.compiled = True
150
+ self.to_be_compiled_sizes.remove(runtime_shape)
151
+ # args are real arguments
152
+ entry.runnable = self.sglang_backend.compiler_manager.compile(
153
+ self.graph,
154
+ args,
155
+ self.inductor_config,
156
+ graph_index=self.piecewise_compile_index,
157
+ num_graphs=self.total_piecewise_compiles,
158
+ runtime_shape=runtime_shape,
159
+ )
160
+
161
+ # finished compilations for all required shapes
162
+ if self.is_last_graph and not self.to_be_compiled_sizes:
163
+ self.check_for_ending_compilation()
164
+
165
+ # Skip CUDA graphs if this entry doesn't use them OR
166
+ # if we're supposed to skip them globally
167
+ # skip_cuda_graphs = get_forward_context().skip_cuda_graphs
168
+ # if not entry.use_cudagraph or skip_cuda_graphs:
169
+ # return entry.runnable(*args)
170
+
171
+ if entry.cudagraph is None:
172
+ if entry.num_finished_warmup < 1: # noqa
173
+ entry.num_finished_warmup += 1
174
+ return entry.runnable(*args)
175
+
176
+ input_addresses = [
177
+ x.data_ptr() for x in args if isinstance(x, torch.Tensor)
178
+ ]
179
+ entry.input_addresses = input_addresses
180
+ cudagraph = torch.cuda.CUDAGraph()
181
+
182
+ with ExitStack() as stack:
183
+ if not self.is_first_graph:
184
+ # during every model forward, we will capture
185
+ # many pieces of cudagraphs (roughly one per layer).
186
+ # running gc again and again across layers will
187
+ # make the cudagraph capture very slow.
188
+ # therefore, we only run gc for the first graph,
189
+ # and disable gc for the rest of the graphs.
190
+ stack.enter_context(patch("gc.collect", lambda: None))
191
+ stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
192
+
193
+ # mind-exploding: carefully manage the reference and memory.
194
+ with torch.cuda.graph(cudagraph, pool=self.graph_pool):
195
+ # `output` is managed by pytorch's cudagraph pool
196
+ output = entry.runnable(*args)
197
+ if self.is_last_graph:
198
+ # by converting it to weak ref,
199
+ # the original `output` will immediately be released
200
+ # to save memory. It is only safe to do this for
201
+ # the last graph, because the output of the last graph
202
+ # will not be used by any other cuda graph.
203
+ output = weak_ref_tensors(output)
204
+
205
+ # here we always use weak ref for the output
206
+ # to save memory
207
+ entry.output = weak_ref_tensors(output)
208
+ entry.cudagraph = cudagraph
209
+
210
+ compilation_counter.num_cudagraph_captured += 1
211
+
212
+ # important: we need to return the output, rather than
213
+ # the weak ref of the output, so that pytorch can correctly
214
+ # manage the memory during cuda graph capture
215
+ return output
216
+
217
+ if self.is_debugging_mode:
218
+ # check if the input addresses are the same
219
+ new_input_addresses = [
220
+ x.data_ptr() for x in args if isinstance(x, torch.Tensor)
221
+ ]
222
+ assert new_input_addresses == entry.input_addresses, (
223
+ "Input addresses for cudagraphs are different during replay."
224
+ f" Expected {entry.input_addresses}, got {new_input_addresses}"
225
+ )
226
+
227
+ entry.cudagraph.replay()
228
+ return entry.output
@@ -0,0 +1,134 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
2
+
3
+ import logging
4
+ import operator
5
+ from collections.abc import Iterable
6
+ from typing import Optional, Union
7
+
8
+ import torch
9
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized
10
+
11
+ from sglang.srt.compilation.fx_utils import is_func
12
+ from sglang.srt.compilation.inductor_pass import SGLangInductorPass
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FixFunctionalizationPass(SGLangInductorPass):
18
+ """
19
+ This pass defunctionalizes certain nodes to avoid redundant tensor copies.
20
+ After this pass, DCE (dead-code elimination) should never be run,
21
+ as de-functionalized nodes may appear as dead code.
22
+
23
+ To add new nodes to defunctionalize, add to the if-elif chain in __call__.
24
+ """
25
+
26
+ def __call__(self, graph: torch.fx.Graph):
27
+ self.begin()
28
+ self.dump_graph(graph, "before_fix_functionalization")
29
+
30
+ self.nodes_to_remove: list[torch.fx.Node] = []
31
+ count = 0
32
+ for node in graph.nodes:
33
+ if not is_func(node, auto_functionalized):
34
+ continue # Avoid deep if-elif nesting
35
+ count += 1
36
+
37
+ self.dump_graph(graph, "before_fix_functionalization_cleanup")
38
+
39
+ # Remove the nodes all at once
40
+ count_removed = len(self.nodes_to_remove)
41
+ for node in self.nodes_to_remove:
42
+ graph.erase_node(node)
43
+
44
+ logger.debug(
45
+ "De-functionalized %s nodes, removed %s nodes", count, count_removed
46
+ )
47
+ self.dump_graph(graph, "after_fix_functionalization")
48
+ self.end_and_log()
49
+
50
+ def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
51
+ """
52
+ Stage a node (or nodes) for removal at the end of the pass.
53
+ """
54
+ if isinstance(node_or_nodes, torch.fx.Node):
55
+ self.nodes_to_remove.append(node_or_nodes)
56
+ else:
57
+ self.nodes_to_remove.extend(node_or_nodes)
58
+
59
+ def defunctionalize(
60
+ self,
61
+ graph: torch.fx.Graph,
62
+ node: torch.fx.Node,
63
+ mutated_args: dict[int, Union[torch.fx.Node, str]],
64
+ args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
65
+ ):
66
+ """
67
+ De-functionalize a node by replacing it with a call to the original.
68
+ It also replaces the getitem users with the mutated arguments.
69
+ See replace_users_with_mutated_args and insert_defunctionalized.
70
+ """
71
+ self.replace_users_with_mutated_args(node, mutated_args)
72
+ self.insert_defunctionalized(graph, node, args=args)
73
+ self._remove(node)
74
+
75
+ def replace_users_with_mutated_args(
76
+ self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
77
+ ):
78
+ """
79
+ Replace all getitem users of the auto-functionalized node with the
80
+ mutated arguments.
81
+ :param node: The auto-functionalized node
82
+ :param mutated_args: The mutated arguments, indexed by getitem index.
83
+ If the value of an arg is a string, `node.kwargs[arg]` is used.
84
+ """
85
+ for idx, user in self.getitem_users(node).items():
86
+ arg = mutated_args[idx]
87
+ arg = node.kwargs[arg] if isinstance(arg, str) else arg
88
+ user.replace_all_uses_with(arg)
89
+ self._remove(user)
90
+
91
+ def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
92
+ """
93
+ Returns the operator.getitem users of the auto-functionalized node,
94
+ indexed by the index they are getting.
95
+ """
96
+ users = {}
97
+ for user in node.users:
98
+ if is_func(user, operator.getitem):
99
+ idx = user.args[1]
100
+ users[idx] = user
101
+ return users
102
+
103
+ def insert_defunctionalized(
104
+ self,
105
+ graph: torch.fx.Graph,
106
+ node: torch.fx.Node,
107
+ args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
108
+ ):
109
+ """
110
+ Insert a new defunctionalized node into the graph before node.
111
+ If one of the kwargs is 'out', provide args directly,
112
+ as node.kwargs cannot be used.
113
+ See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
114
+
115
+ :param graph: Graph to insert the defunctionalized node into
116
+ :param node: The auto-functionalized node to defunctionalize
117
+ :param args: If we cannot use kwargs, specify args directly.
118
+ If an arg is a string, `node.kwargs[arg]` is used.
119
+ """ # noqa: E501
120
+ assert is_func(
121
+ node, auto_functionalized
122
+ ), f"node must be auto-functionalized, is {node} instead"
123
+
124
+ # Create a new call to the original function
125
+ with graph.inserting_before(node):
126
+ function = node.args[0]
127
+ if args is None:
128
+ graph.call_function(function, kwargs=node.kwargs)
129
+ else:
130
+ # Args passed as strings refer to items in node.kwargs
131
+ args = tuple(
132
+ node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
133
+ )
134
+ graph.call_function(function, args=args)
@@ -0,0 +1,83 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
2
+
3
+ import operator
4
+ from collections.abc import Iterable, Iterator
5
+ from typing import Optional
6
+
7
+ from torch import fx
8
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
+ from torch._ops import OpOverload
10
+
11
+
12
+ def is_func(node: fx.Node, target) -> bool:
13
+ return node.op == "call_function" and node.target == target
14
+
15
+
16
+ def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
17
+ return is_func(node, auto_functionalized) and node.args[0] == op
18
+
19
+
20
+ # Returns the first specified node with the given op (if it exists)
21
+ def find_specified_fn_maybe(
22
+ nodes: Iterable[fx.Node], op: OpOverload
23
+ ) -> Optional[fx.Node]:
24
+ for node in nodes:
25
+ if node.target == op:
26
+ return node
27
+ return None
28
+
29
+
30
+ # Returns the first specified node with the given op
31
+ def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
32
+ node = find_specified_fn_maybe(nodes, op)
33
+ assert node is not None, f"Could not find {op} in nodes {nodes}"
34
+ return node
35
+
36
+
37
+ # Returns the first auto_functionalized node with the given op (if it exists)
38
+ def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
39
+ for node in nodes:
40
+ if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
41
+ return node
42
+ return None
43
+
44
+
45
+ # Returns the first auto_functionalized node with the given op
46
+ def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
47
+ node = find_auto_fn_maybe(nodes, op)
48
+ assert node is not None, f"Could not find {op} in nodes {nodes}"
49
+ return node
50
+
51
+
52
+ # Returns the getitem node that extracts the idx-th element from node
53
+ # (if it exists)
54
+ def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
55
+ for user in node.users:
56
+ if is_func(user, operator.getitem) and user.args[1] == idx:
57
+ return user
58
+ return None
59
+
60
+
61
+ # Returns the getitem node that extracts the idx-th element from node
62
+ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
63
+ ret = find_getitem_maybe(node, idx)
64
+ assert ret is not None, f"Could not find getitem {idx} in node {node}"
65
+ return ret
66
+
67
+
68
+ # An auto-functionalization-aware utility for finding nodes with a specific op
69
+ def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
70
+ if not op._schema.is_mutable:
71
+ yield from graph.find_nodes(op="call_function", target=op)
72
+
73
+ for n in graph.find_nodes(op="call_function", target=auto_functionalized):
74
+ if n.args[0] == op:
75
+ yield n
76
+
77
+
78
+ # Asserts that the node only has one user and returns it
79
+ # Even if a node has only 1 user, it might share storage with another node,
80
+ # which might need to be taken into account.
81
+ def get_only_user(node: fx.Node) -> fx.Node:
82
+ assert len(node.users) == 1
83
+ return next(iter(node.users))
@@ -0,0 +1,140 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py
2
+
3
+ import hashlib
4
+ import inspect
5
+ import json
6
+ import logging
7
+ import time
8
+ import types
9
+ from contextlib import contextmanager
10
+ from typing import Any, Callable, Optional, Union
11
+
12
+ import torch
13
+ from torch import fx
14
+ from torch._dynamo.utils import lazy_format_graph_code
15
+ from torch._inductor.custom_graph_pass import CustomGraphPass
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ _pass_context = None
20
+
21
+
22
+ class PassContext:
23
+
24
+ def __init__(self, runtime_shape: Optional[int]):
25
+ self.runtime_shape = runtime_shape
26
+
27
+
28
+ def get_pass_context() -> PassContext:
29
+ """Get the current pass context."""
30
+ assert _pass_context is not None
31
+ return _pass_context
32
+
33
+
34
+ @contextmanager
35
+ def pass_context(runtime_shape: Optional[int]):
36
+ """A context manager that stores the current pass context,
37
+ usually it is a list of sizes to specialize.
38
+ """
39
+ global _pass_context
40
+ prev_context = _pass_context
41
+ _pass_context = PassContext(runtime_shape)
42
+ try:
43
+ yield
44
+ finally:
45
+ _pass_context = prev_context
46
+
47
+
48
+ class InductorPass(CustomGraphPass):
49
+ """
50
+ A custom graph pass that uses a hash of its source as the UUID.
51
+ This is defined as a convenience and should work in most cases.
52
+ """
53
+
54
+ def uuid(self) -> Any:
55
+ """
56
+ Provide a unique identifier for the pass, used in Inductor code cache.
57
+ This should depend on the pass implementation, so that changes to the
58
+ pass result in recompilation.
59
+ By default, the object source is hashed.
60
+ """
61
+ return InductorPass.hash_source(self)
62
+
63
+ @staticmethod
64
+ def hash_source(*srcs: Union[str, Any]):
65
+ """
66
+ Utility method to hash the sources of functions or objects.
67
+ :param srcs: strings or objects to add to the hash.
68
+ Objects and functions have their source inspected.
69
+ :return:
70
+ """
71
+ hasher = hashlib.sha256()
72
+ for src in srcs:
73
+ if isinstance(src, str):
74
+ src_str = src
75
+ elif isinstance(src, types.FunctionType):
76
+ src_str = inspect.getsource(src)
77
+ else:
78
+ src_str = inspect.getsource(src.__class__)
79
+ hasher.update(src_str.encode("utf-8"))
80
+ return hasher.hexdigest()
81
+
82
+ @staticmethod
83
+ def hash_dict(dict_: dict[Any, Any]):
84
+ """
85
+ Utility method to hash a dictionary, can alternatively be used for uuid.
86
+ :return: A sha256 hash of the json rep of the dictionary.
87
+ """
88
+ encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
89
+ return hashlib.sha256(encoded).hexdigest()
90
+
91
+ def is_applicable_for_shape(self, shape: Optional[int]):
92
+ return True
93
+
94
+
95
+ class CallableInductorPass(InductorPass):
96
+ """
97
+ This class is a wrapper for a callable that automatically provides an
98
+ implementation of the UUID.
99
+ """
100
+
101
+ def __init__(
102
+ self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None
103
+ ):
104
+ self.callable = callable
105
+ self._uuid = self.hash_source(callable) if uuid is None else uuid
106
+
107
+ def __call__(self, graph: torch.fx.Graph):
108
+ self.callable(graph)
109
+
110
+ def uuid(self) -> Any:
111
+ return self._uuid
112
+
113
+
114
+ class SGLangInductorPass(InductorPass):
115
+
116
+ def __init__(
117
+ self,
118
+ ):
119
+ self.pass_name = self.__class__.__name__
120
+
121
+ def dump_graph(self, graph: torch.fx.Graph, stage: str):
122
+ lazy_format_graph_code(stage, graph.owning_module)
123
+
124
+ def begin(self):
125
+ self._start_time = time.perf_counter_ns()
126
+
127
+ def end_and_log(self):
128
+ self._end_time = time.perf_counter_ns()
129
+ duration_ms = float(self._end_time - self._start_time) / 1.0e6
130
+ logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
131
+
132
+
133
+ class PrinterInductorPass(SGLangInductorPass):
134
+
135
+ def __init__(self, name: str):
136
+ super().__init__()
137
+ self.name = name
138
+
139
+ def __call__(self, graph: torch.fx.Graph):
140
+ self.dump_graph(graph, self.name)
@@ -0,0 +1,66 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py
2
+
3
+ import logging
4
+
5
+ from torch import fx as fx
6
+
7
+ from sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass
8
+ from sglang.srt.compilation.inductor_pass import (
9
+ CustomGraphPass,
10
+ InductorPass,
11
+ SGLangInductorPass,
12
+ get_pass_context,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class PostGradPassManager(CustomGraphPass):
19
+ """
20
+ The pass manager for post-grad passes.
21
+ It handles configuration, adding custom passes, and running passes.
22
+ It supports uuid for the Inductor code cache. That includes torch<2.6
23
+ support using pickling (in .inductor_pass.CustomGraphPass).
24
+
25
+ The order of the post-grad post-passes is:
26
+ 1. passes (constructor parameter)
27
+ 2. default passes (NoopEliminationPass, FusionPass)
28
+ 3. config["post_grad_custom_post_pass"] (if it exists)
29
+ 4. fix_functionalization
30
+ This way, all passes operate on a functionalized graph.
31
+ """
32
+
33
+ def __init__(self):
34
+ self.passes: list[SGLangInductorPass] = []
35
+
36
+ def __call__(self, graph: fx.Graph):
37
+ shape = get_pass_context().runtime_shape
38
+ for pass_ in self.passes:
39
+ if pass_.is_applicable_for_shape(shape):
40
+ pass_(graph)
41
+
42
+ # always run fix_functionalization last
43
+ self.fix_functionalization(graph)
44
+
45
+ def configure(
46
+ self,
47
+ ):
48
+ self.pass_config = dict()
49
+ self.fix_functionalization = FixFunctionalizationPass()
50
+
51
+ def add(self, pass_: InductorPass):
52
+ assert isinstance(pass_, InductorPass)
53
+ self.passes.append(pass_)
54
+
55
+ def uuid(self):
56
+ """
57
+ The PostGradPassManager is set as a custom pass in the Inductor and
58
+ affects compilation caching. Its uuid depends on the UUIDs of all
59
+ dependent passes and the pass config. See InductorPass for more info.
60
+ """
61
+ pass_manager_uuid = "fshdakhsa"
62
+ state = {"pass_config": pass_manager_uuid, "passes": []}
63
+ for pass_ in self.passes:
64
+ state["passes"].append(pass_.uuid())
65
+ state["passes"].append(self.fix_functionalization.uuid())
66
+ return InductorPass.hash_dict(state)
@@ -0,0 +1,40 @@
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass
3
+ from typing import Any, List, Optional
4
+
5
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
+
7
+
8
+ @dataclass
9
+ class ForwardContext:
10
+ def __init__(self):
11
+ self.forward_batch = None
12
+ self.attention_layer = None
13
+
14
+ def set_forward_batch(self, forward_batch: ForwardBatch):
15
+ self.forward_batch = forward_batch
16
+
17
+ def set_attention_layers(self, layers: List[Any]):
18
+ self.attention_layers = layers
19
+
20
+
21
+ _forward_context: Optional[ForwardContext] = None
22
+
23
+
24
+ def get_forward_context() -> Optional[ForwardContext]:
25
+ if _forward_context is None:
26
+ return None
27
+ return _forward_context
28
+
29
+
30
+ @contextmanager
31
+ def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any]):
32
+ global _forward_context
33
+ prev_forward_context = _forward_context
34
+ _forward_context = ForwardContext()
35
+ _forward_context.set_forward_batch(forward_batch)
36
+ _forward_context.set_attention_layers(attention_layers)
37
+ try:
38
+ yield
39
+ finally:
40
+ _forward_context = prev_forward_context