sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,578 @@
1
+ import enum
2
+ import logging
3
+ from typing import Any, Iterable, List, Optional, Set, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.falcon_h1 import FalconH1Config
9
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
10
+ from sglang.srt.layers.activation import SiluAndMul
11
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
12
+ HybridLinearAttnBackend,
13
+ Mamba2AttnBackend,
14
+ )
15
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
16
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
17
+ from sglang.srt.layers.dp_attention import (
18
+ get_attention_tp_rank,
19
+ get_attention_tp_size,
20
+ is_dp_attention_enabled,
21
+ )
22
+ from sglang.srt.layers.layernorm import RMSNorm
23
+ from sglang.srt.layers.linear import (
24
+ MergedColumnParallelLinear,
25
+ QKVParallelLinear,
26
+ RowParallelLinear,
27
+ )
28
+ from sglang.srt.layers.logits_processor import LogitsProcessor
29
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
+ from sglang.srt.layers.radix_attention import RadixAttention
31
+ from sglang.srt.layers.rotary_embedding import get_rope
32
+ from sglang.srt.layers.vocab_parallel_embedding import (
33
+ ParallelLMHead,
34
+ VocabParallelEmbedding,
35
+ )
36
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
40
+
41
+ logger = logging.getLogger(__name__)
42
+ _is_cuda = is_cuda()
43
+
44
+
45
+ class FalconH1MLP(nn.Module):
46
+ def __init__(
47
+ self,
48
+ hidden_size: int,
49
+ intermediate_size: int,
50
+ hidden_act: str,
51
+ layer_id: int,
52
+ mlp_multipliers: List[float],
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ reduce_results: bool = True,
56
+ ) -> None:
57
+ super().__init__()
58
+ self.gate_up_proj = MergedColumnParallelLinear(
59
+ hidden_size,
60
+ [intermediate_size] * 2,
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ prefix=add_prefix("gate_up_proj", prefix),
64
+ )
65
+ self.down_proj = RowParallelLinear(
66
+ intermediate_size,
67
+ hidden_size,
68
+ bias=False,
69
+ quant_config=quant_config,
70
+ prefix=add_prefix("down_proj", prefix),
71
+ reduce_results=reduce_results,
72
+ )
73
+ if hidden_act != "silu":
74
+ raise ValueError(
75
+ f"Unsupported activation: {hidden_act}. "
76
+ "Only silu is supported for now."
77
+ )
78
+ self.act_fn = SiluAndMul()
79
+ self.layer_id = layer_id
80
+
81
+ self.intermediate_size = intermediate_size
82
+ self.tp_size = get_tensor_model_parallel_world_size()
83
+
84
+ self.gate_multiplier, self.down_multiplier = mlp_multipliers
85
+
86
+ def forward(
87
+ self,
88
+ x,
89
+ forward_batch=None,
90
+ use_reduce_scatter: bool = False,
91
+ ):
92
+ gate_up, _ = self.gate_up_proj(x)
93
+ gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier
94
+
95
+ x = self.act_fn(gate_up)
96
+ x, _ = self.down_proj(
97
+ x,
98
+ skip_all_reduce=use_reduce_scatter,
99
+ )
100
+ x = x * self.down_multiplier
101
+ return x
102
+
103
+
104
+ class FalconH1HybridAttentionDecoderLayer(nn.Module):
105
+
106
+ def __init__(
107
+ self,
108
+ config: FalconH1Config,
109
+ layer_id: int,
110
+ quant_config: Optional[QuantizationConfig] = None,
111
+ prefix: str = "",
112
+ alt_stream: Optional[torch.cuda.Stream] = None,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.config = config
116
+ self.hidden_size = config.hidden_size
117
+ self.attn_tp_rank = get_attention_tp_rank()
118
+ self.attn_tp_size = get_attention_tp_size()
119
+ self.tp_size = get_tensor_model_parallel_world_size()
120
+ self.total_num_heads = config.num_attention_heads
121
+ assert self.total_num_heads % self.attn_tp_size == 0
122
+ self.num_heads = self.total_num_heads // self.attn_tp_size
123
+ self.total_num_kv_heads = config.num_key_value_heads
124
+ if self.total_num_kv_heads >= self.attn_tp_size:
125
+ # Number of KV heads is greater than TP size, so we partition
126
+ # the KV heads across multiple tensor parallel GPUs.
127
+ assert self.total_num_kv_heads % self.attn_tp_size == 0
128
+ else:
129
+ # Number of KV heads is less than TP size, so we replicate
130
+ # the KV heads across multiple tensor parallel GPUs.
131
+ assert self.attn_tp_size % self.total_num_kv_heads == 0
132
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
133
+ self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
134
+ self.q_size = self.num_heads * self.head_dim
135
+ self.kv_size = self.num_kv_heads * self.head_dim
136
+ self.scaling = self.head_dim**-0.5
137
+ self.rope_theta = getattr(config, "rope_theta", 10000)
138
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
139
+ self.rope_scaling = getattr(config, "rope_scaling", None)
140
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
141
+ self.layer_id = layer_id
142
+
143
+ self.rotary_emb = get_rope(
144
+ head_size=self.head_dim,
145
+ rotary_dim=self.head_dim,
146
+ max_position=self.max_position_embeddings,
147
+ rope_scaling=self.rope_scaling,
148
+ base=self.rope_theta,
149
+ partial_rotary_factor=self.partial_rotary_factor,
150
+ is_neox_style=True,
151
+ dtype=torch.get_default_dtype(), # see impl of get_rope
152
+ )
153
+
154
+ self.qkv_proj = QKVParallelLinear(
155
+ config.hidden_size,
156
+ self.head_dim,
157
+ self.total_num_heads,
158
+ self.total_num_kv_heads,
159
+ bias=False,
160
+ quant_config=quant_config,
161
+ tp_rank=self.attn_tp_rank,
162
+ tp_size=self.attn_tp_size,
163
+ )
164
+
165
+ self.o_proj = RowParallelLinear(
166
+ self.total_num_heads * self.head_dim,
167
+ config.hidden_size,
168
+ bias=False,
169
+ quant_config=quant_config,
170
+ reduce_results=False,
171
+ tp_rank=self.attn_tp_rank,
172
+ tp_size=self.attn_tp_size,
173
+ )
174
+
175
+ self.attn = RadixAttention(
176
+ self.num_heads,
177
+ self.head_dim,
178
+ self.scaling,
179
+ num_kv_heads=self.num_kv_heads,
180
+ layer_id=layer_id,
181
+ prefix=f"{prefix}.attn",
182
+ )
183
+
184
+ self.d_ssm = (
185
+ int(config.mamba_expand * config.hidden_size)
186
+ if config.mamba_d_ssm is None
187
+ else config.mamba_d_ssm
188
+ )
189
+
190
+ self.mamba = MambaMixer2(
191
+ cache_params=config.mamba2_cache_params,
192
+ hidden_size=config.hidden_size,
193
+ use_conv_bias=config.mamba_conv_bias,
194
+ use_bias=config.mamba_proj_bias,
195
+ n_groups=config.mamba_n_groups,
196
+ rms_norm_eps=config.rms_norm_eps,
197
+ activation=config.hidden_act,
198
+ use_rms_norm=config.mamba_rms_norm,
199
+ prefix=f"{prefix}.mixer",
200
+ )
201
+
202
+ # FalconH1 all layers are sparse and have no nextn now
203
+ self.is_layer_sparse = False
204
+ is_previous_layer_sparse = False
205
+
206
+ self.layer_scatter_modes = LayerScatterModes.init_new(
207
+ layer_id=layer_id,
208
+ num_layers=config.num_hidden_layers,
209
+ is_layer_sparse=self.is_layer_sparse,
210
+ is_previous_layer_sparse=is_previous_layer_sparse,
211
+ )
212
+
213
+ self.feed_forward = FalconH1MLP(
214
+ hidden_size=self.hidden_size,
215
+ intermediate_size=config.intermediate_size,
216
+ hidden_act=config.hidden_act,
217
+ layer_id=layer_id,
218
+ mlp_multipliers=config.mlp_multipliers,
219
+ quant_config=quant_config,
220
+ prefix=add_prefix("mlp", prefix),
221
+ )
222
+
223
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224
+ self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
225
+
226
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
227
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
228
+
229
+ self.layer_communicator = LayerCommunicator(
230
+ layer_scatter_modes=self.layer_scatter_modes,
231
+ input_layernorm=self.input_layernorm,
232
+ post_attention_layernorm=self.pre_ff_layernorm,
233
+ allow_reduce_scatter=True,
234
+ )
235
+
236
+ self.alt_stream = alt_stream
237
+ self.key_multiplier = config.key_multiplier
238
+
239
+ self.ssm_out_multiplier = config.ssm_out_multiplier
240
+ self.ssm_in_multiplier = config.ssm_in_multiplier
241
+
242
+ self.attention_in_multiplier = config.attention_in_multiplier
243
+ self.attn_out_multiplier = config.attention_out_multiplier
244
+
245
+ self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
246
+ self.zxbcdt_multipliers = config.ssm_multipliers
247
+ self._init_mup_vector()
248
+
249
+ def _init_mup_vector(self):
250
+ """
251
+ Non learnable per-block scaling vector composed of element-wise
252
+ multipliersapplied to each separate contiguous block of the output
253
+ of the linear projection (in_proj) before further processing
254
+ (gating, convolution, SSM):
255
+
256
+ - Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
257
+ - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
258
+ - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
259
+ - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
260
+ → zxbcdt_multipliers[3]
261
+ - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
262
+
263
+ where:
264
+ - d_ssm: Dimension of state-space model latent
265
+ - G: Number of groups (n_groups)
266
+ - S: SSM state size per group
267
+ - All indices are divided by tp_size to support tensor parallelism
268
+ """
269
+ vector_shape = (
270
+ 2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads
271
+ ) // self.tp_size
272
+ mup_vector = torch.ones(1, vector_shape)
273
+ # Z vector 0 -> d_ssm
274
+ mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0]
275
+ # X vector d_ssm -> 2 * d_ssm
276
+ mup_vector[
277
+ :, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size)
278
+ ] *= self.zxbcdt_multipliers[1]
279
+ # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
280
+ mup_vector[
281
+ :,
282
+ (2 * self.d_ssm)
283
+ // self.tp_size : (2 * self.d_ssm + self.groups_time_state_size)
284
+ // self.tp_size,
285
+ ] *= self.zxbcdt_multipliers[2]
286
+ # C vector 2 * d_ssm + (n_group * d_state)
287
+ # -> 2 * d_ssm + 2 * (n_group * d_state)
288
+ mup_vector[
289
+ :,
290
+ (2 * self.d_ssm + self.groups_time_state_size)
291
+ // self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size)
292
+ // self.tp_size,
293
+ ] *= self.zxbcdt_multipliers[3]
294
+ # dt vector 2 * d_ssm + 2 * (n_group * d_state)
295
+ # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
296
+ mup_vector[
297
+ :,
298
+ (2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :,
299
+ ] *= self.zxbcdt_multipliers[4]
300
+
301
+ self.register_buffer("mup_vector", mup_vector, persistent=False)
302
+
303
+ def self_attention(
304
+ self,
305
+ positions: torch.Tensor,
306
+ hidden_states: torch.Tensor,
307
+ forward_batch: ForwardBatch,
308
+ ) -> torch.Tensor:
309
+ qkv, _ = self.qkv_proj(hidden_states)
310
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
311
+ k = k * self.key_multiplier
312
+ q, k = self.rotary_emb(positions, q, k)
313
+
314
+ attn_output = self.attn(q, k, v, forward_batch)
315
+
316
+ output, _ = self.o_proj(attn_output)
317
+ return output
318
+
319
+ def forward(
320
+ self,
321
+ positions: torch.Tensor,
322
+ hidden_states: torch.Tensor,
323
+ residual: Optional[torch.Tensor],
324
+ forward_batch: ForwardBatch,
325
+ **kwargs: Any,
326
+ ):
327
+ hidden_states, residual = self.layer_communicator.prepare_attn(
328
+ hidden_states, residual, forward_batch
329
+ )
330
+
331
+ if not forward_batch.forward_mode.is_idle():
332
+ # Attention block
333
+ attention_hidden_states = self.self_attention(
334
+ positions=positions,
335
+ hidden_states=hidden_states * self.attention_in_multiplier,
336
+ forward_batch=forward_batch,
337
+ )
338
+ attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
339
+
340
+ attn_backend = forward_batch.attn_backend
341
+ assert isinstance(attn_backend, HybridLinearAttnBackend)
342
+ assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
343
+ # Mamba block
344
+ mamba_hidden_states = torch.empty_like(hidden_states)
345
+ attn_backend.linear_attn_backend.forward(
346
+ self.mamba,
347
+ hidden_states * self.ssm_in_multiplier,
348
+ mamba_hidden_states,
349
+ layer_id=self.layer_id,
350
+ mup_vector=self.mup_vector,
351
+ )
352
+ mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
353
+
354
+ hidden_states = attention_hidden_states + mamba_hidden_states
355
+
356
+ # Fully Connected
357
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
358
+ hidden_states, residual, forward_batch
359
+ )
360
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
361
+ forward_batch
362
+ )
363
+ hidden_states = self.feed_forward(
364
+ hidden_states, forward_batch, use_reduce_scatter
365
+ )
366
+
367
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
368
+ hidden_states, residual, forward_batch
369
+ )
370
+
371
+ return hidden_states, residual
372
+
373
+
374
+ ALL_DECODER_LAYER_TYPES = {
375
+ "falcon_h1": FalconH1HybridAttentionDecoderLayer,
376
+ }
377
+
378
+
379
+ class FalconH1Model(nn.Module):
380
+ def __init__(
381
+ self,
382
+ config: FalconH1Config,
383
+ quant_config: Optional[QuantizationConfig] = None,
384
+ prefix: str = "",
385
+ ) -> None:
386
+ super().__init__()
387
+ self.config = config
388
+
389
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
390
+ self.embedding_multiplier = config.embedding_multiplier
391
+
392
+ self.embed_tokens = VocabParallelEmbedding(
393
+ config.vocab_size,
394
+ config.hidden_size,
395
+ org_num_embeddings=config.vocab_size,
396
+ enable_tp=not is_dp_attention_enabled(),
397
+ )
398
+
399
+ def get_layer(idx: int, prefix: str):
400
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
401
+ return layer_class(
402
+ config,
403
+ idx,
404
+ quant_config=quant_config,
405
+ prefix=prefix,
406
+ alt_stream=alt_stream,
407
+ )
408
+
409
+ self.layers = make_layers(
410
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
411
+ )
412
+
413
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
414
+ self.infer_count = 0
415
+
416
+ def forward(
417
+ self,
418
+ input_ids: torch.Tensor,
419
+ positions: torch.Tensor,
420
+ forward_batch: ForwardBatch,
421
+ # mamba_cache_params: MambaCacheParams,
422
+ inputs_embeds: Optional[torch.Tensor] = None,
423
+ ) -> torch.Tensor:
424
+
425
+ # pass a sequence index tensor, that is required for
426
+ # proper continuous batching computation including
427
+ # chunked prefill
428
+ if inputs_embeds is not None:
429
+ hidden_states = inputs_embeds * self.embedding_multiplier
430
+ else:
431
+ hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier
432
+
433
+ residual = None
434
+ for i in range(len(self.layers)):
435
+ layer = self.layers[i]
436
+ hidden_states, residual = layer(
437
+ layer_id=i,
438
+ positions=positions,
439
+ hidden_states=hidden_states,
440
+ residual=residual,
441
+ forward_batch=forward_batch,
442
+ )
443
+
444
+ if not forward_batch.forward_mode.is_idle():
445
+ if residual is None:
446
+ hidden_states = self.final_layernorm(hidden_states)
447
+ else:
448
+ hidden_states, _ = self.final_layernorm(hidden_states, residual)
449
+
450
+ return hidden_states
451
+
452
+
453
+ class HybridLayerType(enum.Enum):
454
+ full_attention = "attention"
455
+ swa_attention = "swa_attention"
456
+ linear_attention = "linear_attention"
457
+ mamba2 = "mamba"
458
+
459
+
460
+ class FalconH1ForCausalLM(nn.Module):
461
+ fall_back_to_pt_during_load = False
462
+
463
+ def __init__(
464
+ self,
465
+ config: FalconH1Config,
466
+ quant_config: Optional[QuantizationConfig] = None,
467
+ prefix: str = "",
468
+ ) -> None:
469
+ super().__init__()
470
+ self.config = config
471
+ self.pp_group = get_pp_group()
472
+ assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
473
+ self.quant_config = quant_config
474
+ self.model = FalconH1Model(
475
+ config, quant_config, prefix=add_prefix("model", prefix)
476
+ )
477
+ if config.tie_word_embeddings:
478
+ self.lm_head = self.model.embed_tokens
479
+ else:
480
+ self.lm_head = ParallelLMHead(
481
+ config.vocab_size,
482
+ config.hidden_size,
483
+ quant_config=quant_config,
484
+ org_num_embeddings=config.vocab_size,
485
+ prefix=add_prefix("lm_head", prefix),
486
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
487
+ )
488
+ self.lm_head = self.lm_head.float()
489
+ self.lm_head_multiplier = config.lm_head_multiplier
490
+ self.logits_processor = LogitsProcessor(
491
+ config, logit_scale=self.lm_head_multiplier
492
+ )
493
+
494
+ @torch.no_grad()
495
+ def forward(
496
+ self,
497
+ input_ids: torch.Tensor,
498
+ positions: torch.Tensor,
499
+ forward_batch: ForwardBatch,
500
+ inputs_embeds: Optional[torch.Tensor] = None,
501
+ **kwargs,
502
+ ):
503
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
504
+
505
+ return self.logits_processor(
506
+ input_ids, hidden_states, self.lm_head, forward_batch
507
+ )
508
+
509
+ def get_embed_and_head(self):
510
+ return self.model.embed_tokens.weight, self.lm_head.weight
511
+
512
+ def set_embed_and_head(self, embed, head):
513
+ del self.model.embed_tokens.weight
514
+ del self.lm_head.weight
515
+ self.model.embed_tokens.weight = embed
516
+ self.lm_head.weight = head
517
+ torch.cuda.empty_cache()
518
+ torch.cuda.synchronize()
519
+
520
+ def load_weights(
521
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
522
+ ) -> Set[str]:
523
+ stacked_params_mapping = [
524
+ # (param_name, shard_name, shard_id)
525
+ ("qkv_proj", "q_proj", "q"),
526
+ ("qkv_proj", "k_proj", "k"),
527
+ ("qkv_proj", "v_proj", "v"),
528
+ ("gate_up_proj", "gate_proj", 0),
529
+ ("gate_up_proj", "up_proj", 1),
530
+ ]
531
+
532
+ params_dict = dict(self.named_parameters())
533
+ loaded_params: Set[str] = set()
534
+ for name, loaded_weight in weights:
535
+
536
+ if "rotary_emb.inv_freq" in name:
537
+ continue
538
+
539
+ if ".self_attn." in name:
540
+ name = name.replace(".self_attn", "")
541
+
542
+ if "A_log" in name:
543
+ name = name.replace("A_log", "A")
544
+
545
+ for param_name, weight_name, shard_id in stacked_params_mapping:
546
+ if weight_name not in name:
547
+ continue
548
+
549
+ name = name.replace(weight_name, param_name)
550
+ # Skip loading extra bias for GPTQ models.
551
+ if name.endswith(".bias") and name not in params_dict:
552
+ continue
553
+ # Skip layers on other devices.
554
+ # if is_pp_missing_parameter(name, self):
555
+ # continue
556
+ if name not in params_dict:
557
+ continue
558
+ param = params_dict[name]
559
+ weight_loader = getattr(param, "weight_loader")
560
+ weight_loader(param, loaded_weight, shard_id)
561
+ break
562
+ else:
563
+ # Skip loading extra bias for GPTQ models.
564
+ if name.endswith(".bias") and name not in params_dict:
565
+ continue
566
+ # if is_pp_missing_parameter(name, self):
567
+ # continue
568
+
569
+ param = params_dict[name]
570
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
571
+
572
+ weight_loader(param, loaded_weight)
573
+
574
+ loaded_params.add(name)
575
+ return loaded_params
576
+
577
+
578
+ EntryClass = FalconH1ForCausalLM
@@ -20,7 +20,6 @@ import torch.nn.functional as F
20
20
  from torch import nn
21
21
  from transformers import (
22
22
  ROPE_INIT_FUNCTIONS,
23
- AutoModel,
24
23
  Gemma3TextConfig,
25
24
  PretrainedConfig,
26
25
  PreTrainedModel,
@@ -761,4 +760,3 @@ class Gemma3ForCausalLM(PreTrainedModel):
761
760
 
762
761
 
763
762
  EntryClass = Gemma3ForCausalLM
764
- AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
17
17
 
18
18
  import logging
19
+ import re
19
20
  from functools import lru_cache
20
21
  from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
22
 
@@ -23,7 +24,6 @@ import torch
23
24
  from torch import nn
24
25
  from transformers import Gemma3Config, PreTrainedModel
25
26
 
26
- from sglang.srt.hf_transformers_utils import get_processor
27
27
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
29
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -44,6 +44,7 @@ from sglang.srt.model_loader.weight_utils import (
44
44
  from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
45
45
  from sglang.srt.models.siglip import SiglipVisionModel
46
46
  from sglang.srt.utils import add_prefix
47
+ from sglang.srt.utils.hf_transformers_utils import get_processor
47
48
 
48
49
  logger = logging.getLogger(__name__)
49
50
 
@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
154
155
  embedding_modules = {}
155
156
  embedding_padding_modules = []
156
157
  supports_lora = True
158
+ # Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
159
+ lora_pattern = re.compile(
160
+ r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
161
+ )
157
162
 
158
163
  def __init__(
159
164
  self,
@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
165
170
  self.config = config
166
171
  self.quant_config = quant_config
167
172
 
173
+ # For LoRA compatibility: expose text_config attributes at top level
174
+ # This allows LoRA code to work without special multimodal handling
175
+ if not hasattr(config, "num_hidden_layers"):
176
+ config.num_hidden_layers = config.text_config.num_hidden_layers
177
+ if not hasattr(config, "hidden_size"):
178
+ config.hidden_size = config.text_config.hidden_size
179
+
168
180
  self.vision_tower = SiglipVisionModel(
169
181
  config=config.vision_config,
170
182
  quant_config=quant_config,
@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
380
392
 
381
393
  return hs
382
394
 
395
+ def should_apply_lora(self, module_name: str) -> bool:
396
+ """Skip vision tower and multi_modal_projector for LoRA."""
397
+ return bool(self.lora_pattern.match(module_name))
398
+
383
399
  def tie_weights(self):
384
400
  return self.language_model.tie_weights()
385
401
 
@@ -14,7 +14,6 @@ from transformers import (
14
14
  )
15
15
  from transformers.models.auto.modeling_auto import AutoModel
16
16
 
17
- from sglang.srt.hf_transformers_utils import get_processor
18
17
  from sglang.srt.layers.layernorm import RMSNorm
19
18
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
20
19
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -38,6 +37,7 @@ from sglang.srt.model_loader.weight_utils import (
38
37
  from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
39
38
  from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
40
39
  from sglang.srt.utils import add_prefix
40
+ from sglang.srt.utils.hf_transformers_utils import get_processor
41
41
 
42
42
  logger = logging.getLogger(__name__)
43
43
 
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
499
499
  def should_apply_lora(self, module_name: str) -> bool:
500
500
  return bool(self.lora_pattern.match(module_name))
501
501
 
502
- def get_hidden_dim(self, module_name):
502
+ def get_hidden_dim(self, module_name, layer_idx):
503
503
  # return input_dim, output_dim
504
504
  if module_name == "qkv_proj":
505
505
  return (
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- """Inference-only GLM-4.5 model compatible with HuggingFace weights"""
15
+ """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
16
16
 
17
17
  import logging
18
18
  from typing import Any, Dict, Iterable, Optional, Tuple
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
429
429
  routed_scaling_factor=self.routed_scaling_factor,
430
430
  )
431
431
 
432
- self.experts = get_moe_impl_class()(
432
+ self.experts = get_moe_impl_class(quant_config)(
433
433
  num_experts=config.n_routed_experts
434
434
  + self.num_fused_shared_experts
435
435
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -785,9 +785,9 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
785
785
  or self.config.architectures[0] != architecture
786
786
  or self.config.n_shared_experts != 1
787
787
  ):
788
- disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
788
+ disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
789
789
  elif get_moe_expert_parallel_world_size() > 1:
790
- disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
790
+ disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
791
791
 
792
792
  if disable_reason is not None:
793
793
  global_server_args_dict["disable_shared_experts_fusion"] = True