sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
16
16
 
17
17
  import logging
18
- from typing import Any, Dict, Iterable, Optional, Tuple
18
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
@@ -30,8 +30,13 @@ from sglang.srt.distributed import (
30
30
  parallel_state,
31
31
  tensor_model_parallel_all_reduce,
32
32
  )
33
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
34
+ use_symmetric_memory,
35
+ )
36
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
38
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
33
39
  from sglang.srt.layers.activation import SiluAndMul
34
- from sglang.srt.layers.amx_utils import PackWeightMethod
35
40
  from sglang.srt.layers.communicator import (
36
41
  LayerCommunicator,
37
42
  LayerScatterModes,
@@ -44,56 +49,41 @@ from sglang.srt.layers.dp_attention import (
44
49
  )
45
50
  from sglang.srt.layers.layernorm import RMSNorm
46
51
  from sglang.srt.layers.linear import (
47
- ColumnParallelLinear,
48
52
  MergedColumnParallelLinear,
49
53
  QKVParallelLinear,
50
- ReplicatedLinear,
51
54
  RowParallelLinear,
52
55
  )
53
56
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
- from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
57
+ from sglang.srt.layers.moe import (
58
+ get_moe_a2a_backend,
59
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
60
+ )
55
61
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
56
62
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
57
63
  from sglang.srt.layers.moe.topk import TopK
58
64
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
59
- from sglang.srt.layers.quantization.fp8_kernel import (
60
- is_fp8_fnuz,
61
- per_tensor_quant_mla_fp8,
62
- per_token_group_quant_mla_deep_gemm_masked_fp8,
63
- )
65
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
64
66
  from sglang.srt.layers.radix_attention import RadixAttention
65
67
  from sglang.srt.layers.rotary_embedding import get_rope
68
+ from sglang.srt.layers.utils import PPMissingLayer
66
69
  from sglang.srt.layers.vocab_parallel_embedding import (
67
70
  ParallelLMHead,
68
71
  VocabParallelEmbedding,
69
72
  )
70
- from sglang.srt.managers.schedule_batch import global_server_args_dict
71
73
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
72
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
74
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
73
75
  from sglang.srt.model_loader.weight_utils import default_weight_loader
74
- from sglang.srt.models.deepseek_v2 import (
75
- DeepseekV2DecoderLayer,
76
- DeepseekV2ForCausalLM,
77
- DeepseekV2Model,
78
- DeepseekV2MoE,
79
- )
80
- from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
76
+ from sglang.srt.server_args import get_global_server_args
77
+ from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
81
78
  from sglang.srt.utils import (
82
- BumpAllocator,
83
- LazyValue,
84
79
  add_prefix,
85
- bind_or_assign,
86
80
  cpu_has_amx_support,
87
81
  get_bool_env_var,
88
82
  get_device_sm,
89
- get_int_env_var,
90
83
  is_cpu,
91
84
  is_cuda,
92
- is_flashinfer_available,
93
85
  is_hip,
94
- is_non_idle_and_non_empty,
95
- log_info_on_rank0,
96
- use_intel_amx_backend,
86
+ make_layers,
97
87
  )
98
88
 
99
89
  _is_hip = is_hip()
@@ -104,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support()
104
94
  _is_cpu = is_cpu()
105
95
  _device_sm = get_device_sm()
106
96
 
107
- if _is_cuda:
108
- from sgl_kernel import dsv3_router_gemm
109
- elif _is_cpu and _is_cpu_amx_available:
110
- pass
111
-
112
97
  logger = logging.getLogger(__name__)
113
98
 
114
99
 
@@ -148,8 +133,7 @@ class Glm4MoeMLP(nn.Module):
148
133
  )
149
134
  if hidden_act != "silu":
150
135
  raise ValueError(
151
- f"Unsupported activation: {hidden_act}. "
152
- "Only silu is supported for now."
136
+ f"Unsupported activation: {hidden_act}. Only silu is supported for now."
153
137
  )
154
138
  self.act_fn = SiluAndMul()
155
139
 
@@ -158,7 +142,6 @@ class Glm4MoeMLP(nn.Module):
158
142
  x,
159
143
  forward_batch=None,
160
144
  should_allreduce_fusion=False,
161
- gemm_output_zero_allocator: BumpAllocator = None,
162
145
  ):
163
146
  if (self.tp_size == 1) and x.shape[0] == 0:
164
147
  return x
@@ -338,47 +321,21 @@ class Glm4MoeGate(nn.Module):
338
321
  self,
339
322
  config,
340
323
  prefix: str = "",
341
- is_nextn: bool = False,
342
324
  ):
343
325
  super().__init__()
344
- self.is_nextn = is_nextn
345
326
  self.weight = nn.Parameter(
346
327
  torch.empty((config.n_routed_experts, config.hidden_size))
347
328
  )
348
329
  self.e_score_correction_bias = nn.Parameter(
349
330
  torch.empty((config.n_routed_experts), dtype=torch.float32)
350
331
  )
351
- if _is_cpu and _is_cpu_amx_available:
352
- self.quant_method = PackWeightMethod(weight_names=["weight"])
353
332
 
354
333
  def forward(self, hidden_states):
355
- if use_intel_amx_backend(self):
356
- return torch.ops.sgl_kernel.weight_packed_linear(
357
- hidden_states,
358
- self.weight,
359
- None, # bias
360
- True, # is_vnni
361
- )
362
-
363
- # NOTE: For some unknown reason, router_gemm seems degrade accept length.
364
- if (
365
- _is_cuda
366
- and not self.is_nextn
367
- and hidden_states.shape[0] < 4
368
- and hidden_states.shape[1] == 7168
369
- and self.weight.shape[0] == 256
370
- and _device_sm >= 90
371
- ):
372
- logits = dsv3_router_gemm(hidden_states, self.weight).to(
373
- hidden_states.dtype
374
- )
375
- else:
376
- logits = F.linear(hidden_states, self.weight, None)
377
-
334
+ logits = F.linear(hidden_states, self.weight, None)
378
335
  return logits
379
336
 
380
337
 
381
- class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
338
+ class Glm4MoeSparseMoeBlock(nn.Module):
382
339
  def __init__(
383
340
  self,
384
341
  config: PretrainedConfig,
@@ -386,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
386
343
  quant_config: Optional[QuantizationConfig] = None,
387
344
  prefix: str = "",
388
345
  alt_stream: Optional[torch.cuda.Stream] = None,
389
- is_nextn: bool = False,
390
346
  ):
391
347
  nn.Module.__init__(self)
348
+ self.top_k = config.num_experts_per_tok
392
349
  self.tp_size = get_tensor_model_parallel_world_size()
393
- self.ep_size = get_moe_expert_parallel_world_size()
394
350
  self.routed_scaling_factor = config.routed_scaling_factor
395
351
  self.n_shared_experts = config.n_shared_experts
396
- self.num_fused_shared_experts = (
397
- 0
398
- if global_server_args_dict["disable_shared_experts_fusion"]
399
- else config.n_shared_experts
400
- )
401
352
  self.config = config
402
353
  self.layer_id = layer_id
403
354
  self.alt_stream = alt_stream
@@ -414,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
414
365
  "Only silu is supported for now."
415
366
  )
416
367
 
417
- self.gate = Glm4MoeGate(
418
- config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
419
- )
368
+ self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix))
420
369
 
421
370
  self.topk = TopK(
422
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
371
+ top_k=self.top_k,
423
372
  renormalize=config.norm_topk_prob,
424
373
  use_grouped_topk=True,
425
374
  num_expert_group=config.n_group,
426
- num_fused_shared_experts=self.num_fused_shared_experts,
427
375
  topk_group=config.topk_group,
428
376
  correction_bias=self.gate.e_score_correction_bias,
429
377
  routed_scaling_factor=self.routed_scaling_factor,
430
378
  )
431
379
 
432
380
  self.experts = get_moe_impl_class(quant_config)(
433
- num_experts=config.n_routed_experts
434
- + self.num_fused_shared_experts
435
- + global_server_args_dict["ep_num_redundant_experts"],
436
- num_fused_shared_experts=self.num_fused_shared_experts,
437
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
381
+ num_experts=config.n_routed_experts,
382
+ top_k=self.top_k,
383
+ layer_id=self.layer_id,
438
384
  hidden_size=config.hidden_size,
439
385
  intermediate_size=config.moe_intermediate_size,
440
- layer_id=self.layer_id,
441
386
  quant_config=quant_config,
442
387
  routed_scaling_factor=self.routed_scaling_factor,
443
388
  prefix=add_prefix("experts", prefix),
444
389
  )
445
390
 
446
- self.shared_experts_is_int8 = False
447
- self.shared_experts_is_fp8 = False
448
- # self.shared_experts_weight_block_size = None
449
- if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
391
+ # shared expert
392
+ if config.n_shared_experts is not None:
450
393
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
451
394
  self.shared_experts = Glm4MoeMLP(
452
395
  hidden_size=config.hidden_size,
@@ -455,28 +398,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
455
398
  quant_config=quant_config,
456
399
  reduce_results=False,
457
400
  prefix=add_prefix("shared_experts", prefix),
458
- **(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}),
459
- )
460
- is_packed_weight = hasattr(
461
- self.shared_experts.gate_up_proj.quant_method, "quant_config"
462
- )
463
- self.shared_experts_is_int8 = (
464
- not is_packed_weight
465
- and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
401
+ **(
402
+ dict(tp_rank=0, tp_size=1)
403
+ if get_moe_a2a_backend().is_deepep()
404
+ or get_moe_a2a_backend().is_mooncake()
405
+ or should_use_flashinfer_cutlass_moe_fp4_allgather()
406
+ else {}
407
+ ),
466
408
  )
467
- self.shared_experts_is_fp8 = (
468
- not is_packed_weight
469
- and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
470
- )
471
-
472
- self.top_k = config.num_experts_per_tok
473
409
 
474
- if get_moe_a2a_backend().is_deepep():
410
+ if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
475
411
  # TODO: we will support tp < ep in the future
476
412
  self.ep_size = get_moe_expert_parallel_world_size()
477
413
  self.num_experts = (
478
414
  config.n_routed_experts
479
- + global_server_args_dict["ep_num_redundant_experts"]
415
+ + get_global_server_args().ep_num_redundant_experts
480
416
  )
481
417
  self.renormalize = config.norm_topk_prob
482
418
  self.topk_group = config.topk_group
@@ -487,27 +423,50 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
487
423
  else None
488
424
  )
489
425
 
490
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
491
- group=parallel_state.get_tp_group().device_group,
492
- router_topk=self.top_k,
493
- permute_fusion=True,
494
- num_experts=self.num_experts,
495
- num_local_experts=config.n_routed_experts // self.tp_size,
496
- hidden_size=config.hidden_size,
497
- params_dtype=config.torch_dtype,
498
- deepep_mode=get_deepep_mode(),
499
- async_finish=True,
500
- return_recv_hook=True,
501
- )
426
+ self._enable_a2a_moe = (
427
+ get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
428
+ )
429
+
430
+ def get_moe_weights(self):
431
+ return [
432
+ x.data
433
+ for name, x in self.experts.named_parameters()
434
+ if name not in ["correction_bias"]
435
+ ]
502
436
 
503
- self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
437
+ def forward(
438
+ self,
439
+ hidden_states: torch.Tensor,
440
+ forward_batch: Optional[ForwardBatch] = None,
441
+ should_allreduce_fusion: bool = False,
442
+ use_reduce_scatter: bool = False,
443
+ ) -> torch.Tensor:
444
+ if not self._enable_a2a_moe:
445
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
446
+ if (
447
+ self.alt_stream is not None
448
+ and hidden_states.shape[0] > 0
449
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
450
+ ):
451
+ return self.forward_normal_dual_stream(
452
+ hidden_states,
453
+ should_allreduce_fusion,
454
+ use_reduce_scatter,
455
+ )
456
+ else:
457
+ return self.forward_normal(
458
+ hidden_states,
459
+ should_allreduce_fusion,
460
+ use_reduce_scatter,
461
+ )
462
+ else:
463
+ return self.forward_deepep(hidden_states, forward_batch)
504
464
 
505
465
  def forward_normal_dual_stream(
506
466
  self,
507
467
  hidden_states: torch.Tensor,
508
468
  should_allreduce_fusion: bool = False,
509
469
  use_reduce_scatter: bool = False,
510
- gemm_output_zero_allocator: BumpAllocator = None,
511
470
  ) -> torch.Tensor:
512
471
 
513
472
  current_stream = torch.cuda.current_stream()
@@ -521,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
521
480
  final_hidden_states = self.experts(hidden_states, topk_output)
522
481
  if not _is_cuda:
523
482
  final_hidden_states *= self.routed_scaling_factor
483
+
524
484
  current_stream.wait_stream(self.alt_stream)
485
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
486
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
525
487
 
526
- if self.ep_size > 1:
527
- if (
528
- self.tp_size > 1
529
- and not should_allreduce_fusion
530
- and not use_reduce_scatter
531
- ):
532
- final_hidden_states = tensor_model_parallel_all_reduce(
533
- final_hidden_states
534
- )
535
- final_hidden_states += shared_output
536
- else:
537
- final_hidden_states += shared_output
538
- if (
539
- self.tp_size > 1
540
- and not should_allreduce_fusion
541
- and not use_reduce_scatter
542
- ):
543
- final_hidden_states = tensor_model_parallel_all_reduce(
544
- final_hidden_states
545
- )
488
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
489
+ final_hidden_states = final_hidden_states_out
490
+ sm.tag(final_hidden_states)
491
+ if (
492
+ self.tp_size > 1
493
+ and not should_allreduce_fusion
494
+ and not use_reduce_scatter
495
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
496
+ ):
497
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
546
498
  return final_hidden_states
547
499
 
548
500
  def forward_normal(
@@ -550,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
550
502
  hidden_states: torch.Tensor,
551
503
  should_allreduce_fusion: bool = False,
552
504
  use_reduce_scatter: bool = False,
553
- gemm_output_zero_allocator: BumpAllocator = None,
554
505
  ) -> torch.Tensor:
555
- if hasattr(self, "shared_experts") and use_intel_amx_backend(
556
- self.shared_experts.gate_up_proj
557
- ):
558
- return self.forward_cpu(hidden_states, should_allreduce_fusion)
506
+ if hidden_states.shape[0] > 0:
507
+ shared_output = self._forward_shared_experts(hidden_states)
508
+ # router_logits: (num_tokens, n_experts)
509
+ router_logits = self.gate(hidden_states)
510
+ topk_output = self.topk(hidden_states, router_logits)
511
+ else:
512
+ shared_output = None
513
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
559
514
 
560
- shared_output = self._forward_shared_experts(hidden_states)
561
- # router_logits: (num_tokens, n_experts)
562
- router_logits = self.gate(hidden_states)
563
- topk_output = self.topk(hidden_states, router_logits)
564
515
  final_hidden_states = self.experts(hidden_states, topk_output)
565
516
  if not _is_cuda and not _use_aiter:
566
517
  # fused in biased_grouped_topk so we can skip here
567
518
  final_hidden_states *= self.routed_scaling_factor
568
- if self.ep_size > 1:
569
- if self.tp_size > 1 and not should_allreduce_fusion:
570
- final_hidden_states = tensor_model_parallel_all_reduce(
571
- final_hidden_states
572
- )
573
- if shared_output is not None:
574
- final_hidden_states += shared_output
519
+ if shared_output is not None:
520
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
521
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
522
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
523
+ final_hidden_states = final_hidden_states_out
524
+ sm.tag(final_hidden_states)
525
+ if (
526
+ self.tp_size > 1
527
+ and not should_allreduce_fusion
528
+ and not use_reduce_scatter
529
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
530
+ ):
531
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
532
+ return final_hidden_states
533
+
534
+ def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
535
+ shared_output = None
536
+ if hidden_states.shape[0] > 0:
537
+ # router_logits: (num_tokens, n_experts)
538
+ router_logits, _ = self.gate(hidden_states)
539
+ shared_output = self._forward_shared_experts(hidden_states)
540
+ topk_output = self.topk(
541
+ hidden_states,
542
+ router_logits,
543
+ num_token_non_padded=forward_batch.num_token_non_padded,
544
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
545
+ layer_id=self.layer_id,
546
+ ),
547
+ )
575
548
  else:
576
- if shared_output is not None:
577
- final_hidden_states += shared_output
578
- if self.tp_size > 1 and not should_allreduce_fusion:
579
- final_hidden_states = tensor_model_parallel_all_reduce(
580
- final_hidden_states
581
- )
549
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
550
+ final_hidden_states = self.experts(
551
+ hidden_states=hidden_states,
552
+ topk_output=topk_output,
553
+ )
554
+
555
+ if shared_output is not None:
556
+ final_hidden_states.add_(shared_output)
557
+
582
558
  return final_hidden_states
583
559
 
560
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
561
+ shared_output = None
562
+ if hidden_states.shape[0] > 0:
563
+ shared_output = self.shared_experts(hidden_states)
564
+ return shared_output
565
+
584
566
 
585
- class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
567
+ class Glm4MoeDecoderLayer(nn.Module):
586
568
  def __init__(
587
569
  self,
588
570
  config: PretrainedConfig,
@@ -605,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
605
587
  rms_norm_eps = config.rms_norm_eps
606
588
  attention_bias = config.attention_bias
607
589
  self.layer_id = layer_id
590
+
608
591
  self.self_attn = Glm4MoeAttention(
609
592
  hidden_size=self.hidden_size,
610
593
  num_heads=config.num_attention_heads,
@@ -620,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
620
603
  quant_config=quant_config,
621
604
  prefix=add_prefix("self_attn", prefix),
622
605
  use_qk_norm=config.use_qk_norm,
606
+ alt_stream=alt_stream,
623
607
  )
624
608
 
625
609
  self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
626
610
  is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
627
611
 
628
- num_layers = 1 if is_nextn else config.num_hidden_layers
629
612
  self.layer_scatter_modes = LayerScatterModes.init_new(
630
613
  layer_id=layer_id,
631
- num_layers=num_layers,
614
+ num_layers=1 if is_nextn else config.num_hidden_layers,
632
615
  is_layer_sparse=self.is_layer_sparse,
633
616
  is_previous_layer_sparse=is_previous_layer_sparse,
634
617
  )
@@ -639,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
639
622
  quant_config=quant_config,
640
623
  prefix=add_prefix("mlp", prefix),
641
624
  layer_id=self.layer_id,
625
+ alt_stream=alt_stream,
642
626
  )
643
627
  else:
644
628
  if enable_moe_dense_fully_dp():
@@ -665,6 +649,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
665
649
  input_layernorm=self.input_layernorm,
666
650
  post_attention_layernorm=self.post_attention_layernorm,
667
651
  allow_reduce_scatter=True,
652
+ is_last_layer=(
653
+ is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
654
+ ),
655
+ )
656
+
657
+ def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
658
+ return is_nextn or (
659
+ self.config.n_routed_experts is not None
660
+ and layer_id >= self.config.first_k_dense_replace
668
661
  )
669
662
 
670
663
  def forward(
@@ -673,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
673
666
  hidden_states: torch.Tensor,
674
667
  forward_batch: ForwardBatch,
675
668
  residual: Optional[torch.Tensor],
676
- zero_allocator: BumpAllocator,
677
- gemm_output_zero_allocator: BumpAllocator = None,
678
669
  ) -> torch.Tensor:
679
670
  hidden_states, residual = self.layer_communicator.prepare_attn(
680
671
  hidden_states, residual, forward_batch
@@ -699,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
699
690
  return hidden_states, residual
700
691
 
701
692
 
702
- class Glm4MoeModel(DeepseekV2Model):
693
+ class Glm4MoeModel(nn.Module):
703
694
  def __init__(
704
695
  self,
705
696
  config: PretrainedConfig,
706
697
  quant_config: Optional[QuantizationConfig] = None,
707
698
  prefix: str = "",
708
- ) -> None:
709
- nn.Module.__init__(self)
710
- self.padding_id = config.pad_token_id
699
+ ):
700
+ super().__init__()
701
+ self.pp_group = get_pp_group()
702
+ self.config = config
711
703
  self.vocab_size = config.vocab_size
712
- self.first_k_dense_replace = config.first_k_dense_replace
704
+ self.embed_dim = config.hidden_size
705
+ if self.pp_group.is_first_rank:
706
+ self.embed_tokens = VocabParallelEmbedding(
707
+ config.vocab_size,
708
+ config.hidden_size,
709
+ enable_tp=not is_dp_attention_enabled(),
710
+ )
711
+ else:
712
+ self.embed_tokens = PPMissingLayer()
713
713
 
714
- self.embed_tokens = VocabParallelEmbedding(
715
- config.vocab_size,
716
- config.hidden_size,
717
- enable_tp=not is_dp_attention_enabled(),
718
- )
719
714
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
720
- self.layers = nn.ModuleList(
721
- [
722
- Glm4MoeDecoderLayer(
723
- config,
724
- layer_id,
725
- quant_config=quant_config,
726
- prefix=add_prefix(f"layers.{layer_id}", prefix),
727
- alt_stream=self.alt_stream,
728
- )
729
- for layer_id in range(config.num_hidden_layers)
730
- ]
715
+ self.layers, self.start_layer, self.end_layer = make_layers(
716
+ config.num_hidden_layers,
717
+ lambda idx, prefix: Glm4MoeDecoderLayer(
718
+ layer_id=idx,
719
+ config=config,
720
+ quant_config=quant_config,
721
+ prefix=prefix,
722
+ alt_stream=self.alt_stream,
723
+ ),
724
+ pp_rank=self.pp_group.rank_in_group,
725
+ pp_size=self.pp_group.world_size,
726
+ prefix=add_prefix("layers", prefix),
731
727
  )
732
- self.pp_group = get_pp_group()
733
- self.start_layer = 0
734
- self.end_layer = config.num_hidden_layers
735
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
728
+ if self.pp_group.is_last_rank:
729
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
730
+ else:
731
+ self.norm = PPMissingLayer(return_tuple=True)
732
+
733
+ def get_input_embeddings(self) -> torch.Tensor:
734
+ return self.embed_tokens
735
+
736
+ def forward(
737
+ self,
738
+ input_ids: torch.Tensor,
739
+ positions: torch.Tensor,
740
+ forward_batch: ForwardBatch,
741
+ input_embeds: torch.Tensor = None,
742
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
743
+ ) -> Union[torch.Tensor, PPProxyTensors]:
744
+ if self.pp_group.is_first_rank:
745
+ if input_embeds is None:
746
+ hidden_states = self.embed_tokens(input_ids)
747
+ else:
748
+ hidden_states = input_embeds
749
+ residual = None
750
+ else:
751
+ assert pp_proxy_tensors is not None
752
+ hidden_states = pp_proxy_tensors["hidden_states"]
753
+ residual = pp_proxy_tensors["residual"]
754
+
755
+ normal_start_layer = self.start_layer
756
+ normal_end_layer = self.end_layer
757
+ if forward_batch.can_run_tbo:
758
+ if (
759
+ self.first_k_dense_replace > normal_start_layer
760
+ and self.first_k_dense_replace < normal_end_layer
761
+ ):
762
+ normal_end_layer = self.first_k_dense_replace
763
+ elif self.first_k_dense_replace < normal_start_layer:
764
+ normal_end_layer = normal_start_layer = 0
765
+
766
+ for i in range(normal_start_layer, normal_end_layer):
767
+ with get_global_expert_distribution_recorder().with_current_layer(i):
768
+ layer = self.layers[i]
769
+ hidden_states, residual = layer(
770
+ positions,
771
+ hidden_states,
772
+ forward_batch,
773
+ residual,
774
+ )
736
775
 
776
+ if normal_end_layer != self.end_layer:
777
+ hidden_states, residual = model_forward_maybe_tbo(
778
+ layers=self.layers[normal_end_layer : self.end_layer],
779
+ enable_tbo=True,
780
+ positions=positions,
781
+ forward_batch=forward_batch,
782
+ hidden_states=hidden_states,
783
+ residual=residual,
784
+ input_data_scatter_mode=self.layers[
785
+ normal_end_layer - 1
786
+ ].layer_scatter_modes.layer_output_mode,
787
+ )
788
+
789
+ if not self.pp_group.is_last_rank:
790
+ return PPProxyTensors(
791
+ {
792
+ "hidden_states": hidden_states,
793
+ "residual": residual,
794
+ }
795
+ )
796
+ else:
797
+ if not forward_batch.forward_mode.is_idle():
798
+ if residual is None:
799
+ hidden_states = self.norm(hidden_states)
800
+ else:
801
+ hidden_states, _ = self.norm(hidden_states, residual)
802
+ return hidden_states
737
803
 
738
- class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
739
804
 
805
+ class Glm4MoeForCausalLM(nn.Module):
740
806
  def __init__(
741
807
  self,
742
808
  config: PretrainedConfig,
@@ -744,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
744
810
  prefix: str = "",
745
811
  ) -> None:
746
812
  nn.Module.__init__(self)
747
- config.moe_layer_freq = 1
748
813
  self.config = config
749
814
  self.tp_size = get_tensor_model_parallel_world_size()
750
815
  self.quant_config = quant_config
751
816
  self.pp_group = get_pp_group()
752
- self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
753
817
  self.model = Glm4MoeModel(
754
818
  config, quant_config, prefix=add_prefix("model", prefix)
755
819
  )
@@ -758,53 +822,45 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
758
822
  config.hidden_size,
759
823
  quant_config=quant_config,
760
824
  prefix=add_prefix("lm_head", prefix),
761
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
825
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
762
826
  )
763
827
  self.logits_processor = LogitsProcessor(config)
764
828
 
765
- self._routed_experts_weights_of_layer = LazyValue(
766
- lambda: {
767
- layer_id: layer.mlp.get_moe_weights()
768
- for layer_id, layer in enumerate(self.model.layers)
769
- if isinstance(layer.mlp, DeepseekV2MoE)
770
- }
771
- )
829
+ # For EAGLE3 support
830
+ self.capture_aux_hidden_states = False
772
831
 
773
- def determine_num_fused_shared_experts(
774
- self, architecture: str = "Glm4MoeForCausalLM"
775
- ):
776
- self.num_fused_shared_experts = 0
777
- if global_server_args_dict["disable_shared_experts_fusion"]:
778
- return
832
+ def get_input_embeddings(self) -> nn.Embedding:
833
+ return self.model.embed_tokens
779
834
 
780
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
781
- disable_reason = None
782
- if (
783
- not _is_cuda
784
- or torch.cuda.get_device_capability("cuda") < (8, 0)
785
- or self.config.architectures[0] != architecture
786
- or self.config.n_shared_experts != 1
787
- ):
788
- disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
789
- elif get_moe_expert_parallel_world_size() > 1:
790
- disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
791
-
792
- if disable_reason is not None:
793
- global_server_args_dict["disable_shared_experts_fusion"] = True
794
- self.num_fused_shared_experts = 0
795
- log_info_on_rank0(
796
- logger,
797
- f"{disable_reason} Shared experts fusion optimization is disabled.",
835
+ @torch.no_grad()
836
+ def forward(
837
+ self,
838
+ input_ids: torch.Tensor,
839
+ positions: torch.Tensor,
840
+ forward_batch: ForwardBatch,
841
+ input_embeds: torch.Tensor = None,
842
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
843
+ ) -> torch.Tensor:
844
+ hidden_states = self.model(
845
+ input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
846
+ )
847
+
848
+ if self.pp_group.is_last_rank:
849
+ return self.logits_processor(
850
+ input_ids, hidden_states, self.lm_head, forward_batch
798
851
  )
799
- return
852
+ else:
853
+ return hidden_states
800
854
 
801
- self.num_fused_shared_experts = self.config.n_shared_experts
855
+ @property
856
+ def start_layer(self):
857
+ return self.model.start_layer
802
858
 
803
- def get_input_embeddings(self) -> nn.Embedding:
804
- return self.model.embed_tokens
859
+ @property
860
+ def end_layer(self):
861
+ return self.model.end_layer
805
862
 
806
863
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
807
-
808
864
  if is_nextn:
809
865
  if hasattr(self.config, "num_nextn_predict_layers"):
810
866
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -826,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
826
882
  ("gate_up_proj", "gate_proj", 0),
827
883
  ("gate_up_proj", "up_proj", 1),
828
884
  ]
829
- if self.num_fused_shared_experts > 0:
830
- assert self.num_fused_shared_experts == 1
831
- weights_list = list(weights)
832
- weights_dict = dict(weights_list)
833
- if self.quant_config is not None:
834
- if self.quant_config.get_name() == "w8a8_int8":
835
- suffix_list = [
836
- "down_proj.weight",
837
- "down_proj.weight_scale",
838
- "gate_proj.weight",
839
- "gate_proj.weight_scale",
840
- "up_proj.weight",
841
- "up_proj.weight_scale",
842
- ]
843
- elif (
844
- self.quant_config.get_name() == "fp8"
845
- or self.quant_config.get_name() == "blockwise_int8"
846
- or self.quant_config.get_name() == "compressed_tensors"
847
- ):
848
- suffix_list = [
849
- "down_proj.weight",
850
- "down_proj.weight_scale",
851
- "gate_proj.weight",
852
- "gate_proj.weight_scale",
853
- "up_proj.weight",
854
- "up_proj.weight_scale",
855
- ]
856
- elif self.quant_config.get_name() == "awq":
857
- suffix_list = [
858
- "down_proj.qweight",
859
- "down_proj.qzeros",
860
- "down_proj.scales",
861
- "gate_proj.qweight",
862
- "gate_proj.qzeros",
863
- "gate_proj.scales",
864
- "up_proj.qweight",
865
- "up_proj.qzeros",
866
- "up_proj.scales",
867
- ]
868
- elif self.quant_config.get_name() == "modelopt_fp4":
869
- suffix_list = [
870
- "down_proj.weight",
871
- "down_proj.weight_scale",
872
- "down_proj.weight_scale_2",
873
- "down_proj.input_scale",
874
- "gate_proj.weight",
875
- "gate_proj.weight_scale",
876
- "gate_proj.weight_scale_2",
877
- "gate_proj.input_scale",
878
- "up_proj.weight",
879
- "up_proj.weight_scale",
880
- "up_proj.weight_scale_2",
881
- "up_proj.input_scale",
882
- ]
883
- else:
884
- raise ValueError(
885
- f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
886
- )
887
- else:
888
- suffix_list = [
889
- "down_proj.weight",
890
- "gate_proj.weight",
891
- "up_proj.weight",
892
- ]
893
- names_to_remove = []
894
-
895
- moe_layers = (
896
- range(
897
- self.config.first_k_dense_replace,
898
- self.config.num_hidden_layers,
899
- self.config.moe_layer_freq,
900
- )
901
- if not is_nextn
902
- else [nextn_layer_id]
903
- )
904
-
905
- for moe_layer in moe_layers:
906
- for suffix in suffix_list:
907
- shared_expert_weight_name = (
908
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
909
- )
910
- # online fp8 quantization does not load weight_scale
911
- if shared_expert_weight_name not in weights_dict:
912
- continue
913
- weights_list.append(
914
- (
915
- f"model.layers.{moe_layer}."
916
- f"mlp.experts."
917
- f"{self.config.n_routed_experts + 0}"
918
- f".{suffix}",
919
- weights_dict[shared_expert_weight_name],
920
- )
921
- )
922
- names_to_remove += [shared_expert_weight_name]
923
- weights = [w for w in weights_list if w[0] not in names_to_remove]
924
885
 
925
- # Params for weights, fp8 weight scales, fp8 activation scales
926
- # (param_name, weight_name, expert_id, shard_id)
927
886
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
928
887
  ckpt_gate_proj_name="gate_proj",
929
888
  ckpt_down_proj_name="down_proj",
930
889
  ckpt_up_proj_name="up_proj",
931
- num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
890
+ num_experts=self.config.n_routed_experts,
932
891
  )
933
892
 
934
- # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
935
- fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
936
- self.config.q_lora_rank is not None
937
- )
938
- cached_a_proj = {} if fuse_qkv_a_proj else None
939
-
940
893
  if is_nextn:
941
894
  nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
942
895
  nextn_spec_weight_names = [
@@ -992,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
992
945
  # name will be updated to mlp.experts[0].gate_up_proj, which
993
946
  # will then be updated below in expert_params_mapping
994
947
  # for mlp.experts[0].gate_gate_up_proj, which breaks load.
995
- if ("mlp.experts." in name) and name not in params_dict:
948
+ if "mlp.experts" in name:
996
949
  continue
997
950
  name = name.replace(weight_name, param_name)
998
951
  # Skip loading extra bias for GPTQ models.
999
952
  if name.endswith(".bias") and name not in params_dict:
1000
953
  continue
954
+ if name not in params_dict:
955
+ continue
956
+
1001
957
  param = params_dict[name]
1002
958
  weight_loader = param.weight_loader
1003
959
  weight_loader(param, loaded_weight, shard_id)
1004
960
  break
1005
961
  else:
962
+ # Track if this is an expert weight to enable early skipping
963
+ is_expert_weight = False
964
+
1006
965
  for mapping in expert_params_mapping:
1007
966
  param_name, weight_name, expert_id, shard_id = mapping
1008
967
  if weight_name not in name:
1009
968
  continue
969
+
970
+ # Mark as expert weight regardless of whether we can process it
971
+ is_expert_weight = True
972
+
1010
973
  name = name.replace(weight_name, param_name)
974
+ if name not in params_dict:
975
+ # Expert weight not on this rank, will be skipped below
976
+ continue
977
+
1011
978
  param = params_dict[name]
1012
979
  weight_loader = param.weight_loader
1013
980
  weight_loader(
@@ -1019,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
1019
986
  )
1020
987
  break
1021
988
  else:
989
+ if is_expert_weight:
990
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
991
+ continue
992
+
1022
993
  # Skip loading extra bias for GPTQ models.
1023
994
  if name.endswith(".bias") and name not in params_dict:
1024
995
  continue
1025
- if fuse_qkv_a_proj and (
1026
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
1027
- ):
1028
- cached_a_proj[name] = loaded_weight
1029
- q_a_proj_name = (
1030
- name
1031
- if "q_a_proj" in name
1032
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
1033
- )
1034
- kv_a_proj_name = (
1035
- name
1036
- if "kv_a_proj_with_mqa" in name
1037
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
1038
- )
996
+ if name not in params_dict:
997
+ continue
1039
998
 
1040
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
1041
- if (
1042
- q_a_proj_name in cached_a_proj
1043
- and kv_a_proj_name in cached_a_proj
1044
- ):
1045
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
1046
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
1047
- fused_weight = torch.cat(
1048
- [q_a_proj_weight, kv_a_proj_weight], dim=0
1049
- )
1050
- param_name = (
1051
- name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
1052
- if "q_a_proj" in name
1053
- else name.replace(
1054
- "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
1055
- )
1056
- )
1057
- param = params_dict[param_name]
1058
-
1059
- weight_loader = getattr(
1060
- param, "weight_loader", default_weight_loader
1061
- )
1062
- weight_loader(param, fused_weight)
1063
- cached_a_proj.pop(q_a_proj_name)
1064
- cached_a_proj.pop(kv_a_proj_name)
1065
- else:
1066
- if (
1067
- "k_scale" in name or "v_scale" in name
1068
- ) and name not in params_dict:
1069
- # modelopt attn kv scale is named differently
1070
- if any(scale in name for scale in ["k_scale", "v_scale"]):
1071
- name = name.replace("_proj", "attn_mqa")
1072
- else:
1073
- logger.warning(
1074
- f"Unknown scale found in checkpoint: {name}"
1075
- )
999
+ if name in params_dict.keys():
1076
1000
  param = params_dict[name]
1077
1001
  weight_loader = getattr(
1078
1002
  param, "weight_loader", default_weight_loader
1079
1003
  )
1080
1004
  weight_loader(param, loaded_weight)
1005
+ else:
1006
+ logger.warning(f"Parameter {name} not found in params_dict")
1007
+
1008
+ def get_embed_and_head(self):
1009
+ return self.model.embed_tokens.weight, self.lm_head.weight
1010
+
1011
+ def set_embed_and_head(self, embed, head):
1012
+ del self.model.embed_tokens.weight
1013
+ del self.lm_head.weight
1014
+ self.model.embed_tokens.weight = embed
1015
+ self.lm_head.weight = head
1016
+ torch.cuda.empty_cache()
1017
+ torch.cuda.synchronize()
1018
+
1019
+ @classmethod
1020
+ def get_model_config_for_expert_location(cls, config):
1021
+ return ModelConfigForExpertLocation(
1022
+ num_layers=config.num_hidden_layers,
1023
+ num_logical_experts=config.n_routed_experts,
1024
+ num_groups=config.n_group,
1025
+ )
1081
1026
 
1082
1027
 
1083
1028
  EntryClass = [Glm4MoeForCausalLM]