sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,437 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
2
+
3
+
4
+ import ast
5
+ import dataclasses
6
+ import logging
7
+ import os
8
+ import pprint
9
+ import time
10
+ from collections.abc import Sequence
11
+ from contextlib import contextmanager
12
+ from typing import Any, Callable, Optional
13
+
14
+ import torch
15
+ import torch.fx as fx
16
+ from torch._dispatch.python import enable_python_dispatcher
17
+
18
+ from sglang.srt.compilation.compilation_config import CompilationConfig
19
+ from sglang.srt.compilation.compilation_counter import compilation_counter
20
+ from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
21
+ from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
22
+ from sglang.srt.compilation.pass_manager import PostGradPassManager
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def make_compiler(config: CompilationConfig):
28
+ if config.compiler == "eager":
29
+ return EagerAdapter()
30
+ elif config.compiler == "inductor":
31
+ return InductorAdaptor()
32
+ else:
33
+ raise ValueError(f"Unknown compiler: {config.compiler}")
34
+
35
+
36
+ class CompilerManager:
37
+ def __init__(
38
+ self,
39
+ config: CompilationConfig,
40
+ ):
41
+ self.cache = dict()
42
+ self.is_cache_updated = False
43
+ self.compiler = make_compiler(config)
44
+
45
+ def compute_hash(self):
46
+ return self.compiler.compute_hash()
47
+
48
+ def initialize_cache(
49
+ self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
50
+ ):
51
+ self.disable_cache = disable_cache
52
+ self.cache_dir = cache_dir
53
+ self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py")
54
+
55
+ if not disable_cache and os.path.exists(self.cache_file_path):
56
+ with open(self.cache_file_path) as f:
57
+ self.cache = ast.literal_eval(f.read())
58
+
59
+ self.compiler.initialize_cache(
60
+ cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
61
+ )
62
+
63
+ def save_to_file(self):
64
+ if self.disable_cache or not self.is_cache_updated:
65
+ return
66
+ printer = pprint.PrettyPrinter(indent=4)
67
+ data = printer.pformat(self.cache)
68
+ with open(self.cache_file_path, "w") as f:
69
+ f.write(data)
70
+
71
+ def load(
72
+ self,
73
+ graph: fx.GraphModule,
74
+ example_inputs: list[Any],
75
+ graph_index: int,
76
+ runtime_shape: Optional[int] = None,
77
+ ) -> Optional[Callable]:
78
+ handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
79
+ compiled_graph = self.compiler.load(
80
+ handle, graph, example_inputs, graph_index, runtime_shape
81
+ )
82
+ if runtime_shape is None:
83
+ logger.debug(
84
+ "Directly load the %s-th graph for dynamic shape from %s via "
85
+ "handle %s",
86
+ graph_index,
87
+ self.compiler.name,
88
+ handle,
89
+ )
90
+ else:
91
+ logger.debug(
92
+ "Directly load the %s-th graph for shape %s from %s via " "handle %s",
93
+ graph_index,
94
+ str(runtime_shape),
95
+ self.compiler.name,
96
+ handle,
97
+ )
98
+ return compiled_graph
99
+
100
+ def compile(
101
+ self,
102
+ graph: fx.GraphModule,
103
+ example_inputs,
104
+ inductor_config: dict[str, Any],
105
+ graph_index: int = 0,
106
+ num_graphs: int = 1,
107
+ runtime_shape: Optional[int] = None,
108
+ ) -> Any:
109
+ if graph_index == 0:
110
+ # before compiling the first graph, record the start time
111
+ global compilation_start_time
112
+ compilation_start_time = time.time()
113
+
114
+ compilation_counter.num_backend_compilations += 1
115
+
116
+ compiled_graph = None
117
+
118
+ # TODO(Yuwei): support cache loading
119
+
120
+ # no compiler cached the graph, or the cache is disabled,
121
+ # we need to compile it
122
+ if isinstance(self.compiler, InductorAdaptor):
123
+ maybe_key = None
124
+ else:
125
+ maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
126
+ compiled_graph, handle = self.compiler.compile(
127
+ graph, example_inputs, inductor_config, runtime_shape, maybe_key
128
+ )
129
+
130
+ assert compiled_graph is not None, "Failed to compile the graph"
131
+
132
+ # store the artifact in the cache
133
+ if handle is not None:
134
+ self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
135
+ compilation_counter.num_cache_entries_updated += 1
136
+ self.is_cache_updated = True
137
+ if graph_index == 0:
138
+ # adds some info logging for the first graph
139
+ if runtime_shape is None:
140
+ logger.info("Cache the graph for dynamic shape for later use")
141
+ else:
142
+ logger.info(
143
+ "Cache the graph of shape %s for later use", str(runtime_shape)
144
+ )
145
+ if runtime_shape is None:
146
+ logger.debug(
147
+ "Store the %s-th graph for dynamic shape from %s via " "handle %s",
148
+ graph_index,
149
+ self.compiler.name,
150
+ handle,
151
+ )
152
+ else:
153
+ logger.debug(
154
+ "Store the %s-th graph for shape %s from %s via handle %s",
155
+ graph_index,
156
+ str(runtime_shape),
157
+ self.compiler.name,
158
+ handle,
159
+ )
160
+
161
+ # after compiling the last graph, record the end time
162
+ if graph_index == num_graphs - 1:
163
+ now = time.time()
164
+ elapsed = now - compilation_start_time
165
+ if runtime_shape is None:
166
+ logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
167
+ else:
168
+ logger.info(
169
+ "Compiling a graph for shape %s takes %.2f s",
170
+ runtime_shape,
171
+ elapsed,
172
+ )
173
+
174
+ return compiled_graph
175
+
176
+
177
+ @dataclasses.dataclass
178
+ class SplitItem:
179
+ submod_name: str
180
+ graph_id: int
181
+ is_splitting_graph: bool
182
+ graph: fx.GraphModule
183
+
184
+
185
+ def split_graph(
186
+ graph: fx.GraphModule, ops: list[str]
187
+ ) -> tuple[fx.GraphModule, list[SplitItem]]:
188
+ # split graph by ops
189
+ subgraph_id = 0
190
+ node_to_subgraph_id = {}
191
+ split_op_graphs = []
192
+ for node in graph.graph.nodes:
193
+ if node.op in ("output", "placeholder"):
194
+ continue
195
+ if node.op == "call_function" and str(node.target) in ops:
196
+ subgraph_id += 1
197
+ node_to_subgraph_id[node] = subgraph_id
198
+ split_op_graphs.append(subgraph_id)
199
+ subgraph_id += 1
200
+ else:
201
+ node_to_subgraph_id[node] = subgraph_id
202
+
203
+ # `keep_original_order` is important!
204
+ # otherwise pytorch might reorder the nodes and
205
+ # the semantics of the graph will change when we
206
+ # have mutations in the graph
207
+ split_gm = torch.fx.passes.split_module.split_module(
208
+ graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
209
+ )
210
+
211
+ outputs = []
212
+
213
+ names = [name for (name, module) in split_gm.named_modules()]
214
+
215
+ for name in names:
216
+ if "." in name or name == "":
217
+ # recursive child module or the root module
218
+ continue
219
+
220
+ module = getattr(split_gm, name)
221
+
222
+ graph_id = int(name.replace("submod_", ""))
223
+ outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
224
+
225
+ # sort by intetger graph_id, rather than string name
226
+ outputs.sort(key=lambda x: x.graph_id)
227
+
228
+ return split_gm, outputs
229
+
230
+
231
+ # we share the global graph pool among all the backends
232
+ global_graph_pool = None
233
+
234
+ compilation_start_time = 0.0
235
+
236
+
237
+ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
238
+ def __init__(
239
+ self,
240
+ module: torch.fx.GraphModule,
241
+ compile_submod_names: list[str],
242
+ inductor_config: dict[str, Any],
243
+ graph_pool,
244
+ compile_config: CompilationConfig,
245
+ sglang_backend: "SGLangBackend",
246
+ ):
247
+ super().__init__(module)
248
+ from torch._guards import detect_fake_mode
249
+
250
+ self.fake_mode = detect_fake_mode()
251
+ self.compile_submod_names = compile_submod_names
252
+ self.graph_pool = graph_pool
253
+ self.sglang_backend = sglang_backend
254
+ # When True, it annoyingly dumps the torch.fx.Graph on errors.
255
+ self.extra_traceback = False
256
+ self.inductor_config = inductor_config
257
+ self.compile_config = compile_config
258
+
259
+ def run(self, *args):
260
+ fake_args = [
261
+ self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
262
+ for t in args
263
+ ]
264
+ with self.fake_mode, enable_python_dispatcher():
265
+ return super().run(*fake_args)
266
+
267
+ def call_module(
268
+ self,
269
+ target: torch.fx.node.Target,
270
+ args: tuple[torch.fx.node.Argument, ...],
271
+ kwargs: dict[str, Any],
272
+ ) -> Any:
273
+ assert isinstance(target, str)
274
+ output = super().call_module(target, args, kwargs)
275
+
276
+ if target in self.compile_submod_names:
277
+ index = self.compile_submod_names.index(target)
278
+ submod = self.fetch_attr(target)
279
+ sym_shape_indices = [
280
+ i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
281
+ ]
282
+ global compilation_start_time
283
+ compiled_graph_for_dynamic_shape = (
284
+ self.sglang_backend.compiler_manager.compile(
285
+ submod,
286
+ args,
287
+ self.inductor_config,
288
+ graph_index=index,
289
+ num_graphs=len(self.compile_submod_names),
290
+ runtime_shape=None,
291
+ )
292
+ )
293
+
294
+ self.module.__dict__[target] = CUDAPiecewiseBackend(
295
+ submod,
296
+ self.compile_config,
297
+ self.inductor_config,
298
+ self.graph_pool,
299
+ index,
300
+ len(self.compile_submod_names),
301
+ sym_shape_indices,
302
+ compiled_graph_for_dynamic_shape,
303
+ self.sglang_backend,
304
+ )
305
+
306
+ compilation_counter.num_piecewise_capturable_graphs_seen += 1
307
+
308
+ return output
309
+
310
+
311
+ model_tag: str = "backbone"
312
+
313
+
314
+ @contextmanager
315
+ def set_model_tag(tag: str):
316
+ """Context manager to set the model tag."""
317
+ global model_tag
318
+ assert (
319
+ tag != model_tag
320
+ ), f"Model tag {tag} is the same as the current tag {model_tag}."
321
+ old_tag = model_tag
322
+ model_tag = tag
323
+ try:
324
+ yield
325
+ finally:
326
+ model_tag = old_tag
327
+
328
+
329
+ class SGLangBackend:
330
+
331
+ graph_pool: Any
332
+ _called: bool = False
333
+ # the graph we compiled
334
+ graph: fx.GraphModule
335
+ # the stiching graph module for all the piecewise graphs
336
+ split_gm: fx.GraphModule
337
+ piecewise_graphs: list[SplitItem]
338
+ returned_callable: Callable
339
+ # Inductor passes to run on the graph pre-defunctionalization
340
+ post_grad_passes: Sequence[Callable]
341
+ sym_tensor_indices: list[int]
342
+ input_buffers: list[torch.Tensor]
343
+ compiler_manager: CompilerManager
344
+
345
+ def __init__(
346
+ self,
347
+ config: CompilationConfig,
348
+ graph_pool: Any,
349
+ ):
350
+ assert graph_pool is not None
351
+ self.graph_pool = graph_pool
352
+
353
+ self.post_grad_pass_manager = PostGradPassManager()
354
+ self.sym_tensor_indices = []
355
+ self.input_buffers = []
356
+
357
+ self.compiler_manager = CompilerManager(config)
358
+ self.inductor_config = {
359
+ "enable_auto_functionalized_v2": False,
360
+ }
361
+ self.compile_config = config
362
+
363
+ def configure_post_pass(self):
364
+ self.post_grad_pass_manager.configure()
365
+ self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
366
+
367
+ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
368
+ base_cache_dir = os.path.expanduser(
369
+ os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
370
+ )
371
+
372
+ cache_hash = self.compiler_manager.compute_hash()
373
+ cache_dir = os.path.join(
374
+ base_cache_dir,
375
+ "torch_compile_cache",
376
+ cache_hash,
377
+ )
378
+
379
+ os.makedirs(cache_dir, exist_ok=True)
380
+ rank = 0
381
+ dp_rank = 0
382
+ local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag)
383
+ os.makedirs(local_cache_dir, exist_ok=True)
384
+ self.compiler_manager.initialize_cache(
385
+ local_cache_dir, disable_cache=False, prefix=""
386
+ )
387
+ compilation_counter.num_graphs_seen += 1
388
+
389
+ assert not self._called, "SGLangBackend can only be called once"
390
+
391
+ self.graph = graph
392
+ self.configure_post_pass()
393
+
394
+ self.split_gm, self.piecewise_graphs = split_graph(
395
+ graph, ["sglang.unified_attention_with_output"]
396
+ )
397
+
398
+ from torch._dynamo.utils import lazy_format_graph_code
399
+
400
+ # depyf will hook lazy_format_graph_code and dump the graph
401
+ # for debugging, no need to print the graph here
402
+ lazy_format_graph_code("before split", self.graph)
403
+ lazy_format_graph_code("after split", self.split_gm)
404
+
405
+ compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
406
+
407
+ submod_names_to_compile = [
408
+ item.submod_name
409
+ for item in self.piecewise_graphs
410
+ if not item.is_splitting_graph
411
+ ]
412
+
413
+ PiecewiseCompileInterpreter(
414
+ self.split_gm,
415
+ submod_names_to_compile,
416
+ self.inductor_config,
417
+ self.graph_pool,
418
+ self.compile_config,
419
+ self,
420
+ ).run(*example_inputs)
421
+
422
+ graph_path = os.path.join(local_cache_dir, "computation_graph.py")
423
+ if not os.path.exists(graph_path):
424
+ # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
425
+ # use `print_readable` because it can include submodules
426
+ src = (
427
+ "from __future__ import annotations\nimport torch\n"
428
+ + self.split_gm.print_readable(print_output=False)
429
+ )
430
+ src = src.replace("<lambda>", "GraphModule")
431
+ with open(graph_path, "w") as f:
432
+ f.write(src)
433
+
434
+ logger.debug("Computation graph saved to %s", graph_path)
435
+
436
+ self._called = True
437
+ return self.split_gm
@@ -0,0 +1,20 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
2
+
3
+ from typing import List
4
+
5
+
6
+ # TODO(Yuwei): support better compile config support
7
+ class CompilationConfig:
8
+ def __init__(self, capture_sizes: List[int], compiler: str = "eager"):
9
+ self.traced_files = set()
10
+ self.capture_sizes = capture_sizes
11
+ self.compiler = compiler
12
+
13
+ def add_traced_file(self, file_path: str):
14
+ self.traced_files.add(file_path)
15
+
16
+ def get_traced_files(self):
17
+ return self.traced_files
18
+
19
+ def get_capture_sizes(self):
20
+ return self.capture_sizes
@@ -0,0 +1,47 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
2
+
3
+ import copy
4
+ import dataclasses
5
+ from contextlib import contextmanager
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class CompilationCounter:
10
+ num_models_seen: int = 0
11
+ num_graphs_seen: int = 0
12
+ # including the splitting ops
13
+ num_piecewise_graphs_seen: int = 0
14
+ # not including the splitting ops
15
+ num_piecewise_capturable_graphs_seen: int = 0
16
+ num_backend_compilations: int = 0
17
+ # Number of gpu_model_runner attempts to trigger CUDAGraphs capture
18
+ num_gpu_runner_capture_triggers: int = 0
19
+ # Number of CUDAGraphs captured
20
+ num_cudagraph_captured: int = 0
21
+ # InductorAdapter.compile calls
22
+ num_inductor_compiles: int = 0
23
+ # EagerAdapter.compile calls
24
+ num_eager_compiles: int = 0
25
+ # The number of time vLLM's compiler cache entry was updated
26
+ num_cache_entries_updated: int = 0
27
+ # The number of standalone_compile compiled artifacts saved
28
+ num_compiled_artifacts_saved: int = 0
29
+ # Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
30
+ dynamo_as_is_count: int = 0
31
+
32
+ def clone(self) -> "CompilationCounter":
33
+ return copy.deepcopy(self)
34
+
35
+ @contextmanager
36
+ def expect(self, **kwargs):
37
+ old = self.clone()
38
+ yield
39
+ for k, v in kwargs.items():
40
+ assert getattr(self, k) - getattr(old, k) == v, (
41
+ f"{k} not as expected, before it is {getattr(old, k)}"
42
+ f", after it is {getattr(self, k)}, "
43
+ f"expected diff is {v}"
44
+ )
45
+
46
+
47
+ compilation_counter = CompilationCounter()