sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1069 @@
1
+ import enum
2
+ import logging
3
+ from typing import Any, Dict, Iterable, Optional, Set, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from sglang.srt.configs.qwen3_next import Qwen3NextConfig
10
+ from sglang.srt.distributed import (
11
+ divide,
12
+ get_pp_group,
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ )
16
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
17
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
18
+ from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
19
+ from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
20
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
21
+ from sglang.srt.layers.dp_attention import (
22
+ get_attention_tp_rank,
23
+ get_attention_tp_size,
24
+ is_dp_attention_enabled,
25
+ )
26
+ from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
27
+ from sglang.srt.layers.linear import (
28
+ ColumnParallelLinear,
29
+ MergedColumnParallelLinear,
30
+ QKVParallelLinear,
31
+ RowParallelLinear,
32
+ )
33
+ from sglang.srt.layers.logits_processor import LogitsProcessor
34
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
35
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
+ from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.rotary_embedding import get_rope
38
+ from sglang.srt.layers.vocab_parallel_embedding import (
39
+ ParallelLMHead,
40
+ VocabParallelEmbedding,
41
+ )
42
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
43
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
44
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
+ from sglang.srt.model_loader.weight_utils import (
46
+ default_weight_loader,
47
+ sharded_weight_loader,
48
+ )
49
+ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
50
+ from sglang.srt.utils import (
51
+ LazyValue,
52
+ add_prefix,
53
+ is_cuda,
54
+ is_npu,
55
+ make_layers,
56
+ set_weight_attrs,
57
+ )
58
+
59
+ logger = logging.getLogger(__name__)
60
+ _is_cuda = is_cuda()
61
+ _is_npu = is_npu()
62
+
63
+ import triton
64
+ import triton.language as tl
65
+
66
+
67
+ @triton.jit
68
+ def fused_qkvzba_split_reshape_cat_kernel(
69
+ mixed_qkv,
70
+ z,
71
+ b,
72
+ a,
73
+ mixed_qkvz,
74
+ mixed_ba,
75
+ NUM_HEADS_QK: tl.constexpr,
76
+ NUM_HEADS_V: tl.constexpr,
77
+ HEAD_QK: tl.constexpr,
78
+ HEAD_V: tl.constexpr,
79
+ ):
80
+ i_bs, i_qk = tl.program_id(0), tl.program_id(1)
81
+ QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
82
+ BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
83
+ QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
84
+ q_end: tl.constexpr = HEAD_QK
85
+ blk_q_ptr = (
86
+ mixed_qkvz
87
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
88
+ + i_qk * QKVZ_DIM_T
89
+ + tl.arange(0, q_end)
90
+ )
91
+ k_end: tl.constexpr = q_end + HEAD_QK
92
+ blk_k_ptr = (
93
+ mixed_qkvz
94
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
95
+ + i_qk * QKVZ_DIM_T
96
+ + tl.arange(q_end, k_end)
97
+ )
98
+ v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
99
+ blk_v_ptr = (
100
+ mixed_qkvz
101
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
102
+ + i_qk * QKVZ_DIM_T
103
+ + tl.arange(k_end, v_end)
104
+ )
105
+ z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
106
+ blk_z_ptr = (
107
+ mixed_qkvz
108
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
109
+ + i_qk * QKVZ_DIM_T
110
+ + tl.arange(v_end, z_end)
111
+ )
112
+ blk_q_st_ptr = (
113
+ mixed_qkv
114
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
115
+ + i_qk * HEAD_QK
116
+ + tl.arange(0, HEAD_QK)
117
+ )
118
+ blk_k_st_ptr = (
119
+ mixed_qkv
120
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
121
+ + NUM_HEADS_QK * HEAD_QK
122
+ + i_qk * HEAD_QK
123
+ + tl.arange(0, HEAD_QK)
124
+ )
125
+ blk_v_st_ptr = (
126
+ mixed_qkv
127
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
128
+ + NUM_HEADS_QK * HEAD_QK * 2
129
+ + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
130
+ + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
131
+ )
132
+ blk_z_st_ptr = (
133
+ z
134
+ + i_bs * NUM_HEADS_V * HEAD_V
135
+ + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
136
+ + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
137
+ )
138
+ tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
139
+ tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
140
+ tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
141
+ tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
142
+ b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
143
+ a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
144
+ for i in tl.static_range(b_end):
145
+ blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
146
+ blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
147
+ tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
148
+ for i in tl.static_range(b_end, a_end):
149
+ blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
150
+ blk_a_st_ptr = (
151
+ a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
152
+ )
153
+ tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
154
+
155
+
156
+ def fused_qkvzba_split_reshape_cat(
157
+ mixed_qkvz,
158
+ mixed_ba,
159
+ num_heads_qk,
160
+ num_heads_v,
161
+ head_qk,
162
+ head_v,
163
+ ):
164
+ batch, seq_len = mixed_qkvz.shape[0], 1
165
+ qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
166
+ mixed_qkv = torch.empty(
167
+ [batch * seq_len, qkv_dim_t],
168
+ dtype=mixed_qkvz.dtype,
169
+ device=mixed_qkvz.device,
170
+ )
171
+ z = torch.empty(
172
+ [batch * seq_len, num_heads_v, head_v],
173
+ dtype=mixed_qkvz.dtype,
174
+ device=mixed_qkvz.device,
175
+ )
176
+ b = torch.empty(
177
+ [batch * seq_len, num_heads_v],
178
+ dtype=mixed_ba.dtype,
179
+ device=mixed_ba.device,
180
+ )
181
+ a = torch.empty_like(b)
182
+ grid = (batch * seq_len, num_heads_qk)
183
+ fused_qkvzba_split_reshape_cat_kernel[grid](
184
+ mixed_qkv,
185
+ z,
186
+ b,
187
+ a,
188
+ mixed_qkvz,
189
+ mixed_ba,
190
+ num_heads_qk,
191
+ num_heads_v,
192
+ head_qk,
193
+ head_v,
194
+ num_warps=1,
195
+ num_stages=3,
196
+ )
197
+ return mixed_qkv, z, b, a
198
+
199
+
200
+ # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
201
+ @triton.jit
202
+ def fused_gdn_gating_kernel(
203
+ g,
204
+ A_log,
205
+ a,
206
+ dt_bias,
207
+ seq_len,
208
+ NUM_HEADS: tl.constexpr,
209
+ beta: tl.constexpr,
210
+ threshold: tl.constexpr,
211
+ BLK_HEADS: tl.constexpr,
212
+ ):
213
+ i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
214
+ head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
215
+ off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
216
+ mask = head_off < NUM_HEADS
217
+ blk_A_log = tl.load(A_log + head_off, mask=mask)
218
+ blk_a = tl.load(a + off, mask=mask)
219
+ blk_bias = tl.load(dt_bias + head_off, mask=mask)
220
+ x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
221
+ softplus_x = tl.where(
222
+ beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
223
+ )
224
+ blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
225
+ tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
226
+
227
+
228
+ def fused_gdn_gating(
229
+ A_log: torch.Tensor,
230
+ a: torch.Tensor,
231
+ dt_bias: torch.Tensor,
232
+ beta: float = 1.0,
233
+ threshold: float = 20.0,
234
+ ) -> torch.Tensor:
235
+ batch, num_heads = a.shape
236
+ seq_len = 1
237
+ grid = (batch, seq_len, triton.cdiv(num_heads, 8))
238
+ g = torch.empty_like(a, dtype=torch.float32)
239
+ fused_gdn_gating_kernel[grid](
240
+ g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
241
+ )
242
+ return g
243
+
244
+
245
+ class Qwen3GatedDeltaNet(nn.Module):
246
+ def __init__(
247
+ self,
248
+ config: Qwen3NextConfig,
249
+ layer_id: int,
250
+ quant_config: Optional[QuantizationConfig] = None,
251
+ alt_stream: Optional[torch.cuda.Stream] = None,
252
+ ) -> None:
253
+ super().__init__()
254
+ self.config = config
255
+ self.attn_tp_rank = get_attention_tp_rank()
256
+ self.attn_tp_size = get_attention_tp_size()
257
+ self.hidden_size = config.hidden_size
258
+ self.num_v_heads = config.linear_num_value_heads
259
+ self.num_k_heads = config.linear_num_key_heads
260
+ self.head_k_dim = config.linear_key_head_dim
261
+ self.head_v_dim = config.linear_value_head_dim
262
+ self.key_dim = self.head_k_dim * self.num_k_heads
263
+ self.value_dim = self.head_v_dim * self.num_v_heads
264
+ self.alt_stream = alt_stream
265
+
266
+ self.conv_kernel_size = config.linear_conv_kernel_dim
267
+ self.layer_id = layer_id
268
+ self.activation = config.hidden_act
269
+ self.layer_norm_epsilon = config.rms_norm_eps
270
+
271
+ # QKV
272
+ self.conv_dim = self.key_dim * 2 + self.value_dim
273
+ self.conv1d = ColumnParallelLinear(
274
+ input_size=self.conv_kernel_size,
275
+ output_size=self.conv_dim,
276
+ bias=False,
277
+ quant_config=None,
278
+ tp_rank=self.attn_tp_rank,
279
+ tp_size=self.attn_tp_size,
280
+ )
281
+ self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
282
+ # projection of the input hidden states
283
+ projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
284
+ projection_size_ba = self.num_v_heads * 2
285
+
286
+ self.in_proj_qkvz = ColumnParallelLinear(
287
+ input_size=self.hidden_size,
288
+ output_size=projection_size_qkvz,
289
+ bias=False,
290
+ quant_config=quant_config,
291
+ tp_rank=self.attn_tp_rank,
292
+ tp_size=self.attn_tp_size,
293
+ )
294
+ self.in_proj_ba = ColumnParallelLinear(
295
+ input_size=self.hidden_size,
296
+ output_size=projection_size_ba,
297
+ bias=False,
298
+ quant_config=None,
299
+ tp_rank=self.attn_tp_rank,
300
+ tp_size=self.attn_tp_size,
301
+ )
302
+
303
+ query_key_settings = (self.key_dim, 0, False)
304
+ value_settings = (self.value_dim, 0, False)
305
+
306
+ delattr(self.conv1d.weight, "weight_loader")
307
+ set_weight_attrs(
308
+ self.conv1d.weight,
309
+ {
310
+ "weight_loader": mamba_v2_sharded_weight_loader(
311
+ [
312
+ query_key_settings,
313
+ query_key_settings,
314
+ value_settings,
315
+ ],
316
+ self.attn_tp_size,
317
+ self.attn_tp_rank,
318
+ )
319
+ },
320
+ )
321
+
322
+ # selective projection used to make dt, B and C input dependent
323
+
324
+ # time step projection (discretization)
325
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
326
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads // self.attn_tp_size))
327
+
328
+ A = torch.empty(
329
+ divide(self.num_v_heads, self.attn_tp_size), dtype=torch.float32
330
+ ).uniform_(0, 16)
331
+ self.A_log = nn.Parameter(torch.log(A))
332
+ self.A_log._no_weight_decay = True
333
+
334
+ set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
335
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
336
+
337
+ self.norm = RMSNormGated(
338
+ self.head_v_dim,
339
+ eps=self.layer_norm_epsilon,
340
+ group_size=None,
341
+ norm_before_gate=True,
342
+ device=torch.get_device_module().current_device(),
343
+ dtype=config.torch_dtype,
344
+ )
345
+
346
+ self.out_proj = RowParallelLinear(
347
+ self.value_dim,
348
+ self.hidden_size,
349
+ bias=False,
350
+ quant_config=quant_config,
351
+ input_is_parallel=True,
352
+ reduce_results=False,
353
+ tp_rank=self.attn_tp_rank,
354
+ tp_size=self.attn_tp_size,
355
+ )
356
+
357
+ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
358
+ """
359
+ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
360
+ """
361
+ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
362
+ self.num_k_heads // self.attn_tp_size,
363
+ (
364
+ self.head_k_dim
365
+ + self.head_k_dim
366
+ + (self.head_v_dim + self.head_v_dim)
367
+ * self.num_v_heads
368
+ // self.num_k_heads
369
+ ),
370
+ )
371
+ new_tensor_shape_ba = mixed_ba.size()[:-1] + (
372
+ self.num_k_heads // self.attn_tp_size,
373
+ 2 * self.num_v_heads // self.num_k_heads,
374
+ )
375
+
376
+ mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
377
+ mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
378
+
379
+ split_arg_list_qkvz = [
380
+ self.head_k_dim,
381
+ self.head_k_dim,
382
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
383
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
384
+ ]
385
+ split_arg_list_ba = [
386
+ self.num_v_heads // self.num_k_heads,
387
+ self.num_v_heads // self.num_k_heads,
388
+ ]
389
+
390
+ # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
391
+ # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
392
+ (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
393
+ (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
394
+
395
+ # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
396
+ value = value.reshape(value.size(0), -1, self.head_v_dim)
397
+ z = z.reshape(z.size(0), -1, self.head_v_dim)
398
+ b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size)
399
+ a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size)
400
+
401
+ return query, key, value, z, b, a
402
+
403
+ def _forward_input_proj(self, hidden_states: torch.Tensor):
404
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
405
+ seq_len, _ = hidden_states.shape
406
+ if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
407
+ current_stream = torch.cuda.current_stream()
408
+ self.alt_stream.wait_stream(current_stream)
409
+ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
410
+ with torch.cuda.stream(self.alt_stream):
411
+ projected_states_ba, _ = self.in_proj_ba(hidden_states)
412
+ current_stream.wait_stream(self.alt_stream)
413
+ else:
414
+ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
415
+ projected_states_ba, _ = self.in_proj_ba(hidden_states)
416
+ return projected_states_qkvz, projected_states_ba
417
+
418
+ def forward(
419
+ self,
420
+ hidden_states: torch.Tensor,
421
+ forward_batch: ForwardBatch,
422
+ ):
423
+ seq_len, _ = hidden_states.shape
424
+ is_cuda_graph = forward_batch.forward_mode.is_cuda_graph()
425
+
426
+ projected_states_qkvz, projected_states_ba = self._forward_input_proj(
427
+ hidden_states
428
+ )
429
+
430
+ if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph:
431
+ mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
432
+ projected_states_qkvz,
433
+ projected_states_ba,
434
+ triton.cdiv(self.num_k_heads, self.attn_tp_size),
435
+ triton.cdiv(self.num_v_heads, self.attn_tp_size),
436
+ self.head_k_dim,
437
+ self.head_v_dim,
438
+ )
439
+ else:
440
+ query, key, value, z, b, a = self.fix_query_key_value_ordering(
441
+ projected_states_qkvz, projected_states_ba
442
+ )
443
+ query, key, value = map(
444
+ lambda x: x.reshape(x.shape[0], -1), (query, key, value)
445
+ )
446
+ mixed_qkv = torch.cat((query, key, value), dim=-1)
447
+ # mixed_qkv = rearrange(mixed_qkv, "b l d -> b d l")
448
+
449
+ # 2. Convolution sequence transformation
450
+ conv_weights = self.conv1d.weight.view(
451
+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)
452
+ )
453
+
454
+ kwargs = {
455
+ "mixed_qkv": mixed_qkv,
456
+ "conv_weights": conv_weights,
457
+ "bias": self.conv1d.bias,
458
+ "activation": self.activation,
459
+ "key_dim": self.key_dim,
460
+ "value_dim": self.value_dim,
461
+ "attention_tp_size": self.attn_tp_size,
462
+ "head_k_dim": self.head_k_dim,
463
+ "head_v_dim": self.head_v_dim,
464
+ "a": a,
465
+ "b": b,
466
+ "A_log": self.A_log,
467
+ "dt_bias": self.dt_bias,
468
+ "layer_id": self.layer_id,
469
+ "seq_len": seq_len,
470
+ "num_k_heads": self.num_k_heads,
471
+ "num_v_heads": self.num_v_heads,
472
+ "z": z,
473
+ }
474
+
475
+ core_attn_out = forward_batch.attn_backend.forward(
476
+ q=None,
477
+ k=None,
478
+ v=None,
479
+ layer=None,
480
+ forward_batch=forward_batch,
481
+ **kwargs,
482
+ )
483
+
484
+ z_shape_og = z.shape
485
+ # reshape input data into 2D tensor
486
+ core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
487
+ z = z.reshape(-1, z.shape[-1])
488
+ core_attn_out = self.norm(core_attn_out, z)
489
+ core_attn_out = core_attn_out.reshape(z_shape_og)
490
+ core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
491
+
492
+ output, _ = self.out_proj(core_attn_out)
493
+ return output
494
+
495
+
496
+ class Qwen3HybridLinearDecoderLayer(nn.Module):
497
+
498
+ def __init__(
499
+ self,
500
+ config: Qwen3NextConfig,
501
+ layer_id: int,
502
+ quant_config: Optional[QuantizationConfig] = None,
503
+ prefix: str = "",
504
+ alt_stream: Optional[torch.cuda.Stream] = None,
505
+ ) -> None:
506
+ super().__init__()
507
+ self.config = config
508
+ self.linear_attn = Qwen3GatedDeltaNet(
509
+ config, layer_id, quant_config, alt_stream
510
+ )
511
+
512
+ # Qwen3Next all layers are sparse and have no nextn now
513
+ self.is_layer_sparse = True
514
+ is_previous_layer_sparse = True
515
+ self.layer_id = layer_id
516
+
517
+ self.layer_scatter_modes = LayerScatterModes.init_new(
518
+ layer_id=layer_id,
519
+ num_layers=config.num_hidden_layers,
520
+ is_layer_sparse=self.is_layer_sparse,
521
+ is_previous_layer_sparse=is_previous_layer_sparse,
522
+ )
523
+
524
+ if self.is_layer_sparse:
525
+ self.mlp = Qwen2MoeSparseMoeBlock(
526
+ layer_id=layer_id,
527
+ config=config,
528
+ quant_config=quant_config,
529
+ alt_stream=alt_stream,
530
+ )
531
+ else:
532
+ self.mlp = Qwen2MoeMLP(
533
+ hidden_size=config.hidden_size,
534
+ intermediate_size=config.intermediate_size,
535
+ hidden_act=config.hidden_act,
536
+ quant_config=quant_config,
537
+ )
538
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
539
+ self.post_attention_layernorm = GemmaRMSNorm(
540
+ config.hidden_size, eps=config.rms_norm_eps
541
+ )
542
+ self.layer_communicator = LayerCommunicator(
543
+ layer_scatter_modes=self.layer_scatter_modes,
544
+ input_layernorm=self.input_layernorm,
545
+ post_attention_layernorm=self.post_attention_layernorm,
546
+ allow_reduce_scatter=True,
547
+ )
548
+
549
+ def forward(
550
+ self,
551
+ hidden_states: torch.Tensor,
552
+ residual: Optional[torch.Tensor],
553
+ **kwargs,
554
+ ):
555
+ forward_batch = kwargs.get("forward_batch", None)
556
+
557
+ hidden_states, residual = self.layer_communicator.prepare_attn(
558
+ hidden_states, residual, forward_batch
559
+ )
560
+
561
+ if not forward_batch.forward_mode.is_idle():
562
+ hidden_states = self.linear_attn(
563
+ hidden_states,
564
+ forward_batch,
565
+ )
566
+ # Fully Connected
567
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
568
+ hidden_states, residual, forward_batch
569
+ )
570
+
571
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
572
+ forward_batch
573
+ )
574
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
575
+
576
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
577
+ hidden_states, residual, forward_batch
578
+ )
579
+
580
+ return hidden_states, residual
581
+
582
+
583
+ class Qwen3HybridAttentionDecoderLayer(nn.Module):
584
+
585
+ def __init__(
586
+ self,
587
+ config: Qwen3NextConfig,
588
+ layer_id: int,
589
+ quant_config: Optional[QuantizationConfig] = None,
590
+ prefix: str = "",
591
+ alt_stream: Optional[torch.cuda.Stream] = None,
592
+ ) -> None:
593
+ super().__init__()
594
+ self.config = config
595
+ self.hidden_size = config.hidden_size
596
+ self.attn_tp_rank = get_attention_tp_rank()
597
+ self.attn_tp_size = get_attention_tp_size()
598
+ self.total_num_heads = config.num_attention_heads
599
+ assert self.total_num_heads % self.attn_tp_size == 0
600
+ self.num_heads = self.total_num_heads // self.attn_tp_size
601
+ self.total_num_kv_heads = config.num_key_value_heads
602
+ if self.total_num_kv_heads >= self.attn_tp_size:
603
+ # Number of KV heads is greater than TP size, so we partition
604
+ # the KV heads across multiple tensor parallel GPUs.
605
+ assert self.total_num_kv_heads % self.attn_tp_size == 0
606
+ else:
607
+ # Number of KV heads is less than TP size, so we replicate
608
+ # the KV heads across multiple tensor parallel GPUs.
609
+ assert self.attn_tp_size % self.total_num_kv_heads == 0
610
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
611
+ self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
612
+ self.q_size = self.num_heads * self.head_dim
613
+ self.kv_size = self.num_kv_heads * self.head_dim
614
+ self.scaling = self.head_dim**-0.5
615
+ self.rope_theta = getattr(config, "rope_theta", 10000)
616
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
617
+ self.rope_scaling = getattr(config, "rope_scaling", None)
618
+ self.partial_rotary_factor = config.partial_rotary_factor
619
+ self.layer_id = layer_id
620
+
621
+ self.attn_output_gate = getattr(config, "attn_output_gate", True)
622
+ if self.attn_output_gate:
623
+ logger.warning_once("using attn output gate!")
624
+
625
+ self.rotary_emb = get_rope(
626
+ head_size=self.head_dim,
627
+ rotary_dim=self.head_dim,
628
+ max_position=self.max_position_embeddings,
629
+ rope_scaling=self.rope_scaling,
630
+ base=self.rope_theta,
631
+ partial_rotary_factor=self.partial_rotary_factor,
632
+ is_neox_style=True,
633
+ dtype=torch.get_default_dtype(), # see impl of get_rope
634
+ )
635
+
636
+ self.qkv_proj = QKVParallelLinear(
637
+ config.hidden_size,
638
+ self.head_dim,
639
+ self.total_num_heads * (1 + self.attn_output_gate),
640
+ self.total_num_kv_heads,
641
+ bias=False,
642
+ quant_config=quant_config,
643
+ tp_rank=self.attn_tp_rank,
644
+ tp_size=self.attn_tp_size,
645
+ )
646
+
647
+ self.o_proj = RowParallelLinear(
648
+ self.total_num_heads * self.head_dim,
649
+ config.hidden_size,
650
+ bias=False,
651
+ quant_config=quant_config,
652
+ reduce_results=False,
653
+ tp_rank=self.attn_tp_rank,
654
+ tp_size=self.attn_tp_size,
655
+ )
656
+
657
+ self.attn = RadixAttention(
658
+ self.num_heads,
659
+ self.head_dim,
660
+ self.scaling,
661
+ num_kv_heads=self.num_kv_heads,
662
+ layer_id=layer_id,
663
+ prefix=f"{prefix}.attn",
664
+ )
665
+
666
+ # Qwen3Next all layers are sparse and have no nextn now
667
+ self.is_layer_sparse = True
668
+ is_previous_layer_sparse = True
669
+
670
+ self.layer_scatter_modes = LayerScatterModes.init_new(
671
+ layer_id=layer_id,
672
+ num_layers=config.num_hidden_layers,
673
+ is_layer_sparse=self.is_layer_sparse,
674
+ is_previous_layer_sparse=is_previous_layer_sparse,
675
+ )
676
+
677
+ if self.is_layer_sparse:
678
+ self.mlp = Qwen2MoeSparseMoeBlock(
679
+ layer_id=layer_id,
680
+ config=config,
681
+ quant_config=quant_config,
682
+ alt_stream=alt_stream,
683
+ )
684
+ else:
685
+ self.mlp = Qwen2MoeMLP(
686
+ hidden_size=config.hidden_size,
687
+ intermediate_size=config.intermediate_size,
688
+ hidden_act=config.hidden_act,
689
+ quant_config=quant_config,
690
+ )
691
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
692
+ self.post_attention_layernorm = GemmaRMSNorm(
693
+ config.hidden_size, eps=config.rms_norm_eps
694
+ )
695
+
696
+ self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
697
+ self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
698
+
699
+ self.layer_communicator = LayerCommunicator(
700
+ layer_scatter_modes=self.layer_scatter_modes,
701
+ input_layernorm=self.input_layernorm,
702
+ post_attention_layernorm=self.post_attention_layernorm,
703
+ allow_reduce_scatter=True,
704
+ )
705
+
706
+ self.alt_stream = alt_stream
707
+
708
+ def _apply_qk_norm(
709
+ self, q: torch.Tensor, k: torch.Tensor
710
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
711
+ # overlap qk norm
712
+ if self.alt_stream is not None and get_is_capture_mode():
713
+ current_stream = torch.cuda.current_stream()
714
+ self.alt_stream.wait_stream(current_stream)
715
+ q_by_head = q.reshape(-1, self.head_dim)
716
+ q_by_head = self.q_norm(q_by_head)
717
+ with torch.cuda.stream(self.alt_stream):
718
+ k_by_head = k.reshape(-1, self.head_dim)
719
+ k_by_head = self.k_norm(k_by_head)
720
+ current_stream.wait_stream(self.alt_stream)
721
+ else:
722
+ q_by_head = q.reshape(-1, self.head_dim)
723
+ q_by_head = self.q_norm(q_by_head)
724
+ k_by_head = k.reshape(-1, self.head_dim)
725
+ k_by_head = self.k_norm(k_by_head)
726
+ q = q_by_head.view(q.shape)
727
+ k = k_by_head.view(k.shape)
728
+ return q, k
729
+
730
+ def self_attention(
731
+ self,
732
+ positions: torch.Tensor,
733
+ hidden_states: torch.Tensor,
734
+ forward_batch: ForwardBatch,
735
+ ) -> torch.Tensor:
736
+ qkv, _ = self.qkv_proj(hidden_states)
737
+
738
+ if self.attn_output_gate:
739
+ q_gate, k, v = qkv.split(
740
+ [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
741
+ )
742
+ orig_shape = q_gate.shape[:-1]
743
+ q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
744
+ q, gate = torch.chunk(q_gate, 2, dim=-1)
745
+ q = q.reshape(*orig_shape, -1)
746
+ gate = gate.reshape(*orig_shape, -1)
747
+ else:
748
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
749
+
750
+ q, k = self._apply_qk_norm(q, k)
751
+
752
+ q, k = self.rotary_emb(positions, q, k)
753
+
754
+ attn_output = self.attn(q, k, v, forward_batch)
755
+
756
+ if self.attn_output_gate:
757
+ gate = torch.sigmoid(gate)
758
+ attn_output = attn_output * gate
759
+
760
+ output, _ = self.o_proj(attn_output)
761
+ return output
762
+
763
+ def forward(
764
+ self,
765
+ positions: torch.Tensor,
766
+ hidden_states: torch.Tensor,
767
+ residual: Optional[torch.Tensor],
768
+ forward_batch: ForwardBatch,
769
+ **kwargs: Any,
770
+ ):
771
+ hidden_states, residual = self.layer_communicator.prepare_attn(
772
+ hidden_states, residual, forward_batch
773
+ )
774
+
775
+ if not forward_batch.forward_mode.is_idle():
776
+ hidden_states = self.self_attention(
777
+ positions=positions,
778
+ hidden_states=hidden_states,
779
+ forward_batch=forward_batch,
780
+ )
781
+
782
+ # Fully Connected
783
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
784
+ hidden_states, residual, forward_batch
785
+ )
786
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
787
+ forward_batch
788
+ )
789
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
790
+
791
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
792
+ hidden_states, residual, forward_batch
793
+ )
794
+
795
+ return hidden_states, residual
796
+
797
+
798
+ ALL_DECODER_LAYER_TYPES = {
799
+ "attention": Qwen3HybridAttentionDecoderLayer,
800
+ "linear_attention": Qwen3HybridLinearDecoderLayer,
801
+ }
802
+
803
+
804
+ class Qwen3NextModel(nn.Module):
805
+ def __init__(
806
+ self,
807
+ config: Qwen3NextConfig,
808
+ quant_config: Optional[QuantizationConfig] = None,
809
+ prefix: str = "",
810
+ ) -> None:
811
+ super().__init__()
812
+ self.config = config
813
+
814
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
815
+
816
+ self.embed_tokens = VocabParallelEmbedding(
817
+ config.vocab_size,
818
+ config.hidden_size,
819
+ org_num_embeddings=config.vocab_size,
820
+ enable_tp=not is_dp_attention_enabled(),
821
+ )
822
+
823
+ def get_layer(idx: int, prefix: str):
824
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
825
+ return layer_class(
826
+ config,
827
+ idx,
828
+ quant_config=quant_config,
829
+ prefix=prefix,
830
+ alt_stream=alt_stream,
831
+ )
832
+
833
+ self.layers = make_layers(
834
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
835
+ )
836
+
837
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
838
+ self.infer_count = 0
839
+
840
+ def forward(
841
+ self,
842
+ input_ids: torch.Tensor,
843
+ positions: torch.Tensor,
844
+ forward_batch: ForwardBatch,
845
+ # mamba_cache_params: MambaCacheParams,
846
+ inputs_embeds: Optional[torch.Tensor] = None,
847
+ ) -> torch.Tensor:
848
+
849
+ # pass a sequence index tensor, that is required for
850
+ # proper continuous batching computation including
851
+ # chunked prefill
852
+ if inputs_embeds is not None:
853
+ hidden_states = inputs_embeds
854
+ else:
855
+ hidden_states = self.embed_tokens(input_ids)
856
+
857
+ residual = None
858
+ for i in range(len(self.layers)):
859
+ layer = self.layers[i]
860
+ with get_global_expert_distribution_recorder().with_current_layer(i):
861
+ hidden_states, residual = layer(
862
+ layer_id=i,
863
+ positions=positions,
864
+ hidden_states=hidden_states,
865
+ residual=residual,
866
+ forward_batch=forward_batch,
867
+ )
868
+
869
+ if not forward_batch.forward_mode.is_idle():
870
+ if residual is None:
871
+ hidden_states = self.norm(hidden_states)
872
+ else:
873
+ hidden_states, _ = self.norm(hidden_states, residual)
874
+
875
+ return hidden_states
876
+
877
+
878
+ class HybridLayerType(enum.Enum):
879
+ full_attention = "attention"
880
+ swa_attention = "swa_attention"
881
+ linear_attention = "linear_attention"
882
+ mamba2 = "mamba"
883
+
884
+
885
+ class Qwen3NextForCausalLM(nn.Module):
886
+ fall_back_to_pt_during_load = False
887
+
888
+ def __init__(
889
+ self,
890
+ config: Qwen3NextConfig,
891
+ quant_config: Optional[QuantizationConfig] = None,
892
+ prefix: str = "",
893
+ ) -> None:
894
+ super().__init__()
895
+ self.config = config
896
+ self.pp_group = get_pp_group()
897
+ assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
898
+ self.quant_config = quant_config
899
+ self.model = Qwen3NextModel(
900
+ config, quant_config, prefix=add_prefix("model", prefix)
901
+ )
902
+ self.lm_head = ParallelLMHead(
903
+ config.vocab_size,
904
+ config.hidden_size,
905
+ quant_config=quant_config,
906
+ org_num_embeddings=config.vocab_size,
907
+ prefix=add_prefix("lm_head", prefix),
908
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
909
+ )
910
+ self.lm_head = self.lm_head.float()
911
+ self.logits_processor = LogitsProcessor(config)
912
+
913
+ self._routed_experts_weights_of_layer = LazyValue(
914
+ lambda: {
915
+ layer_id: layer.mlp.get_moe_weights()
916
+ for layer_id, layer in enumerate(self.model.layers)
917
+ if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
918
+ }
919
+ )
920
+
921
+ @property
922
+ def routed_experts_weights_of_layer(self):
923
+ return self._routed_experts_weights_of_layer.value
924
+
925
+ @torch.no_grad()
926
+ def forward(
927
+ self,
928
+ input_ids: torch.Tensor,
929
+ positions: torch.Tensor,
930
+ forward_batch: ForwardBatch,
931
+ inputs_embeds: Optional[torch.Tensor] = None,
932
+ **kwargs,
933
+ ):
934
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
935
+
936
+ return self.logits_processor(
937
+ input_ids, hidden_states, self.lm_head, forward_batch
938
+ )
939
+
940
+ def get_embed_and_head(self):
941
+ return self.model.embed_tokens.weight, self.lm_head.weight
942
+
943
+ def set_embed_and_head(self, embed, head):
944
+ del self.model.embed_tokens.weight
945
+ del self.lm_head.weight
946
+ self.model.embed_tokens.weight = embed
947
+ self.lm_head.weight = head
948
+ torch.cuda.empty_cache()
949
+ torch.cuda.synchronize()
950
+
951
+ def load_weights(
952
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
953
+ ) -> Set[str]:
954
+ stacked_params_mapping = [
955
+ # (param_name, shard_name, shard_id)
956
+ ("qkv_proj", "q_proj", "q"),
957
+ ("qkv_proj", "k_proj", "k"),
958
+ ("qkv_proj", "v_proj", "v"),
959
+ ("gate_up_proj", "gate_proj", 0),
960
+ ("gate_up_proj", "up_proj", 1),
961
+ ]
962
+
963
+ # Params for weights, fp8 weight scales, fp8 activation scales
964
+ # (param_name, weight_name, expert_id, shard_id)
965
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
966
+ ckpt_gate_proj_name="gate_proj",
967
+ ckpt_down_proj_name="down_proj",
968
+ ckpt_up_proj_name="up_proj",
969
+ num_experts=self.config.num_experts,
970
+ )
971
+
972
+ params_dict = dict(self.named_parameters())
973
+ loaded_params: Set[str] = set()
974
+ for name, loaded_weight in weights:
975
+
976
+ if is_mtp:
977
+
978
+ if "mtp" not in name:
979
+ continue
980
+
981
+ if name in [
982
+ "mtp.fc.weight",
983
+ "mtp.pre_fc_norm_embedding.weight",
984
+ "mtp.pre_fc_norm_hidden.weight",
985
+ ]:
986
+ name = name.replace("mtp.", "")
987
+ else:
988
+ name = name.replace("mtp", "model")
989
+
990
+ if not is_mtp and "mtp" in name:
991
+ continue
992
+
993
+ if "rotary_emb.inv_freq" in name:
994
+ continue
995
+
996
+ if ".self_attn." in name:
997
+ name = name.replace(".self_attn", "")
998
+
999
+ for param_name, weight_name, shard_id in stacked_params_mapping:
1000
+ if weight_name not in name:
1001
+ continue
1002
+
1003
+ # TODO(fix mtp loading)
1004
+ if "mlp.experts" in name:
1005
+ continue
1006
+
1007
+ name = name.replace(weight_name, param_name)
1008
+ # Skip loading extra bias for GPTQ models.
1009
+ if name.endswith(".bias") and name not in params_dict:
1010
+ continue
1011
+ # Skip layers on other devices.
1012
+ # if is_pp_missing_parameter(name, self):
1013
+ # continue
1014
+ if name not in params_dict:
1015
+ continue
1016
+ param = params_dict[name]
1017
+ weight_loader = getattr(param, "weight_loader")
1018
+ weight_loader(param, loaded_weight, shard_id)
1019
+ break
1020
+ else:
1021
+ for mapping in expert_params_mapping:
1022
+ param_name, weight_name, expert_id, shard_id = mapping
1023
+ if weight_name not in name:
1024
+ continue
1025
+ name = name.replace(weight_name, param_name)
1026
+ # Skip layers on other devices.
1027
+ # if is_pp_missing_parameter(name, self):
1028
+ # continue
1029
+ # Skip loading extra bias for GPTQ models.
1030
+ if (
1031
+ name.endswith(".bias") or name.endswith("_bias")
1032
+ ) and name not in params_dict:
1033
+ continue
1034
+ param = params_dict[name]
1035
+
1036
+ weight_loader = getattr(param, "weight_loader")
1037
+ weight_loader(
1038
+ param,
1039
+ loaded_weight,
1040
+ name,
1041
+ shard_id=shard_id,
1042
+ expert_id=expert_id,
1043
+ )
1044
+ break
1045
+ else:
1046
+ # Skip loading extra bias for GPTQ models.
1047
+ if name.endswith(".bias") and name not in params_dict:
1048
+ continue
1049
+ # if is_pp_missing_parameter(name, self):
1050
+ # continue
1051
+
1052
+ param = params_dict[name]
1053
+ weight_loader = getattr(
1054
+ param, "weight_loader", default_weight_loader
1055
+ )
1056
+ weight_loader(param, loaded_weight)
1057
+ loaded_params.add(name)
1058
+ return loaded_params
1059
+
1060
+ @classmethod
1061
+ def get_model_config_for_expert_location(cls, config):
1062
+ return ModelConfigForExpertLocation(
1063
+ num_layers=config.num_hidden_layers,
1064
+ num_logical_experts=config.num_experts,
1065
+ num_groups=None,
1066
+ )
1067
+
1068
+
1069
+ EntryClass = Qwen3NextForCausalLM