sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,570 @@
1
+ import logging
2
+ from typing import Any, Iterable, List, Optional, Set, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from sglang.srt.configs.falcon_h1 import FalconH1Config
8
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
9
+ from sglang.srt.layers.activation import SiluAndMul
10
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
11
+ HybridLinearAttnBackend,
12
+ Mamba2AttnBackend,
13
+ )
14
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
15
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
16
+ from sglang.srt.layers.dp_attention import (
17
+ get_attention_tp_rank,
18
+ get_attention_tp_size,
19
+ is_dp_attention_enabled,
20
+ )
21
+ from sglang.srt.layers.layernorm import RMSNorm
22
+ from sglang.srt.layers.linear import (
23
+ MergedColumnParallelLinear,
24
+ QKVParallelLinear,
25
+ RowParallelLinear,
26
+ )
27
+ from sglang.srt.layers.logits_processor import LogitsProcessor
28
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
29
+ from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.layers.rotary_embedding import get_rope
31
+ from sglang.srt.layers.vocab_parallel_embedding import (
32
+ ParallelLMHead,
33
+ VocabParallelEmbedding,
34
+ )
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
37
+ from sglang.srt.server_args import get_global_server_args
38
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
39
+
40
+ logger = logging.getLogger(__name__)
41
+ _is_cuda = is_cuda()
42
+
43
+
44
+ class FalconH1MLP(nn.Module):
45
+ def __init__(
46
+ self,
47
+ hidden_size: int,
48
+ intermediate_size: int,
49
+ hidden_act: str,
50
+ layer_id: int,
51
+ mlp_multipliers: List[float],
52
+ quant_config: Optional[QuantizationConfig] = None,
53
+ prefix: str = "",
54
+ reduce_results: bool = True,
55
+ ) -> None:
56
+ super().__init__()
57
+ self.gate_up_proj = MergedColumnParallelLinear(
58
+ hidden_size,
59
+ [intermediate_size] * 2,
60
+ bias=False,
61
+ quant_config=quant_config,
62
+ prefix=add_prefix("gate_up_proj", prefix),
63
+ )
64
+ self.down_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ prefix=add_prefix("down_proj", prefix),
70
+ reduce_results=reduce_results,
71
+ )
72
+ if hidden_act != "silu":
73
+ raise ValueError(
74
+ f"Unsupported activation: {hidden_act}. "
75
+ "Only silu is supported for now."
76
+ )
77
+ self.act_fn = SiluAndMul()
78
+ self.layer_id = layer_id
79
+
80
+ self.intermediate_size = intermediate_size
81
+ self.tp_size = get_tensor_model_parallel_world_size()
82
+
83
+ self.gate_multiplier, self.down_multiplier = mlp_multipliers
84
+
85
+ def forward(
86
+ self,
87
+ x,
88
+ forward_batch=None,
89
+ use_reduce_scatter: bool = False,
90
+ ):
91
+ gate_up, _ = self.gate_up_proj(x)
92
+ gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier
93
+
94
+ x = self.act_fn(gate_up)
95
+ x, _ = self.down_proj(
96
+ x,
97
+ skip_all_reduce=use_reduce_scatter,
98
+ )
99
+ x = x * self.down_multiplier
100
+ return x
101
+
102
+
103
+ class FalconH1HybridAttentionDecoderLayer(nn.Module):
104
+
105
+ def __init__(
106
+ self,
107
+ config: FalconH1Config,
108
+ layer_id: int,
109
+ quant_config: Optional[QuantizationConfig] = None,
110
+ prefix: str = "",
111
+ alt_stream: Optional[torch.cuda.Stream] = None,
112
+ ) -> None:
113
+ super().__init__()
114
+ self.config = config
115
+ self.hidden_size = config.hidden_size
116
+ self.attn_tp_rank = get_attention_tp_rank()
117
+ self.attn_tp_size = get_attention_tp_size()
118
+ self.tp_size = get_tensor_model_parallel_world_size()
119
+ self.total_num_heads = config.num_attention_heads
120
+ assert self.total_num_heads % self.attn_tp_size == 0
121
+ self.num_heads = self.total_num_heads // self.attn_tp_size
122
+ self.total_num_kv_heads = config.num_key_value_heads
123
+ if self.total_num_kv_heads >= self.attn_tp_size:
124
+ # Number of KV heads is greater than TP size, so we partition
125
+ # the KV heads across multiple tensor parallel GPUs.
126
+ assert self.total_num_kv_heads % self.attn_tp_size == 0
127
+ else:
128
+ # Number of KV heads is less than TP size, so we replicate
129
+ # the KV heads across multiple tensor parallel GPUs.
130
+ assert self.attn_tp_size % self.total_num_kv_heads == 0
131
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
132
+ self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
133
+ self.q_size = self.num_heads * self.head_dim
134
+ self.kv_size = self.num_kv_heads * self.head_dim
135
+ self.scaling = self.head_dim**-0.5
136
+ self.rope_theta = getattr(config, "rope_theta", 10000)
137
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
138
+ self.rope_scaling = getattr(config, "rope_scaling", None)
139
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
140
+ self.layer_id = layer_id
141
+
142
+ self.rotary_emb = get_rope(
143
+ head_size=self.head_dim,
144
+ rotary_dim=self.head_dim,
145
+ max_position=self.max_position_embeddings,
146
+ rope_scaling=self.rope_scaling,
147
+ base=self.rope_theta,
148
+ partial_rotary_factor=self.partial_rotary_factor,
149
+ is_neox_style=True,
150
+ dtype=torch.get_default_dtype(), # see impl of get_rope
151
+ )
152
+
153
+ self.qkv_proj = QKVParallelLinear(
154
+ config.hidden_size,
155
+ self.head_dim,
156
+ self.total_num_heads,
157
+ self.total_num_kv_heads,
158
+ bias=False,
159
+ quant_config=quant_config,
160
+ tp_rank=self.attn_tp_rank,
161
+ tp_size=self.attn_tp_size,
162
+ )
163
+
164
+ self.o_proj = RowParallelLinear(
165
+ self.total_num_heads * self.head_dim,
166
+ config.hidden_size,
167
+ bias=False,
168
+ quant_config=quant_config,
169
+ reduce_results=False,
170
+ tp_rank=self.attn_tp_rank,
171
+ tp_size=self.attn_tp_size,
172
+ )
173
+
174
+ self.attn = RadixAttention(
175
+ self.num_heads,
176
+ self.head_dim,
177
+ self.scaling,
178
+ num_kv_heads=self.num_kv_heads,
179
+ layer_id=layer_id,
180
+ prefix=f"{prefix}.attn",
181
+ )
182
+
183
+ self.d_ssm = (
184
+ int(config.mamba_expand * config.hidden_size)
185
+ if config.mamba_d_ssm is None
186
+ else config.mamba_d_ssm
187
+ )
188
+
189
+ self.mamba = MambaMixer2(
190
+ cache_params=config.mamba2_cache_params,
191
+ hidden_size=config.hidden_size,
192
+ use_conv_bias=config.mamba_conv_bias,
193
+ use_bias=config.mamba_proj_bias,
194
+ n_groups=config.mamba_n_groups,
195
+ rms_norm_eps=config.rms_norm_eps,
196
+ activation=config.hidden_act,
197
+ use_rms_norm=config.mamba_rms_norm,
198
+ prefix=f"{prefix}.mixer",
199
+ )
200
+
201
+ # FalconH1 all layers are sparse and have no nextn now
202
+ self.is_layer_sparse = False
203
+ is_previous_layer_sparse = False
204
+
205
+ self.layer_scatter_modes = LayerScatterModes.init_new(
206
+ layer_id=layer_id,
207
+ num_layers=config.num_hidden_layers,
208
+ is_layer_sparse=self.is_layer_sparse,
209
+ is_previous_layer_sparse=is_previous_layer_sparse,
210
+ )
211
+
212
+ self.feed_forward = FalconH1MLP(
213
+ hidden_size=self.hidden_size,
214
+ intermediate_size=config.intermediate_size,
215
+ hidden_act=config.hidden_act,
216
+ layer_id=layer_id,
217
+ mlp_multipliers=config.mlp_multipliers,
218
+ quant_config=quant_config,
219
+ prefix=add_prefix("mlp", prefix),
220
+ )
221
+
222
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
223
+ self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224
+
225
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
226
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
227
+
228
+ self.layer_communicator = LayerCommunicator(
229
+ layer_scatter_modes=self.layer_scatter_modes,
230
+ input_layernorm=self.input_layernorm,
231
+ post_attention_layernorm=self.pre_ff_layernorm,
232
+ allow_reduce_scatter=True,
233
+ )
234
+
235
+ self.alt_stream = alt_stream
236
+ self.key_multiplier = config.key_multiplier
237
+
238
+ self.ssm_out_multiplier = config.ssm_out_multiplier
239
+ self.ssm_in_multiplier = config.ssm_in_multiplier
240
+
241
+ self.attention_in_multiplier = config.attention_in_multiplier
242
+ self.attn_out_multiplier = config.attention_out_multiplier
243
+
244
+ self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
245
+ self.zxbcdt_multipliers = config.ssm_multipliers
246
+ self._init_mup_vector()
247
+
248
+ def _init_mup_vector(self):
249
+ """
250
+ Non learnable per-block scaling vector composed of element-wise
251
+ multipliersapplied to each separate contiguous block of the output
252
+ of the linear projection (in_proj) before further processing
253
+ (gating, convolution, SSM):
254
+
255
+ - Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
256
+ - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
257
+ - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
258
+ - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
259
+ → zxbcdt_multipliers[3]
260
+ - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
261
+
262
+ where:
263
+ - d_ssm: Dimension of state-space model latent
264
+ - G: Number of groups (n_groups)
265
+ - S: SSM state size per group
266
+ - All indices are divided by tp_size to support tensor parallelism
267
+ """
268
+ vector_shape = (
269
+ 2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads
270
+ ) // self.tp_size
271
+ mup_vector = torch.ones(1, vector_shape)
272
+ # Z vector 0 -> d_ssm
273
+ mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0]
274
+ # X vector d_ssm -> 2 * d_ssm
275
+ mup_vector[
276
+ :, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size)
277
+ ] *= self.zxbcdt_multipliers[1]
278
+ # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
279
+ mup_vector[
280
+ :,
281
+ (2 * self.d_ssm)
282
+ // self.tp_size : (2 * self.d_ssm + self.groups_time_state_size)
283
+ // self.tp_size,
284
+ ] *= self.zxbcdt_multipliers[2]
285
+ # C vector 2 * d_ssm + (n_group * d_state)
286
+ # -> 2 * d_ssm + 2 * (n_group * d_state)
287
+ mup_vector[
288
+ :,
289
+ (2 * self.d_ssm + self.groups_time_state_size)
290
+ // self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size)
291
+ // self.tp_size,
292
+ ] *= self.zxbcdt_multipliers[3]
293
+ # dt vector 2 * d_ssm + 2 * (n_group * d_state)
294
+ # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
295
+ mup_vector[
296
+ :,
297
+ (2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :,
298
+ ] *= self.zxbcdt_multipliers[4]
299
+
300
+ self.register_buffer("mup_vector", mup_vector, persistent=False)
301
+
302
+ def self_attention(
303
+ self,
304
+ positions: torch.Tensor,
305
+ hidden_states: torch.Tensor,
306
+ forward_batch: ForwardBatch,
307
+ ) -> torch.Tensor:
308
+ qkv, _ = self.qkv_proj(hidden_states)
309
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
310
+ k = k * self.key_multiplier
311
+ q, k = self.rotary_emb(positions, q, k)
312
+
313
+ attn_output = self.attn(q, k, v, forward_batch)
314
+
315
+ output, _ = self.o_proj(attn_output)
316
+ return output
317
+
318
+ def forward(
319
+ self,
320
+ positions: torch.Tensor,
321
+ hidden_states: torch.Tensor,
322
+ residual: Optional[torch.Tensor],
323
+ forward_batch: ForwardBatch,
324
+ **kwargs: Any,
325
+ ):
326
+ hidden_states, residual = self.layer_communicator.prepare_attn(
327
+ hidden_states, residual, forward_batch
328
+ )
329
+
330
+ if not forward_batch.forward_mode.is_idle():
331
+ # Attention block
332
+ attention_hidden_states = self.self_attention(
333
+ positions=positions,
334
+ hidden_states=hidden_states * self.attention_in_multiplier,
335
+ forward_batch=forward_batch,
336
+ )
337
+ attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
338
+
339
+ attn_backend = forward_batch.attn_backend
340
+ assert isinstance(attn_backend, HybridLinearAttnBackend)
341
+ assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
342
+ # Mamba block
343
+ mamba_hidden_states = torch.empty_like(hidden_states)
344
+ attn_backend.linear_attn_backend.forward(
345
+ self.mamba,
346
+ hidden_states * self.ssm_in_multiplier,
347
+ mamba_hidden_states,
348
+ layer_id=self.layer_id,
349
+ mup_vector=self.mup_vector,
350
+ )
351
+ mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
352
+
353
+ hidden_states = attention_hidden_states + mamba_hidden_states
354
+
355
+ # Fully Connected
356
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
357
+ hidden_states, residual, forward_batch
358
+ )
359
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
360
+ forward_batch
361
+ )
362
+ hidden_states = self.feed_forward(
363
+ hidden_states, forward_batch, use_reduce_scatter
364
+ )
365
+
366
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
367
+ hidden_states, residual, forward_batch
368
+ )
369
+
370
+ return hidden_states, residual
371
+
372
+
373
+ ALL_DECODER_LAYER_TYPES = {
374
+ "falcon_h1": FalconH1HybridAttentionDecoderLayer,
375
+ }
376
+
377
+
378
+ class FalconH1Model(nn.Module):
379
+ def __init__(
380
+ self,
381
+ config: FalconH1Config,
382
+ quant_config: Optional[QuantizationConfig] = None,
383
+ prefix: str = "",
384
+ ) -> None:
385
+ super().__init__()
386
+ self.config = config
387
+
388
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
389
+ self.embedding_multiplier = config.embedding_multiplier
390
+
391
+ self.embed_tokens = VocabParallelEmbedding(
392
+ config.vocab_size,
393
+ config.hidden_size,
394
+ org_num_embeddings=config.vocab_size,
395
+ enable_tp=not is_dp_attention_enabled(),
396
+ )
397
+
398
+ def get_layer(idx: int, prefix: str):
399
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
400
+ return layer_class(
401
+ config,
402
+ idx,
403
+ quant_config=quant_config,
404
+ prefix=prefix,
405
+ alt_stream=alt_stream,
406
+ )
407
+
408
+ self.layers = make_layers(
409
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
410
+ )
411
+
412
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
413
+ self.infer_count = 0
414
+
415
+ def forward(
416
+ self,
417
+ input_ids: torch.Tensor,
418
+ positions: torch.Tensor,
419
+ forward_batch: ForwardBatch,
420
+ # mamba_cache_params: MambaCacheParams,
421
+ inputs_embeds: Optional[torch.Tensor] = None,
422
+ ) -> torch.Tensor:
423
+
424
+ # pass a sequence index tensor, that is required for
425
+ # proper continuous batching computation including
426
+ # chunked prefill
427
+ if inputs_embeds is not None:
428
+ hidden_states = inputs_embeds * self.embedding_multiplier
429
+ else:
430
+ hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier
431
+
432
+ residual = None
433
+ for i in range(len(self.layers)):
434
+ layer = self.layers[i]
435
+ hidden_states, residual = layer(
436
+ layer_id=i,
437
+ positions=positions,
438
+ hidden_states=hidden_states,
439
+ residual=residual,
440
+ forward_batch=forward_batch,
441
+ )
442
+
443
+ if not forward_batch.forward_mode.is_idle():
444
+ if residual is None:
445
+ hidden_states = self.final_layernorm(hidden_states)
446
+ else:
447
+ hidden_states, _ = self.final_layernorm(hidden_states, residual)
448
+
449
+ return hidden_states
450
+
451
+
452
+ class FalconH1ForCausalLM(nn.Module):
453
+ fall_back_to_pt_during_load = False
454
+
455
+ def __init__(
456
+ self,
457
+ config: FalconH1Config,
458
+ quant_config: Optional[QuantizationConfig] = None,
459
+ prefix: str = "",
460
+ ) -> None:
461
+ super().__init__()
462
+ self.config = config
463
+ self.pp_group = get_pp_group()
464
+ assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
465
+ self.quant_config = quant_config
466
+ self.model = FalconH1Model(
467
+ config, quant_config, prefix=add_prefix("model", prefix)
468
+ )
469
+ if config.tie_word_embeddings:
470
+ self.lm_head = self.model.embed_tokens
471
+ else:
472
+ self.lm_head = ParallelLMHead(
473
+ config.vocab_size,
474
+ config.hidden_size,
475
+ quant_config=quant_config,
476
+ org_num_embeddings=config.vocab_size,
477
+ prefix=add_prefix("lm_head", prefix),
478
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
479
+ )
480
+ self.lm_head = self.lm_head.float()
481
+ self.lm_head_multiplier = config.lm_head_multiplier
482
+ self.logits_processor = LogitsProcessor(
483
+ config, logit_scale=self.lm_head_multiplier
484
+ )
485
+
486
+ @torch.no_grad()
487
+ def forward(
488
+ self,
489
+ input_ids: torch.Tensor,
490
+ positions: torch.Tensor,
491
+ forward_batch: ForwardBatch,
492
+ inputs_embeds: Optional[torch.Tensor] = None,
493
+ **kwargs,
494
+ ):
495
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
496
+
497
+ return self.logits_processor(
498
+ input_ids, hidden_states, self.lm_head, forward_batch
499
+ )
500
+
501
+ def get_embed_and_head(self):
502
+ return self.model.embed_tokens.weight, self.lm_head.weight
503
+
504
+ def set_embed_and_head(self, embed, head):
505
+ del self.model.embed_tokens.weight
506
+ del self.lm_head.weight
507
+ self.model.embed_tokens.weight = embed
508
+ self.lm_head.weight = head
509
+ torch.cuda.empty_cache()
510
+ torch.cuda.synchronize()
511
+
512
+ def load_weights(
513
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
514
+ ) -> Set[str]:
515
+ stacked_params_mapping = [
516
+ # (param_name, shard_name, shard_id)
517
+ ("qkv_proj", "q_proj", "q"),
518
+ ("qkv_proj", "k_proj", "k"),
519
+ ("qkv_proj", "v_proj", "v"),
520
+ ("gate_up_proj", "gate_proj", 0),
521
+ ("gate_up_proj", "up_proj", 1),
522
+ ]
523
+
524
+ params_dict = dict(self.named_parameters())
525
+ loaded_params: Set[str] = set()
526
+ for name, loaded_weight in weights:
527
+
528
+ if "rotary_emb.inv_freq" in name:
529
+ continue
530
+
531
+ if ".self_attn." in name:
532
+ name = name.replace(".self_attn", "")
533
+
534
+ if "A_log" in name:
535
+ name = name.replace("A_log", "A")
536
+
537
+ for param_name, weight_name, shard_id in stacked_params_mapping:
538
+ if weight_name not in name:
539
+ continue
540
+
541
+ name = name.replace(weight_name, param_name)
542
+ # Skip loading extra bias for GPTQ models.
543
+ if name.endswith(".bias") and name not in params_dict:
544
+ continue
545
+ # Skip layers on other devices.
546
+ # if is_pp_missing_parameter(name, self):
547
+ # continue
548
+ if name not in params_dict:
549
+ continue
550
+ param = params_dict[name]
551
+ weight_loader = getattr(param, "weight_loader")
552
+ weight_loader(param, loaded_weight, shard_id)
553
+ break
554
+ else:
555
+ # Skip loading extra bias for GPTQ models.
556
+ if name.endswith(".bias") and name not in params_dict:
557
+ continue
558
+ # if is_pp_missing_parameter(name, self):
559
+ # continue
560
+
561
+ param = params_dict[name]
562
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
563
+
564
+ weight_loader(param, loaded_weight)
565
+
566
+ loaded_params.add(name)
567
+ return loaded_params
568
+
569
+
570
+ EntryClass = FalconH1ForCausalLM
@@ -20,7 +20,6 @@ import torch.nn.functional as F
20
20
  from torch import nn
21
21
  from transformers import (
22
22
  ROPE_INIT_FUNCTIONS,
23
- AutoModel,
24
23
  Gemma3TextConfig,
25
24
  PretrainedConfig,
26
25
  PreTrainedModel,
@@ -761,4 +760,3 @@ class Gemma3ForCausalLM(PreTrainedModel):
761
760
 
762
761
 
763
762
  EntryClass = Gemma3ForCausalLM
764
- AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
17
17
 
18
18
  import logging
19
+ import re
19
20
  from functools import lru_cache
20
21
  from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
22
 
@@ -23,7 +24,6 @@ import torch
23
24
  from torch import nn
24
25
  from transformers import Gemma3Config, PreTrainedModel
25
26
 
26
- from sglang.srt.hf_transformers_utils import get_processor
27
27
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
29
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -44,6 +44,7 @@ from sglang.srt.model_loader.weight_utils import (
44
44
  from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
45
45
  from sglang.srt.models.siglip import SiglipVisionModel
46
46
  from sglang.srt.utils import add_prefix
47
+ from sglang.srt.utils.hf_transformers_utils import get_processor
47
48
 
48
49
  logger = logging.getLogger(__name__)
49
50
 
@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
154
155
  embedding_modules = {}
155
156
  embedding_padding_modules = []
156
157
  supports_lora = True
158
+ # Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
159
+ lora_pattern = re.compile(
160
+ r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
161
+ )
157
162
 
158
163
  def __init__(
159
164
  self,
@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
165
170
  self.config = config
166
171
  self.quant_config = quant_config
167
172
 
173
+ # For LoRA compatibility: expose text_config attributes at top level
174
+ # This allows LoRA code to work without special multimodal handling
175
+ if not hasattr(config, "num_hidden_layers"):
176
+ config.num_hidden_layers = config.text_config.num_hidden_layers
177
+ if not hasattr(config, "hidden_size"):
178
+ config.hidden_size = config.text_config.hidden_size
179
+
168
180
  self.vision_tower = SiglipVisionModel(
169
181
  config=config.vision_config,
170
182
  quant_config=quant_config,
@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
380
392
 
381
393
  return hs
382
394
 
395
+ def should_apply_lora(self, module_name: str) -> bool:
396
+ """Skip vision tower and multi_modal_projector for LoRA."""
397
+ return bool(self.lora_pattern.match(module_name))
398
+
383
399
  def tie_weights(self):
384
400
  return self.language_model.tie_weights()
385
401
 
@@ -14,9 +14,7 @@ from transformers import (
14
14
  )
15
15
  from transformers.models.auto.modeling_auto import AutoModel
16
16
 
17
- from sglang.srt.hf_transformers_utils import get_processor
18
- from sglang.srt.layers.layernorm import RMSNorm
19
- from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
17
+ from sglang.srt.layers.linear import RowParallelLinear
20
18
  from sglang.srt.layers.logits_processor import LogitsProcessor
21
19
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
20
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -38,6 +36,7 @@ from sglang.srt.model_loader.weight_utils import (
38
36
  from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
39
37
  from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
40
38
  from sglang.srt.utils import add_prefix
39
+ from sglang.srt.utils.hf_transformers_utils import get_processor
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42