sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,547 @@
1
+ # Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py
2
+
3
+ import contextlib
4
+ from collections import namedtuple
5
+ from collections.abc import Callable
6
+ from typing import Any, Dict
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ __all__ = [
13
+ "set_batch_invariant_mode",
14
+ "is_batch_invariant_mode_enabled",
15
+ "disable_batch_invariant_mode",
16
+ "enable_batch_invariant_mode",
17
+ ]
18
+
19
+
20
+ def _matmul_launch_metadata(
21
+ grid: Callable[..., Any], kernel: Any, args: Dict[str, Any]
22
+ ) -> Dict[str, Any]:
23
+ ret = {}
24
+ m, n, k = args["M"], args["N"], args["K"]
25
+ ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
26
+ if "tiles_per_update" in args:
27
+ ret["name"] = (
28
+ f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]"
29
+ )
30
+ if "c_ptr" in args:
31
+ bytes_per_elem = args["c_ptr"].element_size()
32
+ else:
33
+ bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
34
+ ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
35
+ ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
36
+ return ret
37
+
38
+
39
+ @triton.jit
40
+ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
41
+ group_id = tile_id // num_pid_in_group
42
+ first_pid_m = group_id * GROUP_SIZE_M
43
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
44
+ pid_m = first_pid_m + (tile_id % group_size_m)
45
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
46
+ return pid_m, pid_n
47
+
48
+
49
+ @triton.jit(launch_metadata=_matmul_launch_metadata)
50
+ def matmul_kernel_persistent(
51
+ a_ptr,
52
+ b_ptr,
53
+ c_ptr, #
54
+ bias_ptr,
55
+ M,
56
+ N,
57
+ K, #
58
+ stride_am,
59
+ stride_ak,
60
+ stride_bk,
61
+ stride_bn,
62
+ stride_cm,
63
+ stride_cn,
64
+ BLOCK_SIZE_M: tl.constexpr, #
65
+ BLOCK_SIZE_N: tl.constexpr, #
66
+ BLOCK_SIZE_K: tl.constexpr, #
67
+ GROUP_SIZE_M: tl.constexpr, #
68
+ NUM_SMS: tl.constexpr, #
69
+ A_LARGE: tl.constexpr,
70
+ B_LARGE: tl.constexpr,
71
+ C_LARGE: tl.constexpr,
72
+ HAS_BIAS: tl.constexpr,
73
+ ):
74
+ start_pid = tl.program_id(axis=0)
75
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
76
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
77
+ k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
78
+ num_tiles = num_pid_m * num_pid_n
79
+
80
+ offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
81
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
82
+
83
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
84
+ pid_m, pid_n = _compute_pid(
85
+ tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
86
+ )
87
+ start_m = pid_m * BLOCK_SIZE_M
88
+ start_n = pid_n * BLOCK_SIZE_N
89
+ offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
90
+ offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
91
+ if A_LARGE:
92
+ offs_am = offs_am.to(tl.int64)
93
+ if B_LARGE:
94
+ offs_bn = offs_bn.to(tl.int64)
95
+ offs_am = tl.where(offs_am < M, offs_am, 0)
96
+ offs_bn = tl.where(offs_bn < N, offs_bn, 0)
97
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
98
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
99
+
100
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
101
+ for ki in range(k_tiles):
102
+ if A_LARGE or B_LARGE:
103
+ offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
104
+ else:
105
+ offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
106
+ a_ptrs = a_ptr + (
107
+ offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
108
+ )
109
+ b_ptrs = b_ptr + (
110
+ offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
111
+ )
112
+
113
+ a = tl.load(
114
+ a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
115
+ )
116
+ b = tl.load(
117
+ b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
118
+ )
119
+ accumulator = tl.dot(a, b, accumulator)
120
+
121
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
122
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
123
+ if C_LARGE:
124
+ offs_cm = offs_cm.to(tl.int64)
125
+ offs_cn = offs_cn.to(tl.int64)
126
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
127
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
128
+ if HAS_BIAS:
129
+ bias_ptrs = bias_ptr + offs_cn
130
+ bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
131
+ accumulator += bias
132
+ if c_ptr.dtype.element_ty == tl.float8e4nv:
133
+ c = accumulator.to(tl.float8e4nv)
134
+ elif c_ptr.dtype.element_ty == tl.bfloat16:
135
+ c = accumulator.to(tl.bfloat16)
136
+ elif c_ptr.dtype.element_ty == tl.float32:
137
+ c = accumulator.to(tl.float32)
138
+ else:
139
+ c = accumulator.to(tl.float16)
140
+ tl.store(c_ptrs, c, mask=c_mask)
141
+
142
+
143
+ def matmul_persistent(
144
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
145
+ ):
146
+ # Check constraints.
147
+ assert a.shape[1] == b.shape[0], "Incompatible dimensions"
148
+ assert a.dtype == b.dtype, "Incompatible dtypes"
149
+ assert (
150
+ bias is None or bias.dim() == 1
151
+ ), "Currently assuming bias is 1D, let Horace know if you run into this"
152
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
153
+ M, K = a.shape
154
+ K, N = b.shape
155
+ dtype = a.dtype
156
+ # Allocates output.
157
+ c = torch.empty((M, N), device=a.device, dtype=dtype)
158
+
159
+ # 1D launch kernel where each block gets its own program.
160
+ def grid(META):
161
+ return (
162
+ min(
163
+ NUM_SMS,
164
+ triton.cdiv(M, META["BLOCK_SIZE_M"])
165
+ * triton.cdiv(N, META["BLOCK_SIZE_N"]),
166
+ ),
167
+ )
168
+
169
+ configs = {
170
+ torch.bfloat16: {
171
+ "BLOCK_SIZE_M": 128,
172
+ "BLOCK_SIZE_N": 128,
173
+ "BLOCK_SIZE_K": 64,
174
+ "GROUP_SIZE_M": 8,
175
+ "num_stages": 3,
176
+ "num_warps": 8,
177
+ },
178
+ torch.float16: {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 256,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_stages": 3,
184
+ "num_warps": 8,
185
+ },
186
+ torch.float32: {
187
+ "BLOCK_SIZE_M": 128,
188
+ "BLOCK_SIZE_N": 128,
189
+ "BLOCK_SIZE_K": 32,
190
+ "GROUP_SIZE_M": 8,
191
+ "num_stages": 3,
192
+ "num_warps": 8,
193
+ },
194
+ }
195
+ # print(a.device, b.device, c.device)
196
+ matmul_kernel_persistent[grid](
197
+ a,
198
+ b,
199
+ c, #
200
+ bias,
201
+ M,
202
+ N,
203
+ K, #
204
+ a.stride(0),
205
+ a.stride(1), #
206
+ b.stride(0),
207
+ b.stride(1), #
208
+ c.stride(0),
209
+ c.stride(1), #
210
+ NUM_SMS=NUM_SMS, #
211
+ A_LARGE=a.numel() > 2**31,
212
+ B_LARGE=b.numel() > 2**31,
213
+ C_LARGE=c.numel() > 2**31,
214
+ HAS_BIAS=bias is not None,
215
+ **configs[dtype],
216
+ )
217
+ return c
218
+
219
+
220
+ @triton.jit
221
+ def _log_softmax_kernel(
222
+ input_ptr,
223
+ output_ptr,
224
+ input_row_stride,
225
+ output_row_stride,
226
+ n_cols,
227
+ BLOCK_SIZE: tl.constexpr,
228
+ ):
229
+ """
230
+ Compute log_softmax along the last dimension of a 2D tensor.
231
+ Each block handles one row of the input tensor.
232
+ """
233
+ # Get the row index for this block
234
+ row_idx = tl.program_id(0).to(tl.int64)
235
+
236
+ # Compute base pointers for input and output rows
237
+ row_start_ptr = input_ptr + row_idx * input_row_stride
238
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
239
+
240
+ # Step 1: Find maximum value in the row for numerical stability
241
+ max_val = -float("inf")
242
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
243
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
244
+ mask = col_idx < n_cols
245
+
246
+ # Load values
247
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
248
+
249
+ # Update maximum
250
+ max_val = tl.max(tl.maximum(vals, max_val))
251
+
252
+ # Step 2: Compute sum of exp(x - max_val)
253
+ sum_exp = 0.0
254
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
255
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
256
+ mask = col_idx < n_cols
257
+
258
+ # Load values
259
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
260
+
261
+ # Compute exp(x - max_val) and accumulate
262
+ exp_vals = tl.exp(vals - max_val)
263
+ sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
264
+
265
+ # Compute log(sum_exp)
266
+ log_sum_exp = tl.log(sum_exp)
267
+
268
+ # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
269
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
270
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
271
+ mask = col_idx < n_cols
272
+
273
+ # Load values
274
+ vals = tl.load(row_start_ptr + col_idx, mask=mask)
275
+
276
+ # Compute log_softmax
277
+ output = vals - max_val - log_sum_exp
278
+
279
+ # Store results
280
+ tl.store(output_row_start_ptr + col_idx, output, mask=mask)
281
+
282
+
283
+ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
284
+ """
285
+ Compute log_softmax using Triton kernel.
286
+
287
+ Args:
288
+ input: Input tensor
289
+ dim: Dimension along which to compute log_softmax (only -1 or last dim supported)
290
+ >> Stashed changes
291
+ Returns:
292
+ Tensor with log_softmax applied along the specified dimension
293
+ """
294
+ if dim != -1 and dim != input.ndim - 1:
295
+ raise ValueError(
296
+ "This implementation only supports log_softmax along the last dimension"
297
+ )
298
+
299
+ # Flatten all dimensions except the last one
300
+ original_shape = input.shape
301
+ input_2d = input.reshape(-1, input.shape[-1])
302
+ input_2d = input_2d.contiguous()
303
+
304
+ n_rows, n_cols = input_2d.shape
305
+
306
+ # Allocate output tensor
307
+ output = torch.empty_like(input_2d)
308
+
309
+ # Choose block size based on the number of columns
310
+ BLOCK_SIZE = 1024
311
+
312
+ # Launch kernel with one block per row
313
+ grid = (n_rows,)
314
+ _log_softmax_kernel[grid](
315
+ input_2d,
316
+ output,
317
+ input_2d.stride(0),
318
+ output.stride(0),
319
+ n_cols,
320
+ BLOCK_SIZE=BLOCK_SIZE,
321
+ )
322
+ # Reshape output back to original shape
323
+ return output.reshape(original_shape)
324
+
325
+
326
+ @triton.jit
327
+ def mean_kernel(
328
+ input_ptr,
329
+ output_ptr,
330
+ input_stride0,
331
+ input_stride1,
332
+ input_stride2,
333
+ output_stride0,
334
+ output_stride1,
335
+ M, # size before reduction dim
336
+ N, # size of reduction dim
337
+ K, # size after reduction dim
338
+ BLOCK_SIZE: tl.constexpr,
339
+ ):
340
+ """
341
+ Kernel for computing mean along a single dimension.
342
+ Input is viewed as (M, N, K) where N is the dimension being reduced.
343
+ """
344
+ # Program ID gives us which output element we're computing
345
+ pid = tl.program_id(0)
346
+
347
+ # Compute output indices
348
+ m_idx = pid // K
349
+ k_idx = pid % K
350
+
351
+ # Bounds check
352
+ if m_idx >= M or k_idx >= K:
353
+ return
354
+
355
+ # Accumulate sum across reduction dimension
356
+ acc = 0.0
357
+ for n_start in range(0, N, BLOCK_SIZE):
358
+ n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
359
+ mask = n_offsets < N
360
+
361
+ # Calculate input indices
362
+ input_idx = (
363
+ m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
364
+ )
365
+
366
+ # Load and accumulate
367
+ vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
368
+ acc += tl.sum(vals)
369
+
370
+ # Compute mean and store
371
+ mean_val = acc / N
372
+ output_idx = m_idx * output_stride0 + k_idx * output_stride1
373
+ tl.store(output_ptr + output_idx, mean_val)
374
+
375
+
376
+ def mean_dim(
377
+ input: torch.Tensor,
378
+ dim: int,
379
+ keepdim: bool = False,
380
+ dtype: torch.dtype | None = None,
381
+ ) -> torch.Tensor:
382
+ """
383
+ Triton implementation of torch.mean with single dimension reduction.
384
+
385
+ Args:
386
+ input: Input tensor
387
+ dim: Single dimension along which to compute mean
388
+ keepdim: Whether to keep the reduced dimension
389
+ dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
390
+
391
+ Returns:
392
+ Tensor with mean values along specified dimension
393
+ """
394
+ # Validate inputs
395
+ assert input.is_cuda, "Input must be a CUDA tensor"
396
+ assert (
397
+ -input.ndim <= dim < input.ndim
398
+ ), f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
399
+
400
+ # Handle negative dim
401
+ if dim < 0:
402
+ dim = dim + input.ndim
403
+
404
+ # Handle dtype
405
+ if dtype is None:
406
+ if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
407
+ dtype = torch.float32
408
+ else:
409
+ dtype = input.dtype
410
+
411
+ # Convert input to appropriate dtype if needed
412
+ if input.dtype != dtype:
413
+ input = input.to(dtype)
414
+
415
+ # Get input shape and strides
416
+ shape = list(input.shape)
417
+
418
+ # Calculate dimensions for kernel
419
+ M = 1
420
+ for i in range(dim):
421
+ M *= shape[i]
422
+
423
+ N = shape[dim]
424
+
425
+ K = 1
426
+ for i in range(dim + 1, len(shape)):
427
+ K *= shape[i]
428
+
429
+ # Reshape input to 3D view (M, N, K)
430
+ input_3d = input.reshape(M, N, K)
431
+
432
+ # Create output shape
433
+ if keepdim:
434
+ output_shape = shape.copy()
435
+ output_shape[dim] = 1
436
+ else:
437
+ output_shape = shape[:dim] + shape[dim + 1 :]
438
+
439
+ # Create output tensor
440
+ output = torch.empty(output_shape, dtype=dtype, device=input.device)
441
+
442
+ # Reshape output for kernel
443
+ if keepdim:
444
+ output_2d = output.reshape(M, 1, K).squeeze(1)
445
+ else:
446
+ output_2d = output.reshape(M, K)
447
+
448
+ # Launch kernel
449
+ grid = (M * K,)
450
+ BLOCK_SIZE = 1024
451
+
452
+ mean_kernel[grid](
453
+ input_3d,
454
+ output_2d,
455
+ input_3d.stride(0),
456
+ input_3d.stride(1),
457
+ input_3d.stride(2),
458
+ output_2d.stride(0),
459
+ output_2d.stride(1) if output_2d.ndim > 1 else 0,
460
+ M,
461
+ N,
462
+ K,
463
+ BLOCK_SIZE,
464
+ )
465
+
466
+ return output
467
+
468
+
469
+ def mm_batch_invariant(a, b):
470
+ return matmul_persistent(a, b)
471
+
472
+
473
+ def addmm_batch_invariant(bias, a, b):
474
+ return matmul_persistent(a, b, bias=bias)
475
+
476
+
477
+ def _log_softmax_batch_invariant(input, dim, _half_to_float):
478
+ assert not _half_to_float, "not implemented"
479
+ return log_softmax(input, dim=dim)
480
+
481
+
482
+ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
483
+ assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
484
+ if len(dim) == 1:
485
+ return mean_dim(input, dim[0], keepdim=keepdim)
486
+ else:
487
+ assert input.dtype in {
488
+ torch.float16,
489
+ torch.bfloat16,
490
+ torch.float32,
491
+ }, "only float types supported for now"
492
+ n_elems = 1
493
+ for d in dim:
494
+ n_elems *= input.shape[d]
495
+ return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
496
+
497
+
498
+ _batch_invariant_MODE = False
499
+ _batch_invariant_LIB = None
500
+
501
+
502
+ def is_batch_invariant_mode_enabled():
503
+ return _batch_invariant_MODE
504
+
505
+
506
+ def enable_batch_invariant_mode():
507
+ global _batch_invariant_MODE, _batch_invariant_LIB
508
+ if _batch_invariant_MODE:
509
+ return
510
+
511
+ _batch_invariant_MODE = True
512
+ _batch_invariant_LIB = torch.library.Library("aten", "IMPL")
513
+ _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
514
+ _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
515
+ _batch_invariant_LIB.impl(
516
+ "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
517
+ )
518
+ _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
519
+
520
+
521
+ def disable_batch_invariant_mode():
522
+ global _batch_invariant_MODE, _batch_invariant_LIB
523
+ if _batch_invariant_LIB is not None:
524
+ _batch_invariant_LIB._destroy()
525
+ _batch_invariant_MODE = False
526
+ _batch_invariant_LIB = None
527
+
528
+
529
+ @contextlib.contextmanager
530
+ def set_batch_invariant_mode(enabled: bool = True):
531
+ global _batch_invariant_MODE, _batch_invariant_LIB
532
+ old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
533
+ if enabled:
534
+ enable_batch_invariant_mode()
535
+ else:
536
+ disable_batch_invariant_mode()
537
+ yield
538
+ if _batch_invariant_LIB is not None:
539
+ _batch_invariant_LIB._destroy()
540
+ _batch_invariant_MODE, _batch_invariant_LIB = old_data
541
+
542
+
543
+ AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
544
+
545
+
546
+ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
547
+ return AttentionBlockSize(block_m=16, block_n=16)
@@ -0,0 +1,142 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Checkpoint-engine integration for SGLang.
16
+ This module provides weight update functionality via IPC for checkpoint-engine compatibility.
17
+ """
18
+ import logging
19
+ from typing import Callable, Dict, Optional
20
+
21
+ import torch
22
+ import zmq
23
+
24
+ try:
25
+ from checkpoint_engine.worker import update_weights_from_ipc
26
+ except ImportError:
27
+ raise ImportError(
28
+ "checkpoint-engine is not installed. "
29
+ "Please install it with: pip install sglang[checkpoint-engine]"
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SGLangCheckpointEngineWorkerExtension:
36
+ """
37
+ Worker extension for SGLang to support checkpoint-engine IPC weight updates.
38
+ This class provides the interface needed for checkpoint-engine integration.
39
+ """
40
+
41
+ def __init__(self):
42
+ self._zmq_ctx: Optional[zmq.Context] = None
43
+
44
+ def get_device_uuid(self) -> str:
45
+ """Get the UUID of current device."""
46
+ # We need to implement this to get the device UUID
47
+ # This will be overridden when integrated into SGLang's worker
48
+ raise NotImplementedError(
49
+ "This method should be overridden by SGLang integration"
50
+ )
51
+
52
+ def get_device_id(self) -> int:
53
+ """Get the device ID."""
54
+ raise NotImplementedError(
55
+ "This method should be overridden by SGLang integration"
56
+ )
57
+
58
+ def get_model_loader(self) -> Callable:
59
+ """Get the model weight loader function."""
60
+ raise NotImplementedError(
61
+ "This method should be overridden by SGLang integration"
62
+ )
63
+
64
+ def get_post_hook(self) -> Optional[Callable]:
65
+ """Get the post-processing hook after weight loading."""
66
+ return None
67
+
68
+ def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
69
+ """
70
+ Update weights from IPC communication.
71
+ Args:
72
+ zmq_handles: Dict mapping device UUID to ZMQ socket path
73
+ """
74
+ if self._zmq_ctx is None:
75
+ self._zmq_ctx = zmq.Context()
76
+ device_uuid = self.get_device_uuid()
77
+ device_id = self.get_device_id()
78
+ if device_uuid not in zmq_handles:
79
+ raise ValueError(
80
+ f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
81
+ )
82
+ update_weights_from_ipc(
83
+ self._zmq_ctx,
84
+ zmq_handles[device_uuid],
85
+ device_id=device_id,
86
+ run=self.get_model_loader(),
87
+ post_hook=self.get_post_hook(),
88
+ )
89
+
90
+
91
+ class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
92
+ """
93
+ Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
94
+ This class provides the concrete implementation for checkpoint-engine IPC weight updates.
95
+ """
96
+
97
+ def __init__(self, model_runner):
98
+ super().__init__()
99
+ self.model_runner = model_runner
100
+
101
+ def get_device_uuid(self) -> str:
102
+ """Get the UUID of current device."""
103
+ # Get device UUID for current device
104
+ device_id = torch.cuda.current_device()
105
+ try:
106
+ return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
107
+ except AssertionError as e:
108
+ raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
109
+
110
+ def get_device_id(self) -> int:
111
+ """Get the device ID."""
112
+ return torch.cuda.current_device()
113
+
114
+ def get_model_loader(self) -> Callable:
115
+ """Get the model weight loader function."""
116
+ return self.model_runner.model.load_weights
117
+
118
+ def get_post_hook(self) -> Optional[Callable]:
119
+ """Get the post-processing hook after weight loading."""
120
+
121
+ def post_hook():
122
+ # Perform post-processing after weight loading similar to DefaultModelLoader
123
+ try:
124
+ from sglang.srt.model_loader.loader import device_loading_context
125
+
126
+ # Process quantization methods after loading weights
127
+ for _, module in self.model_runner.model.named_modules():
128
+ quant_method = getattr(module, "quant_method", None)
129
+ if quant_method is not None:
130
+ # Move parameters to device if needed for quantization processing
131
+ target_device = torch.device(
132
+ "cuda", torch.cuda.current_device()
133
+ )
134
+ with device_loading_context(module, target_device):
135
+ quant_method.process_weights_after_loading(module)
136
+ # Call model-specific post-loading hook if available
137
+ if hasattr(self.model_runner.model, "post_load_weights"):
138
+ self.model_runner.model.post_load_weights()
139
+ except Exception as e:
140
+ logger.warning(f"Post-hook processing failed: {e}")
141
+
142
+ return post_hook