sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,576 @@
1
+ import enum
2
+ import logging
3
+ from typing import Any, Iterable, List, Optional, Set, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.falcon_h1 import FalconH1Config
9
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
10
+ from sglang.srt.layers.activation import SiluAndMul
11
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
12
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
13
+ from sglang.srt.layers.dp_attention import (
14
+ get_attention_tp_rank,
15
+ get_attention_tp_size,
16
+ is_dp_attention_enabled,
17
+ )
18
+ from sglang.srt.layers.layernorm import RMSNorm
19
+ from sglang.srt.layers.linear import (
20
+ MergedColumnParallelLinear,
21
+ QKVParallelLinear,
22
+ RowParallelLinear,
23
+ )
24
+ from sglang.srt.layers.logits_processor import LogitsProcessor
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+ from sglang.srt.layers.radix_attention import RadixAttention
27
+ from sglang.srt.layers.rotary_embedding import get_rope
28
+ from sglang.srt.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead,
30
+ VocabParallelEmbedding,
31
+ )
32
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
33
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
35
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
36
+
37
+ logger = logging.getLogger(__name__)
38
+ _is_cuda = is_cuda()
39
+
40
+
41
+ class FalconH1MLP(nn.Module):
42
+ def __init__(
43
+ self,
44
+ hidden_size: int,
45
+ intermediate_size: int,
46
+ hidden_act: str,
47
+ layer_id: int,
48
+ mlp_multipliers: List[float],
49
+ quant_config: Optional[QuantizationConfig] = None,
50
+ prefix: str = "",
51
+ reduce_results: bool = True,
52
+ ) -> None:
53
+ super().__init__()
54
+ self.gate_up_proj = MergedColumnParallelLinear(
55
+ hidden_size,
56
+ [intermediate_size] * 2,
57
+ bias=False,
58
+ quant_config=quant_config,
59
+ prefix=add_prefix("gate_up_proj", prefix),
60
+ )
61
+ self.down_proj = RowParallelLinear(
62
+ intermediate_size,
63
+ hidden_size,
64
+ bias=False,
65
+ quant_config=quant_config,
66
+ prefix=add_prefix("down_proj", prefix),
67
+ reduce_results=reduce_results,
68
+ )
69
+ if hidden_act != "silu":
70
+ raise ValueError(
71
+ f"Unsupported activation: {hidden_act}. "
72
+ "Only silu is supported for now."
73
+ )
74
+ self.act_fn = SiluAndMul()
75
+ self.layer_id = layer_id
76
+
77
+ self.intermediate_size = intermediate_size
78
+ self.tp_size = get_tensor_model_parallel_world_size()
79
+
80
+ self.gate_multiplier, self.down_multiplier = mlp_multipliers
81
+
82
+ def forward(
83
+ self,
84
+ x,
85
+ forward_batch=None,
86
+ use_reduce_scatter: bool = False,
87
+ ):
88
+ gate_up, _ = self.gate_up_proj(x)
89
+ gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier
90
+
91
+ x = self.act_fn(gate_up)
92
+ x, _ = self.down_proj(
93
+ x,
94
+ skip_all_reduce=use_reduce_scatter,
95
+ )
96
+ x = x * self.down_multiplier
97
+ return x
98
+
99
+
100
+ class FalconH1HybridAttentionDecoderLayer(nn.Module):
101
+
102
+ def __init__(
103
+ self,
104
+ config: FalconH1Config,
105
+ layer_id: int,
106
+ quant_config: Optional[QuantizationConfig] = None,
107
+ prefix: str = "",
108
+ alt_stream: Optional[torch.cuda.Stream] = None,
109
+ ) -> None:
110
+ super().__init__()
111
+ self.config = config
112
+ self.hidden_size = config.hidden_size
113
+ self.attn_tp_rank = get_attention_tp_rank()
114
+ self.attn_tp_size = get_attention_tp_size()
115
+ self.tp_size = get_tensor_model_parallel_world_size()
116
+ self.total_num_heads = config.num_attention_heads
117
+ assert self.total_num_heads % self.attn_tp_size == 0
118
+ self.num_heads = self.total_num_heads // self.attn_tp_size
119
+ self.total_num_kv_heads = config.num_key_value_heads
120
+ if self.total_num_kv_heads >= self.attn_tp_size:
121
+ # Number of KV heads is greater than TP size, so we partition
122
+ # the KV heads across multiple tensor parallel GPUs.
123
+ assert self.total_num_kv_heads % self.attn_tp_size == 0
124
+ else:
125
+ # Number of KV heads is less than TP size, so we replicate
126
+ # the KV heads across multiple tensor parallel GPUs.
127
+ assert self.attn_tp_size % self.total_num_kv_heads == 0
128
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
129
+ self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
130
+ self.q_size = self.num_heads * self.head_dim
131
+ self.kv_size = self.num_kv_heads * self.head_dim
132
+ self.scaling = self.head_dim**-0.5
133
+ self.rope_theta = getattr(config, "rope_theta", 10000)
134
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
135
+ self.rope_scaling = getattr(config, "rope_scaling", None)
136
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
137
+ self.layer_id = layer_id
138
+
139
+ self.rotary_emb = get_rope(
140
+ head_size=self.head_dim,
141
+ rotary_dim=self.head_dim,
142
+ max_position=self.max_position_embeddings,
143
+ rope_scaling=self.rope_scaling,
144
+ base=self.rope_theta,
145
+ partial_rotary_factor=self.partial_rotary_factor,
146
+ is_neox_style=True,
147
+ dtype=torch.get_default_dtype(), # see impl of get_rope
148
+ )
149
+
150
+ self.qkv_proj = QKVParallelLinear(
151
+ config.hidden_size,
152
+ self.head_dim,
153
+ self.total_num_heads,
154
+ self.total_num_kv_heads,
155
+ bias=False,
156
+ quant_config=quant_config,
157
+ tp_rank=self.attn_tp_rank,
158
+ tp_size=self.attn_tp_size,
159
+ )
160
+
161
+ self.o_proj = RowParallelLinear(
162
+ self.total_num_heads * self.head_dim,
163
+ config.hidden_size,
164
+ bias=False,
165
+ quant_config=quant_config,
166
+ reduce_results=False,
167
+ tp_rank=self.attn_tp_rank,
168
+ tp_size=self.attn_tp_size,
169
+ )
170
+
171
+ self.attn = RadixAttention(
172
+ self.num_heads,
173
+ self.head_dim,
174
+ self.scaling,
175
+ num_kv_heads=self.num_kv_heads,
176
+ layer_id=layer_id,
177
+ prefix=f"{prefix}.attn",
178
+ )
179
+
180
+ self.d_ssm = (
181
+ int(config.mamba_expand * config.hidden_size)
182
+ if config.mamba_d_ssm is None
183
+ else config.mamba_d_ssm
184
+ )
185
+
186
+ self.mamba = MambaMixer2(
187
+ hidden_size=config.hidden_size,
188
+ ssm_state_size=config.mamba_d_state,
189
+ conv_kernel_size=config.mamba_d_conv,
190
+ intermediate_size=self.d_ssm,
191
+ use_conv_bias=config.mamba_conv_bias,
192
+ use_bias=config.mamba_proj_bias,
193
+ n_groups=config.mamba_n_groups,
194
+ num_heads=config.mamba_n_heads,
195
+ layer_id=layer_id,
196
+ head_dim=config.mamba_d_head,
197
+ rms_norm_eps=config.rms_norm_eps,
198
+ chunk_size=config.mamba_chunk_size,
199
+ activation=config.hidden_act,
200
+ use_rms_norm=config.mamba_rms_norm,
201
+ prefix=f"{prefix}.mixer",
202
+ )
203
+
204
+ # FalconH1 all layers are sparse and have no nextn now
205
+ self.is_layer_sparse = False
206
+ is_previous_layer_sparse = False
207
+
208
+ self.layer_scatter_modes = LayerScatterModes.init_new(
209
+ layer_id=layer_id,
210
+ num_layers=config.num_hidden_layers,
211
+ is_layer_sparse=self.is_layer_sparse,
212
+ is_previous_layer_sparse=is_previous_layer_sparse,
213
+ )
214
+
215
+ self.feed_forward = FalconH1MLP(
216
+ hidden_size=self.hidden_size,
217
+ intermediate_size=config.intermediate_size,
218
+ hidden_act=config.hidden_act,
219
+ layer_id=layer_id,
220
+ mlp_multipliers=config.mlp_multipliers,
221
+ quant_config=quant_config,
222
+ prefix=add_prefix("mlp", prefix),
223
+ )
224
+
225
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
226
+ self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
227
+
228
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
229
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
230
+
231
+ self.layer_communicator = LayerCommunicator(
232
+ layer_scatter_modes=self.layer_scatter_modes,
233
+ input_layernorm=self.input_layernorm,
234
+ post_attention_layernorm=self.pre_ff_layernorm,
235
+ allow_reduce_scatter=True,
236
+ )
237
+
238
+ self.alt_stream = alt_stream
239
+ self.key_multiplier = config.key_multiplier
240
+
241
+ self.ssm_out_multiplier = config.ssm_out_multiplier
242
+ self.ssm_in_multiplier = config.ssm_in_multiplier
243
+
244
+ self.attention_in_multiplier = config.attention_in_multiplier
245
+ self.attn_out_multiplier = config.attention_out_multiplier
246
+
247
+ self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
248
+ self.zxbcdt_multipliers = config.ssm_multipliers
249
+ self._init_mup_vector()
250
+
251
+ def _init_mup_vector(self):
252
+ """
253
+ Non learnable per-block scaling vector composed of element-wise
254
+ multipliersapplied to each separate contiguous block of the output
255
+ of the linear projection (in_proj) before further processing
256
+ (gating, convolution, SSM):
257
+
258
+ - Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
259
+ - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
260
+ - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
261
+ - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
262
+ → zxbcdt_multipliers[3]
263
+ - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
264
+
265
+ where:
266
+ - d_ssm: Dimension of state-space model latent
267
+ - G: Number of groups (n_groups)
268
+ - S: SSM state size per group
269
+ - All indices are divided by tp_size to support tensor parallelism
270
+ """
271
+ vector_shape = (
272
+ 2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads
273
+ ) // self.tp_size
274
+ mup_vector = torch.ones(1, vector_shape)
275
+ # Z vector 0 -> d_ssm
276
+ mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0]
277
+ # X vector d_ssm -> 2 * d_ssm
278
+ mup_vector[
279
+ :, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size)
280
+ ] *= self.zxbcdt_multipliers[1]
281
+ # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
282
+ mup_vector[
283
+ :,
284
+ (2 * self.d_ssm)
285
+ // self.tp_size : (2 * self.d_ssm + self.groups_time_state_size)
286
+ // self.tp_size,
287
+ ] *= self.zxbcdt_multipliers[2]
288
+ # C vector 2 * d_ssm + (n_group * d_state)
289
+ # -> 2 * d_ssm + 2 * (n_group * d_state)
290
+ mup_vector[
291
+ :,
292
+ (2 * self.d_ssm + self.groups_time_state_size)
293
+ // self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size)
294
+ // self.tp_size,
295
+ ] *= self.zxbcdt_multipliers[3]
296
+ # dt vector 2 * d_ssm + 2 * (n_group * d_state)
297
+ # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
298
+ mup_vector[
299
+ :,
300
+ (2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :,
301
+ ] *= self.zxbcdt_multipliers[4]
302
+
303
+ self.register_buffer("mup_vector", mup_vector, persistent=False)
304
+
305
+ def self_attention(
306
+ self,
307
+ positions: torch.Tensor,
308
+ hidden_states: torch.Tensor,
309
+ forward_batch: ForwardBatch,
310
+ ) -> torch.Tensor:
311
+ qkv, _ = self.qkv_proj(hidden_states)
312
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
313
+ k = k * self.key_multiplier
314
+ q, k = self.rotary_emb(positions, q, k)
315
+
316
+ attn_output = self.attn(q, k, v, forward_batch)
317
+
318
+ output, _ = self.o_proj(attn_output)
319
+ return output
320
+
321
+ def forward(
322
+ self,
323
+ positions: torch.Tensor,
324
+ hidden_states: torch.Tensor,
325
+ residual: Optional[torch.Tensor],
326
+ forward_batch: ForwardBatch,
327
+ **kwargs: Any,
328
+ ):
329
+ hidden_states, residual = self.layer_communicator.prepare_attn(
330
+ hidden_states, residual, forward_batch
331
+ )
332
+
333
+ if not forward_batch.forward_mode.is_idle():
334
+ # Attention block
335
+ attention_hidden_states = self.self_attention(
336
+ positions=positions,
337
+ hidden_states=hidden_states * self.attention_in_multiplier,
338
+ forward_batch=forward_batch,
339
+ )
340
+ attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
341
+
342
+ # Mamba block
343
+ mamba_hidden_states = torch.empty_like(hidden_states)
344
+ self.mamba(
345
+ hidden_states * self.ssm_in_multiplier,
346
+ mamba_hidden_states,
347
+ forward_batch=forward_batch,
348
+ mup_vector=self.mup_vector,
349
+ )
350
+ mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
351
+
352
+ hidden_states = attention_hidden_states + mamba_hidden_states
353
+
354
+ # Fully Connected
355
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
356
+ hidden_states, residual, forward_batch
357
+ )
358
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
359
+ forward_batch
360
+ )
361
+ hidden_states = self.feed_forward(
362
+ hidden_states, forward_batch, use_reduce_scatter
363
+ )
364
+
365
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
366
+ hidden_states, residual, forward_batch
367
+ )
368
+
369
+ return hidden_states, residual
370
+
371
+
372
+ ALL_DECODER_LAYER_TYPES = {
373
+ "falcon_h1": FalconH1HybridAttentionDecoderLayer,
374
+ }
375
+
376
+
377
+ class FalconH1Model(nn.Module):
378
+ def __init__(
379
+ self,
380
+ config: FalconH1Config,
381
+ quant_config: Optional[QuantizationConfig] = None,
382
+ prefix: str = "",
383
+ ) -> None:
384
+ super().__init__()
385
+ self.config = config
386
+
387
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
388
+ self.embedding_multiplier = config.embedding_multiplier
389
+
390
+ self.embed_tokens = VocabParallelEmbedding(
391
+ config.vocab_size,
392
+ config.hidden_size,
393
+ org_num_embeddings=config.vocab_size,
394
+ enable_tp=not is_dp_attention_enabled(),
395
+ )
396
+
397
+ def get_layer(idx: int, prefix: str):
398
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
399
+ return layer_class(
400
+ config,
401
+ idx,
402
+ quant_config=quant_config,
403
+ prefix=prefix,
404
+ alt_stream=alt_stream,
405
+ )
406
+
407
+ self.layers = make_layers(
408
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
409
+ )
410
+
411
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
412
+ self.infer_count = 0
413
+
414
+ def forward(
415
+ self,
416
+ input_ids: torch.Tensor,
417
+ positions: torch.Tensor,
418
+ forward_batch: ForwardBatch,
419
+ # mamba_cache_params: MambaCacheParams,
420
+ inputs_embeds: Optional[torch.Tensor] = None,
421
+ ) -> torch.Tensor:
422
+
423
+ # pass a sequence index tensor, that is required for
424
+ # proper continuous batching computation including
425
+ # chunked prefill
426
+ if inputs_embeds is not None:
427
+ hidden_states = inputs_embeds * self.embedding_multiplier
428
+ else:
429
+ hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier
430
+
431
+ residual = None
432
+ for i in range(len(self.layers)):
433
+ layer = self.layers[i]
434
+ hidden_states, residual = layer(
435
+ layer_id=i,
436
+ positions=positions,
437
+ hidden_states=hidden_states,
438
+ residual=residual,
439
+ forward_batch=forward_batch,
440
+ )
441
+
442
+ if not forward_batch.forward_mode.is_idle():
443
+ if residual is None:
444
+ hidden_states = self.final_layernorm(hidden_states)
445
+ else:
446
+ hidden_states, _ = self.final_layernorm(hidden_states, residual)
447
+
448
+ return hidden_states
449
+
450
+
451
+ class HybridLayerType(enum.Enum):
452
+ full_attention = "attention"
453
+ swa_attention = "swa_attention"
454
+ linear_attention = "linear_attention"
455
+ mamba2 = "mamba"
456
+
457
+
458
+ class FalconH1ForCausalLM(nn.Module):
459
+ fall_back_to_pt_during_load = False
460
+
461
+ def __init__(
462
+ self,
463
+ config: FalconH1Config,
464
+ quant_config: Optional[QuantizationConfig] = None,
465
+ prefix: str = "",
466
+ ) -> None:
467
+ super().__init__()
468
+ self.config = config
469
+ self.pp_group = get_pp_group()
470
+ assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
471
+ self.quant_config = quant_config
472
+ self.model = FalconH1Model(
473
+ config, quant_config, prefix=add_prefix("model", prefix)
474
+ )
475
+ if config.tie_word_embeddings:
476
+ self.lm_head = self.model.embed_tokens
477
+ else:
478
+ self.lm_head = ParallelLMHead(
479
+ config.vocab_size,
480
+ config.hidden_size,
481
+ quant_config=quant_config,
482
+ org_num_embeddings=config.vocab_size,
483
+ prefix=add_prefix("lm_head", prefix),
484
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
485
+ )
486
+ self.lm_head = self.lm_head.float()
487
+ self.lm_head_multiplier = config.lm_head_multiplier
488
+ self.logits_processor = LogitsProcessor(
489
+ config, logit_scale=self.lm_head_multiplier
490
+ )
491
+
492
+ @torch.no_grad()
493
+ def forward(
494
+ self,
495
+ input_ids: torch.Tensor,
496
+ positions: torch.Tensor,
497
+ forward_batch: ForwardBatch,
498
+ inputs_embeds: Optional[torch.Tensor] = None,
499
+ **kwargs,
500
+ ):
501
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
502
+
503
+ return self.logits_processor(
504
+ input_ids, hidden_states, self.lm_head, forward_batch
505
+ )
506
+
507
+ def get_embed_and_head(self):
508
+ return self.model.embed_tokens.weight, self.lm_head.weight
509
+
510
+ def set_embed_and_head(self, embed, head):
511
+ del self.model.embed_tokens.weight
512
+ del self.lm_head.weight
513
+ self.model.embed_tokens.weight = embed
514
+ self.lm_head.weight = head
515
+ torch.cuda.empty_cache()
516
+ torch.cuda.synchronize()
517
+
518
+ def load_weights(
519
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
520
+ ) -> Set[str]:
521
+ stacked_params_mapping = [
522
+ # (param_name, shard_name, shard_id)
523
+ ("qkv_proj", "q_proj", "q"),
524
+ ("qkv_proj", "k_proj", "k"),
525
+ ("qkv_proj", "v_proj", "v"),
526
+ ("gate_up_proj", "gate_proj", 0),
527
+ ("gate_up_proj", "up_proj", 1),
528
+ ]
529
+
530
+ params_dict = dict(self.named_parameters())
531
+ loaded_params: Set[str] = set()
532
+ for name, loaded_weight in weights:
533
+
534
+ if "rotary_emb.inv_freq" in name:
535
+ continue
536
+
537
+ if ".self_attn." in name:
538
+ name = name.replace(".self_attn", "")
539
+
540
+ if "A_log" in name:
541
+ name = name.replace("A_log", "A")
542
+
543
+ for param_name, weight_name, shard_id in stacked_params_mapping:
544
+ if weight_name not in name:
545
+ continue
546
+
547
+ name = name.replace(weight_name, param_name)
548
+ # Skip loading extra bias for GPTQ models.
549
+ if name.endswith(".bias") and name not in params_dict:
550
+ continue
551
+ # Skip layers on other devices.
552
+ # if is_pp_missing_parameter(name, self):
553
+ # continue
554
+ if name not in params_dict:
555
+ continue
556
+ param = params_dict[name]
557
+ weight_loader = getattr(param, "weight_loader")
558
+ weight_loader(param, loaded_weight, shard_id)
559
+ break
560
+ else:
561
+ # Skip loading extra bias for GPTQ models.
562
+ if name.endswith(".bias") and name not in params_dict:
563
+ continue
564
+ # if is_pp_missing_parameter(name, self):
565
+ # continue
566
+
567
+ param = params_dict[name]
568
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
569
+
570
+ weight_loader(param, loaded_weight)
571
+
572
+ loaded_params.add(name)
573
+ return loaded_params
574
+
575
+
576
+ EntryClass = FalconH1ForCausalLM
@@ -20,7 +20,6 @@ import torch.nn.functional as F
20
20
  from torch import nn
21
21
  from transformers import (
22
22
  ROPE_INIT_FUNCTIONS,
23
- AutoModel,
24
23
  Gemma3TextConfig,
25
24
  PretrainedConfig,
26
25
  PreTrainedModel,
@@ -761,4 +760,3 @@ class Gemma3ForCausalLM(PreTrainedModel):
761
760
 
762
761
 
763
762
  EntryClass = Gemma3ForCausalLM
764
- AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
@@ -23,7 +23,6 @@ import torch
23
23
  from torch import nn
24
24
  from transformers import Gemma3Config, PreTrainedModel
25
25
 
26
- from sglang.srt.hf_transformers_utils import get_processor
27
26
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
28
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -44,6 +43,7 @@ from sglang.srt.model_loader.weight_utils import (
44
43
  from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
45
44
  from sglang.srt.models.siglip import SiglipVisionModel
46
45
  from sglang.srt.utils import add_prefix
46
+ from sglang.srt.utils.hf_transformers_utils import get_processor
47
47
 
48
48
  logger = logging.getLogger(__name__)
49
49
 
@@ -14,7 +14,6 @@ from transformers import (
14
14
  )
15
15
  from transformers.models.auto.modeling_auto import AutoModel
16
16
 
17
- from sglang.srt.hf_transformers_utils import get_processor
18
17
  from sglang.srt.layers.layernorm import RMSNorm
19
18
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
20
19
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -38,6 +37,7 @@ from sglang.srt.model_loader.weight_utils import (
38
37
  from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
39
38
  from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
40
39
  from sglang.srt.utils import add_prefix
40
+ from sglang.srt.utils.hf_transformers_utils import get_processor
41
41
 
42
42
  logger = logging.getLogger(__name__)
43
43
 
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
499
499
  def should_apply_lora(self, module_name: str) -> bool:
500
500
  return bool(self.lora_pattern.match(module_name))
501
501
 
502
- def get_hidden_dim(self, module_name):
502
+ def get_hidden_dim(self, module_name, layer_idx):
503
503
  # return input_dim, output_dim
504
504
  if module_name == "qkv_proj":
505
505
  return (
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- """Inference-only GLM-4.5 model compatible with HuggingFace weights"""
15
+ """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
16
16
 
17
17
  import logging
18
18
  from typing import Any, Dict, Iterable, Optional, Tuple
@@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module):
153
153
  )
154
154
  self.act_fn = SiluAndMul()
155
155
 
156
- def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
156
+ def forward(
157
+ self,
158
+ x,
159
+ forward_batch=None,
160
+ should_allreduce_fusion=False,
161
+ gemm_output_zero_allocator: BumpAllocator = None,
162
+ ):
157
163
  if (self.tp_size == 1) and x.shape[0] == 0:
158
164
  return x
159
165
 
@@ -423,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
423
429
  routed_scaling_factor=self.routed_scaling_factor,
424
430
  )
425
431
 
426
- self.experts = get_moe_impl_class()(
432
+ self.experts = get_moe_impl_class(quant_config)(
427
433
  num_experts=config.n_routed_experts
428
434
  + self.num_fused_shared_experts
429
435
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
501
507
  hidden_states: torch.Tensor,
502
508
  should_allreduce_fusion: bool = False,
503
509
  use_reduce_scatter: bool = False,
510
+ gemm_output_zero_allocator: BumpAllocator = None,
504
511
  ) -> torch.Tensor:
505
512
 
506
513
  current_stream = torch.cuda.current_stream()
@@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
543
550
  hidden_states: torch.Tensor,
544
551
  should_allreduce_fusion: bool = False,
545
552
  use_reduce_scatter: bool = False,
553
+ gemm_output_zero_allocator: BumpAllocator = None,
546
554
  ) -> torch.Tensor:
547
555
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
548
556
  self.shared_experts.gate_up_proj
@@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
666
674
  forward_batch: ForwardBatch,
667
675
  residual: Optional[torch.Tensor],
668
676
  zero_allocator: BumpAllocator,
677
+ gemm_output_zero_allocator: BumpAllocator = None,
669
678
  ) -> torch.Tensor:
670
679
  hidden_states, residual = self.layer_communicator.prepare_attn(
671
680
  hidden_states, residual, forward_batch
@@ -776,9 +785,9 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
776
785
  or self.config.architectures[0] != architecture
777
786
  or self.config.n_shared_experts != 1
778
787
  ):
779
- disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
788
+ disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
780
789
  elif get_moe_expert_parallel_world_size() > 1:
781
- disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
790
+ disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
782
791
 
783
792
  if disable_reason is not None:
784
793
  global_server_args_dict["disable_shared_experts_fusion"] = True