sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,576 @@
1
+ import enum
2
+ import logging
3
+ from typing import Any, Iterable, List, Optional, Set, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.falcon_h1 import FalconH1Config
9
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
10
+ from sglang.srt.layers.activation import SiluAndMul
11
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
12
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
13
+ from sglang.srt.layers.dp_attention import (
14
+ get_attention_tp_rank,
15
+ get_attention_tp_size,
16
+ is_dp_attention_enabled,
17
+ )
18
+ from sglang.srt.layers.layernorm import RMSNorm
19
+ from sglang.srt.layers.linear import (
20
+ MergedColumnParallelLinear,
21
+ QKVParallelLinear,
22
+ RowParallelLinear,
23
+ )
24
+ from sglang.srt.layers.logits_processor import LogitsProcessor
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+ from sglang.srt.layers.radix_attention import RadixAttention
27
+ from sglang.srt.layers.rotary_embedding import get_rope
28
+ from sglang.srt.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead,
30
+ VocabParallelEmbedding,
31
+ )
32
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
33
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
35
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
36
+
37
+ logger = logging.getLogger(__name__)
38
+ _is_cuda = is_cuda()
39
+
40
+
41
+ class FalconH1MLP(nn.Module):
42
+ def __init__(
43
+ self,
44
+ hidden_size: int,
45
+ intermediate_size: int,
46
+ hidden_act: str,
47
+ layer_id: int,
48
+ mlp_multipliers: List[float],
49
+ quant_config: Optional[QuantizationConfig] = None,
50
+ prefix: str = "",
51
+ reduce_results: bool = True,
52
+ ) -> None:
53
+ super().__init__()
54
+ self.gate_up_proj = MergedColumnParallelLinear(
55
+ hidden_size,
56
+ [intermediate_size] * 2,
57
+ bias=False,
58
+ quant_config=quant_config,
59
+ prefix=add_prefix("gate_up_proj", prefix),
60
+ )
61
+ self.down_proj = RowParallelLinear(
62
+ intermediate_size,
63
+ hidden_size,
64
+ bias=False,
65
+ quant_config=quant_config,
66
+ prefix=add_prefix("down_proj", prefix),
67
+ reduce_results=reduce_results,
68
+ )
69
+ if hidden_act != "silu":
70
+ raise ValueError(
71
+ f"Unsupported activation: {hidden_act}. "
72
+ "Only silu is supported for now."
73
+ )
74
+ self.act_fn = SiluAndMul()
75
+ self.layer_id = layer_id
76
+
77
+ self.intermediate_size = intermediate_size
78
+ self.tp_size = get_tensor_model_parallel_world_size()
79
+
80
+ self.gate_multiplier, self.down_multiplier = mlp_multipliers
81
+
82
+ def forward(
83
+ self,
84
+ x,
85
+ forward_batch=None,
86
+ use_reduce_scatter: bool = False,
87
+ ):
88
+ gate_up, _ = self.gate_up_proj(x)
89
+ gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier
90
+
91
+ x = self.act_fn(gate_up)
92
+ x, _ = self.down_proj(
93
+ x,
94
+ skip_all_reduce=use_reduce_scatter,
95
+ )
96
+ x = x * self.down_multiplier
97
+ return x
98
+
99
+
100
+ class FalconH1HybridAttentionDecoderLayer(nn.Module):
101
+
102
+ def __init__(
103
+ self,
104
+ config: FalconH1Config,
105
+ layer_id: int,
106
+ quant_config: Optional[QuantizationConfig] = None,
107
+ prefix: str = "",
108
+ alt_stream: Optional[torch.cuda.Stream] = None,
109
+ ) -> None:
110
+ super().__init__()
111
+ self.config = config
112
+ self.hidden_size = config.hidden_size
113
+ self.attn_tp_rank = get_attention_tp_rank()
114
+ self.attn_tp_size = get_attention_tp_size()
115
+ self.tp_size = get_tensor_model_parallel_world_size()
116
+ self.total_num_heads = config.num_attention_heads
117
+ assert self.total_num_heads % self.attn_tp_size == 0
118
+ self.num_heads = self.total_num_heads // self.attn_tp_size
119
+ self.total_num_kv_heads = config.num_key_value_heads
120
+ if self.total_num_kv_heads >= self.attn_tp_size:
121
+ # Number of KV heads is greater than TP size, so we partition
122
+ # the KV heads across multiple tensor parallel GPUs.
123
+ assert self.total_num_kv_heads % self.attn_tp_size == 0
124
+ else:
125
+ # Number of KV heads is less than TP size, so we replicate
126
+ # the KV heads across multiple tensor parallel GPUs.
127
+ assert self.attn_tp_size % self.total_num_kv_heads == 0
128
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
129
+ self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
130
+ self.q_size = self.num_heads * self.head_dim
131
+ self.kv_size = self.num_kv_heads * self.head_dim
132
+ self.scaling = self.head_dim**-0.5
133
+ self.rope_theta = getattr(config, "rope_theta", 10000)
134
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
135
+ self.rope_scaling = getattr(config, "rope_scaling", None)
136
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
137
+ self.layer_id = layer_id
138
+
139
+ self.rotary_emb = get_rope(
140
+ head_size=self.head_dim,
141
+ rotary_dim=self.head_dim,
142
+ max_position=self.max_position_embeddings,
143
+ rope_scaling=self.rope_scaling,
144
+ base=self.rope_theta,
145
+ partial_rotary_factor=self.partial_rotary_factor,
146
+ is_neox_style=True,
147
+ dtype=torch.get_default_dtype(), # see impl of get_rope
148
+ )
149
+
150
+ self.qkv_proj = QKVParallelLinear(
151
+ config.hidden_size,
152
+ self.head_dim,
153
+ self.total_num_heads,
154
+ self.total_num_kv_heads,
155
+ bias=False,
156
+ quant_config=quant_config,
157
+ tp_rank=self.attn_tp_rank,
158
+ tp_size=self.attn_tp_size,
159
+ )
160
+
161
+ self.o_proj = RowParallelLinear(
162
+ self.total_num_heads * self.head_dim,
163
+ config.hidden_size,
164
+ bias=False,
165
+ quant_config=quant_config,
166
+ reduce_results=False,
167
+ tp_rank=self.attn_tp_rank,
168
+ tp_size=self.attn_tp_size,
169
+ )
170
+
171
+ self.attn = RadixAttention(
172
+ self.num_heads,
173
+ self.head_dim,
174
+ self.scaling,
175
+ num_kv_heads=self.num_kv_heads,
176
+ layer_id=layer_id,
177
+ prefix=f"{prefix}.attn",
178
+ )
179
+
180
+ self.d_ssm = (
181
+ int(config.mamba_expand * config.hidden_size)
182
+ if config.mamba_d_ssm is None
183
+ else config.mamba_d_ssm
184
+ )
185
+
186
+ self.mamba = MambaMixer2(
187
+ hidden_size=config.hidden_size,
188
+ ssm_state_size=config.mamba_d_state,
189
+ conv_kernel_size=config.mamba_d_conv,
190
+ intermediate_size=self.d_ssm,
191
+ use_conv_bias=config.mamba_conv_bias,
192
+ use_bias=config.mamba_proj_bias,
193
+ n_groups=config.mamba_n_groups,
194
+ num_heads=config.mamba_n_heads,
195
+ layer_id=layer_id,
196
+ head_dim=config.mamba_d_head,
197
+ rms_norm_eps=config.rms_norm_eps,
198
+ chunk_size=config.mamba_chunk_size,
199
+ activation=config.hidden_act,
200
+ use_rms_norm=config.mamba_rms_norm,
201
+ prefix=f"{prefix}.mixer",
202
+ )
203
+
204
+ # FalconH1 all layers are sparse and have no nextn now
205
+ self.is_layer_sparse = False
206
+ is_previous_layer_sparse = False
207
+
208
+ self.layer_scatter_modes = LayerScatterModes.init_new(
209
+ layer_id=layer_id,
210
+ num_layers=config.num_hidden_layers,
211
+ is_layer_sparse=self.is_layer_sparse,
212
+ is_previous_layer_sparse=is_previous_layer_sparse,
213
+ )
214
+
215
+ self.feed_forward = FalconH1MLP(
216
+ hidden_size=self.hidden_size,
217
+ intermediate_size=config.intermediate_size,
218
+ hidden_act=config.hidden_act,
219
+ layer_id=layer_id,
220
+ mlp_multipliers=config.mlp_multipliers,
221
+ quant_config=quant_config,
222
+ prefix=add_prefix("mlp", prefix),
223
+ )
224
+
225
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
226
+ self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
227
+
228
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
229
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
230
+
231
+ self.layer_communicator = LayerCommunicator(
232
+ layer_scatter_modes=self.layer_scatter_modes,
233
+ input_layernorm=self.input_layernorm,
234
+ post_attention_layernorm=self.pre_ff_layernorm,
235
+ allow_reduce_scatter=True,
236
+ )
237
+
238
+ self.alt_stream = alt_stream
239
+ self.key_multiplier = config.key_multiplier
240
+
241
+ self.ssm_out_multiplier = config.ssm_out_multiplier
242
+ self.ssm_in_multiplier = config.ssm_in_multiplier
243
+
244
+ self.attention_in_multiplier = config.attention_in_multiplier
245
+ self.attn_out_multiplier = config.attention_out_multiplier
246
+
247
+ self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
248
+ self.zxbcdt_multipliers = config.ssm_multipliers
249
+ self._init_mup_vector()
250
+
251
+ def _init_mup_vector(self):
252
+ """
253
+ Non learnable per-block scaling vector composed of element-wise
254
+ multipliersapplied to each separate contiguous block of the output
255
+ of the linear projection (in_proj) before further processing
256
+ (gating, convolution, SSM):
257
+
258
+ - Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
259
+ - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
260
+ - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
261
+ - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
262
+ → zxbcdt_multipliers[3]
263
+ - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
264
+
265
+ where:
266
+ - d_ssm: Dimension of state-space model latent
267
+ - G: Number of groups (n_groups)
268
+ - S: SSM state size per group
269
+ - All indices are divided by tp_size to support tensor parallelism
270
+ """
271
+ vector_shape = (
272
+ 2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads
273
+ ) // self.tp_size
274
+ mup_vector = torch.ones(1, vector_shape)
275
+ # Z vector 0 -> d_ssm
276
+ mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0]
277
+ # X vector d_ssm -> 2 * d_ssm
278
+ mup_vector[
279
+ :, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size)
280
+ ] *= self.zxbcdt_multipliers[1]
281
+ # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
282
+ mup_vector[
283
+ :,
284
+ (2 * self.d_ssm)
285
+ // self.tp_size : (2 * self.d_ssm + self.groups_time_state_size)
286
+ // self.tp_size,
287
+ ] *= self.zxbcdt_multipliers[2]
288
+ # C vector 2 * d_ssm + (n_group * d_state)
289
+ # -> 2 * d_ssm + 2 * (n_group * d_state)
290
+ mup_vector[
291
+ :,
292
+ (2 * self.d_ssm + self.groups_time_state_size)
293
+ // self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size)
294
+ // self.tp_size,
295
+ ] *= self.zxbcdt_multipliers[3]
296
+ # dt vector 2 * d_ssm + 2 * (n_group * d_state)
297
+ # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
298
+ mup_vector[
299
+ :,
300
+ (2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :,
301
+ ] *= self.zxbcdt_multipliers[4]
302
+
303
+ self.register_buffer("mup_vector", mup_vector, persistent=False)
304
+
305
+ def self_attention(
306
+ self,
307
+ positions: torch.Tensor,
308
+ hidden_states: torch.Tensor,
309
+ forward_batch: ForwardBatch,
310
+ ) -> torch.Tensor:
311
+ qkv, _ = self.qkv_proj(hidden_states)
312
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
313
+ k = k * self.key_multiplier
314
+ q, k = self.rotary_emb(positions, q, k)
315
+
316
+ attn_output = self.attn(q, k, v, forward_batch)
317
+
318
+ output, _ = self.o_proj(attn_output)
319
+ return output
320
+
321
+ def forward(
322
+ self,
323
+ positions: torch.Tensor,
324
+ hidden_states: torch.Tensor,
325
+ residual: Optional[torch.Tensor],
326
+ forward_batch: ForwardBatch,
327
+ **kwargs: Any,
328
+ ):
329
+ hidden_states, residual = self.layer_communicator.prepare_attn(
330
+ hidden_states, residual, forward_batch
331
+ )
332
+
333
+ if not forward_batch.forward_mode.is_idle():
334
+ # Attention block
335
+ attention_hidden_states = self.self_attention(
336
+ positions=positions,
337
+ hidden_states=hidden_states * self.attention_in_multiplier,
338
+ forward_batch=forward_batch,
339
+ )
340
+ attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
341
+
342
+ # Mamba block
343
+ mamba_hidden_states = torch.empty_like(hidden_states)
344
+ self.mamba(
345
+ hidden_states * self.ssm_in_multiplier,
346
+ mamba_hidden_states,
347
+ forward_batch=forward_batch,
348
+ mup_vector=self.mup_vector,
349
+ )
350
+ mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
351
+
352
+ hidden_states = attention_hidden_states + mamba_hidden_states
353
+
354
+ # Fully Connected
355
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
356
+ hidden_states, residual, forward_batch
357
+ )
358
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
359
+ forward_batch
360
+ )
361
+ hidden_states = self.feed_forward(
362
+ hidden_states, forward_batch, use_reduce_scatter
363
+ )
364
+
365
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
366
+ hidden_states, residual, forward_batch
367
+ )
368
+
369
+ return hidden_states, residual
370
+
371
+
372
+ ALL_DECODER_LAYER_TYPES = {
373
+ "falcon_h1": FalconH1HybridAttentionDecoderLayer,
374
+ }
375
+
376
+
377
+ class FalconH1Model(nn.Module):
378
+ def __init__(
379
+ self,
380
+ config: FalconH1Config,
381
+ quant_config: Optional[QuantizationConfig] = None,
382
+ prefix: str = "",
383
+ ) -> None:
384
+ super().__init__()
385
+ self.config = config
386
+
387
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
388
+ self.embedding_multiplier = config.embedding_multiplier
389
+
390
+ self.embed_tokens = VocabParallelEmbedding(
391
+ config.vocab_size,
392
+ config.hidden_size,
393
+ org_num_embeddings=config.vocab_size,
394
+ enable_tp=not is_dp_attention_enabled(),
395
+ )
396
+
397
+ def get_layer(idx: int, prefix: str):
398
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
399
+ return layer_class(
400
+ config,
401
+ idx,
402
+ quant_config=quant_config,
403
+ prefix=prefix,
404
+ alt_stream=alt_stream,
405
+ )
406
+
407
+ self.layers = make_layers(
408
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
409
+ )
410
+
411
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
412
+ self.infer_count = 0
413
+
414
+ def forward(
415
+ self,
416
+ input_ids: torch.Tensor,
417
+ positions: torch.Tensor,
418
+ forward_batch: ForwardBatch,
419
+ # mamba_cache_params: MambaCacheParams,
420
+ inputs_embeds: Optional[torch.Tensor] = None,
421
+ ) -> torch.Tensor:
422
+
423
+ # pass a sequence index tensor, that is required for
424
+ # proper continuous batching computation including
425
+ # chunked prefill
426
+ if inputs_embeds is not None:
427
+ hidden_states = inputs_embeds * self.embedding_multiplier
428
+ else:
429
+ hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier
430
+
431
+ residual = None
432
+ for i in range(len(self.layers)):
433
+ layer = self.layers[i]
434
+ hidden_states, residual = layer(
435
+ layer_id=i,
436
+ positions=positions,
437
+ hidden_states=hidden_states,
438
+ residual=residual,
439
+ forward_batch=forward_batch,
440
+ )
441
+
442
+ if not forward_batch.forward_mode.is_idle():
443
+ if residual is None:
444
+ hidden_states = self.final_layernorm(hidden_states)
445
+ else:
446
+ hidden_states, _ = self.final_layernorm(hidden_states, residual)
447
+
448
+ return hidden_states
449
+
450
+
451
+ class HybridLayerType(enum.Enum):
452
+ full_attention = "attention"
453
+ swa_attention = "swa_attention"
454
+ linear_attention = "linear_attention"
455
+ mamba2 = "mamba"
456
+
457
+
458
+ class FalconH1ForCausalLM(nn.Module):
459
+ fall_back_to_pt_during_load = False
460
+
461
+ def __init__(
462
+ self,
463
+ config: FalconH1Config,
464
+ quant_config: Optional[QuantizationConfig] = None,
465
+ prefix: str = "",
466
+ ) -> None:
467
+ super().__init__()
468
+ self.config = config
469
+ self.pp_group = get_pp_group()
470
+ assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
471
+ self.quant_config = quant_config
472
+ self.model = FalconH1Model(
473
+ config, quant_config, prefix=add_prefix("model", prefix)
474
+ )
475
+ if config.tie_word_embeddings:
476
+ self.lm_head = self.model.embed_tokens
477
+ else:
478
+ self.lm_head = ParallelLMHead(
479
+ config.vocab_size,
480
+ config.hidden_size,
481
+ quant_config=quant_config,
482
+ org_num_embeddings=config.vocab_size,
483
+ prefix=add_prefix("lm_head", prefix),
484
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
485
+ )
486
+ self.lm_head = self.lm_head.float()
487
+ self.lm_head_multiplier = config.lm_head_multiplier
488
+ self.logits_processor = LogitsProcessor(
489
+ config, logit_scale=self.lm_head_multiplier
490
+ )
491
+
492
+ @torch.no_grad()
493
+ def forward(
494
+ self,
495
+ input_ids: torch.Tensor,
496
+ positions: torch.Tensor,
497
+ forward_batch: ForwardBatch,
498
+ inputs_embeds: Optional[torch.Tensor] = None,
499
+ **kwargs,
500
+ ):
501
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
502
+
503
+ return self.logits_processor(
504
+ input_ids, hidden_states, self.lm_head, forward_batch
505
+ )
506
+
507
+ def get_embed_and_head(self):
508
+ return self.model.embed_tokens.weight, self.lm_head.weight
509
+
510
+ def set_embed_and_head(self, embed, head):
511
+ del self.model.embed_tokens.weight
512
+ del self.lm_head.weight
513
+ self.model.embed_tokens.weight = embed
514
+ self.lm_head.weight = head
515
+ torch.cuda.empty_cache()
516
+ torch.cuda.synchronize()
517
+
518
+ def load_weights(
519
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
520
+ ) -> Set[str]:
521
+ stacked_params_mapping = [
522
+ # (param_name, shard_name, shard_id)
523
+ ("qkv_proj", "q_proj", "q"),
524
+ ("qkv_proj", "k_proj", "k"),
525
+ ("qkv_proj", "v_proj", "v"),
526
+ ("gate_up_proj", "gate_proj", 0),
527
+ ("gate_up_proj", "up_proj", 1),
528
+ ]
529
+
530
+ params_dict = dict(self.named_parameters())
531
+ loaded_params: Set[str] = set()
532
+ for name, loaded_weight in weights:
533
+
534
+ if "rotary_emb.inv_freq" in name:
535
+ continue
536
+
537
+ if ".self_attn." in name:
538
+ name = name.replace(".self_attn", "")
539
+
540
+ if "A_log" in name:
541
+ name = name.replace("A_log", "A")
542
+
543
+ for param_name, weight_name, shard_id in stacked_params_mapping:
544
+ if weight_name not in name:
545
+ continue
546
+
547
+ name = name.replace(weight_name, param_name)
548
+ # Skip loading extra bias for GPTQ models.
549
+ if name.endswith(".bias") and name not in params_dict:
550
+ continue
551
+ # Skip layers on other devices.
552
+ # if is_pp_missing_parameter(name, self):
553
+ # continue
554
+ if name not in params_dict:
555
+ continue
556
+ param = params_dict[name]
557
+ weight_loader = getattr(param, "weight_loader")
558
+ weight_loader(param, loaded_weight, shard_id)
559
+ break
560
+ else:
561
+ # Skip loading extra bias for GPTQ models.
562
+ if name.endswith(".bias") and name not in params_dict:
563
+ continue
564
+ # if is_pp_missing_parameter(name, self):
565
+ # continue
566
+
567
+ param = params_dict[name]
568
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
569
+
570
+ weight_loader(param, loaded_weight)
571
+
572
+ loaded_params.add(name)
573
+ return loaded_params
574
+
575
+
576
+ EntryClass = FalconH1ForCausalLM
@@ -20,7 +20,6 @@ import torch.nn.functional as F
20
20
  from torch import nn
21
21
  from transformers import (
22
22
  ROPE_INIT_FUNCTIONS,
23
- AutoModel,
24
23
  Gemma3TextConfig,
25
24
  PretrainedConfig,
26
25
  PreTrainedModel,
@@ -761,4 +760,3 @@ class Gemma3ForCausalLM(PreTrainedModel):
761
760
 
762
761
 
763
762
  EntryClass = Gemma3ForCausalLM
764
- AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
@@ -23,7 +23,6 @@ import torch
23
23
  from torch import nn
24
24
  from transformers import Gemma3Config, PreTrainedModel
25
25
 
26
- from sglang.srt.hf_transformers_utils import get_processor
27
26
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
28
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -44,6 +43,7 @@ from sglang.srt.model_loader.weight_utils import (
44
43
  from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
45
44
  from sglang.srt.models.siglip import SiglipVisionModel
46
45
  from sglang.srt.utils import add_prefix
46
+ from sglang.srt.utils.hf_transformers_utils import get_processor
47
47
 
48
48
  logger = logging.getLogger(__name__)
49
49
 
@@ -14,7 +14,6 @@ from transformers import (
14
14
  )
15
15
  from transformers.models.auto.modeling_auto import AutoModel
16
16
 
17
- from sglang.srt.hf_transformers_utils import get_processor
18
17
  from sglang.srt.layers.layernorm import RMSNorm
19
18
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
20
19
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -38,6 +37,7 @@ from sglang.srt.model_loader.weight_utils import (
38
37
  from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
39
38
  from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
40
39
  from sglang.srt.utils import add_prefix
40
+ from sglang.srt.utils.hf_transformers_utils import get_processor
41
41
 
42
42
  logger = logging.getLogger(__name__)
43
43
 
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
499
499
  def should_apply_lora(self, module_name: str) -> bool:
500
500
  return bool(self.lora_pattern.match(module_name))
501
501
 
502
- def get_hidden_dim(self, module_name):
502
+ def get_hidden_dim(self, module_name, layer_idx):
503
503
  # return input_dim, output_dim
504
504
  if module_name == "qkv_proj":
505
505
  return (
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- """Inference-only GLM-4.5 model compatible with HuggingFace weights"""
15
+ """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
16
16
 
17
17
  import logging
18
18
  from typing import Any, Dict, Iterable, Optional, Tuple
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
429
429
  routed_scaling_factor=self.routed_scaling_factor,
430
430
  )
431
431
 
432
- self.experts = get_moe_impl_class()(
432
+ self.experts = get_moe_impl_class(quant_config)(
433
433
  num_experts=config.n_routed_experts
434
434
  + self.num_fused_shared_experts
435
435
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -785,9 +785,9 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
785
785
  or self.config.architectures[0] != architecture
786
786
  or self.config.n_shared_experts != 1
787
787
  ):
788
- disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
788
+ disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
789
789
  elif get_moe_expert_parallel_world_size() > 1:
790
- disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
790
+ disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
791
791
 
792
792
  if disable_reason is not None:
793
793
  global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- """Inference-only GLM-4.5 NextN Speculative Decoding."""
15
+ """Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
16
16
  import logging
17
17
  from typing import Iterable, Optional, Tuple
18
18
 
@@ -48,7 +48,7 @@ class Glm4MoeModelNextN(nn.Module):
48
48
  super().__init__()
49
49
  if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
50
50
  logger.warning(
51
- "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
51
+ "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 / GLM-4.6 model."
52
52
  )
53
53
  quant_config = None
54
54