sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,377 +1,907 @@
1
- # Copyright 2023-2024 SGLang Team
2
- # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
3
-
4
- from collections.abc import Iterable
5
- from typing import Optional, Tuple
1
+ # coding=utf-8
2
+ # Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ SGLang BailingMoE model."""
21
+ import logging
22
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
6
23
 
7
24
  import torch
8
25
  import torch.nn.functional as F
9
26
  from torch import nn
10
- from transformers.configuration_utils import PretrainedConfig
27
+ from transformers import PretrainedConfig
11
28
 
12
29
  from sglang.srt.distributed import (
30
+ get_pp_group,
13
31
  get_tensor_model_parallel_world_size,
32
+ parallel_state,
14
33
  tensor_model_parallel_all_reduce,
15
34
  )
35
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
36
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
37
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
16
38
  from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.communicator import (
40
+ LayerCommunicator,
41
+ LayerScatterModes,
42
+ enable_moe_dense_fully_dp,
43
+ )
44
+ from sglang.srt.layers.dp_attention import (
45
+ get_attention_dp_size,
46
+ get_attention_tp_rank,
47
+ get_attention_tp_size,
48
+ is_dp_attention_enabled,
49
+ )
17
50
  from sglang.srt.layers.layernorm import RMSNorm
18
51
  from sglang.srt.layers.linear import (
19
52
  MergedColumnParallelLinear,
20
53
  QKVParallelLinear,
21
- ReplicatedLinear,
22
54
  RowParallelLinear,
23
55
  )
24
56
  from sglang.srt.layers.logits_processor import LogitsProcessor
25
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
57
+ from sglang.srt.layers.moe import get_moe_a2a_backend
58
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
59
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
60
+ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
26
61
  from sglang.srt.layers.moe.topk import TopK
62
+ from sglang.srt.layers.moe.utils import DeepEPMode
27
63
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
64
  from sglang.srt.layers.radix_attention import RadixAttention
29
65
  from sglang.srt.layers.rotary_embedding import get_rope
66
+ from sglang.srt.layers.utils import PPMissingLayer
30
67
  from sglang.srt.layers.vocab_parallel_embedding import (
31
68
  ParallelLMHead,
32
69
  VocabParallelEmbedding,
33
70
  )
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
71
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
72
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
73
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
74
  from sglang.srt.model_loader.weight_utils import default_weight_loader
36
- from sglang.srt.utils import add_prefix, make_layers
75
+ from sglang.srt.models.utils import (
76
+ create_fused_set_kv_buffer_arg,
77
+ enable_fused_set_kv_buffer,
78
+ )
79
+ from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
37
80
 
81
+ LoraConfig = None
82
+ logger = logging.getLogger(__name__)
83
+ _is_cuda = is_cuda()
38
84
 
39
- class BailingAttention(nn.Module):
40
85
 
86
+ class BailingMoEMLP(nn.Module):
41
87
  def __init__(
42
88
  self,
89
+ intermediate_size: int,
43
90
  config: PretrainedConfig,
44
- layer_id: int = 0,
45
91
  quant_config: Optional[QuantizationConfig] = None,
92
+ reduce_results: Optional[bool] = True,
46
93
  prefix: str = "",
47
- ):
94
+ tp_rank: Optional[int] = None,
95
+ tp_size: Optional[int] = None,
96
+ ) -> None:
48
97
  super().__init__()
49
- self.hidden_size = config.hidden_size
50
- tp_size = get_tensor_model_parallel_world_size()
51
-
52
- self.total_num_heads = config.num_attention_heads
53
- self.total_num_kv_heads = config.num_key_value_heads
54
-
55
- assert self.total_num_heads % tp_size == 0
56
- assert self.total_num_kv_heads % tp_size == 0
57
-
58
- self.num_heads = self.total_num_heads // tp_size
59
- self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
60
- self.q_size = self.num_heads * self.head_dim
61
-
62
- self.num_kv_heads = self.total_num_kv_heads // tp_size
63
- self.kv_size = self.num_kv_heads * self.head_dim
64
- self.scale = self.head_dim**-0.5
65
-
66
- self.query_key_value = QKVParallelLinear(
67
- self.hidden_size,
68
- self.head_dim,
69
- self.total_num_heads,
70
- self.total_num_kv_heads,
71
- bias=(config.use_bias or config.use_qkv_bias),
72
- quant_config=quant_config,
73
- prefix=add_prefix("query_key_value", prefix),
74
- )
98
+ self.tp_size = tp_size
75
99
 
76
- self.dense = RowParallelLinear(
77
- self.total_num_heads * self.head_dim,
78
- self.hidden_size,
100
+ self.gate_up_proj = MergedColumnParallelLinear(
101
+ config.hidden_size,
102
+ [intermediate_size] * 2,
79
103
  bias=config.use_bias,
80
104
  quant_config=quant_config,
81
- prefix=add_prefix("dense", prefix),
105
+ prefix=add_prefix("gate_up_proj", prefix),
106
+ tp_rank=tp_rank,
107
+ tp_size=tp_size,
82
108
  )
83
-
84
- self.attn = RadixAttention(
85
- self.num_heads,
86
- self.head_dim,
87
- self.scale,
88
- num_kv_heads=self.num_kv_heads,
89
- layer_id=layer_id,
109
+ self.down_proj = RowParallelLinear(
110
+ intermediate_size,
111
+ config.hidden_size,
112
+ bias=config.use_bias,
113
+ reduce_results=reduce_results,
90
114
  quant_config=quant_config,
91
- prefix=add_prefix("attn", prefix),
115
+ prefix=add_prefix("down_proj", prefix),
116
+ tp_rank=tp_rank,
117
+ tp_size=tp_size,
92
118
  )
93
119
 
94
- self.rotary_emb = get_rope(
95
- self.head_dim,
96
- rotary_dim=self.head_dim,
97
- max_position=config.max_position_embeddings,
98
- base=config.rope_theta,
99
- is_neox_style=True,
100
- rope_scaling=config.rope_scaling,
101
- )
120
+ if config.hidden_act != "silu":
121
+ raise ValueError("Unsupported activation. Only silu is supported for now.")
122
+ self.act_fn = SiluAndMul()
102
123
 
103
124
  def forward(
104
125
  self,
105
126
  hidden_states: torch.Tensor,
106
- position_ids: torch.Tensor,
107
- forward_batch: ForwardBatch,
127
+ forward_batch: Optional[ForwardBatch] = None,
128
+ use_reduce_scatter: bool = False,
108
129
  ) -> torch.Tensor:
109
- qkv, _ = self.query_key_value(hidden_states)
110
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
130
+ if (self.tp_size == 1) and hidden_states.shape[0] == 0:
131
+ return hidden_states
111
132
 
112
- q, k = self.rotary_emb(position_ids, q, k)
113
- context_layer = self.attn(q, k, v, forward_batch)
114
- attn_output, _ = self.dense(context_layer)
115
- return attn_output
133
+ gate_up, _ = self.gate_up_proj(hidden_states)
134
+ hidden_states = self.act_fn(gate_up)
135
+ hidden_states, _ = self.down_proj(
136
+ hidden_states, skip_all_reduce=use_reduce_scatter
137
+ )
138
+ return hidden_states
116
139
 
117
140
 
118
- class BailingMLP(nn.Module):
141
+ class BailingMoEGate(nn.Module):
119
142
  def __init__(
120
143
  self,
121
- intermediate_size: int,
122
- config: PretrainedConfig,
123
- quant_config: Optional[QuantizationConfig] = None,
124
- reduce_results: Optional[bool] = True,
144
+ config,
145
+ params_dtype: Optional[torch.dtype] = None,
125
146
  prefix: str = "",
126
- ) -> None:
147
+ ):
127
148
  super().__init__()
128
- self.gate_up_proj = MergedColumnParallelLinear(
129
- config.hidden_size,
130
- [intermediate_size] * 2,
131
- bias=config.use_bias,
132
- quant_config=quant_config,
133
- prefix=add_prefix("gate_up_proj", prefix),
134
- )
135
- self.down_proj = RowParallelLinear(
136
- intermediate_size,
137
- config.hidden_size,
138
- bias=config.use_bias,
139
- quant_config=quant_config,
140
- reduce_results=reduce_results,
141
- prefix=add_prefix("down_proj", prefix),
149
+ if params_dtype is None:
150
+ params_dtype = torch.get_default_dtype()
151
+ self.params_dtype = params_dtype
152
+ self.weight = nn.Parameter(
153
+ torch.empty(
154
+ (config.num_experts, config.hidden_size),
155
+ dtype=self.params_dtype,
156
+ ),
142
157
  )
143
- self.act_fn = SiluAndMul()
144
-
145
- def forward(self, x):
146
- x, _ = self.gate_up_proj(x)
147
- x = self.act_fn(x)
148
- x, _ = self.down_proj(x)
149
- return x
158
+ if getattr(config, "moe_router_enable_expert_bias", False):
159
+ self.expert_bias = nn.Parameter(
160
+ torch.empty((config.num_experts,), dtype=torch.float32),
161
+ )
162
+ else:
163
+ self.expert_bias = None
150
164
 
165
+ def forward(self, hidden_states):
166
+ logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
167
+ hidden_states.dtype
168
+ )
169
+ return logits
151
170
 
152
- class BailingMoE(nn.Module):
153
171
 
172
+ class BailingMoESparseMoeBlock(nn.Module):
154
173
  def __init__(
155
174
  self,
156
- config: PretrainedConfig,
157
175
  layer_id: int,
176
+ config: PretrainedConfig,
158
177
  quant_config: Optional[QuantizationConfig] = None,
178
+ alt_stream: Optional[torch.cuda.Stream] = None,
159
179
  prefix: str = "",
160
180
  ):
161
181
  super().__init__()
182
+ self.layer_id = layer_id
183
+ self.alt_stream = alt_stream
162
184
  self.tp_size = get_tensor_model_parallel_world_size()
163
- self.num_experts = config.num_experts
164
185
  self.top_k = config.num_experts_per_tok
186
+ self.norm_topk_prob = config.norm_topk_prob
165
187
  self.hidden_size = config.hidden_size
166
188
  self.num_shared_experts = config.num_shared_experts
167
- self.norm_expert_prob = config.norm_topk_prob
168
- self.moe_intermediate_size = config.moe_intermediate_size
189
+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
190
+ self.score_function = getattr(config, "score_function", None)
169
191
 
170
- self.gate = ReplicatedLinear(
171
- self.hidden_size, self.num_experts, bias=False, quant_config=None
192
+ if config.hidden_act != "silu":
193
+ raise ValueError(
194
+ f"Unsupported activation: {config.hidden_act}. "
195
+ "Only silu is supported for now."
196
+ )
197
+
198
+ # Gate always runs at half / full precision for now.
199
+ router_dtype = getattr(config, "router_dtype", None)
200
+ if router_dtype is None:
201
+ self.router_dtype = None
202
+ elif router_dtype == "fp32":
203
+ self.router_dtype = torch.float32
204
+ else:
205
+ self.router_dtype = torch.bfloat16
206
+
207
+ # TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
208
+ assert global_server_args_dict["ep_num_redundant_experts"] == 0
209
+ # check group topk
210
+ self.num_expert_group = getattr(config, "n_group", 0)
211
+ self.topk_group = getattr(config, "topk_group", 0)
212
+ if self.num_expert_group > 0 or self.topk_group > 0:
213
+ assert (
214
+ self.num_expert_group > 0
215
+ and 0 < self.topk_group <= self.num_expert_group
216
+ )
217
+ self.use_grouped_topk = True
218
+ else:
219
+ self.num_expert_group = self.topk_group = None
220
+ self.use_grouped_topk = False
221
+
222
+ self.num_experts = (
223
+ config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
224
+ )
225
+
226
+ self.gate = BailingMoEGate(
227
+ config=config,
228
+ params_dtype=self.router_dtype,
229
+ prefix=add_prefix("gate", prefix),
230
+ )
231
+ self.correction_bias = (
232
+ self.gate.expert_bias.data if self.gate.expert_bias is not None else None
172
233
  )
173
234
 
174
- self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
235
+ if self.score_function is not None:
236
+ assert (
237
+ self.score_function == "softmax" and self.correction_bias is None
238
+ ) or (
239
+ self.score_function == "sigmoid" and self.correction_bias is not None
240
+ ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"
175
241
 
176
- self.experts = FusedMoE(
242
+ self.topk = TopK(
243
+ top_k=self.top_k,
244
+ renormalize=self.norm_topk_prob,
245
+ use_grouped_topk=self.use_grouped_topk,
246
+ num_expert_group=self.num_expert_group,
247
+ # num_fused_shared_experts=self.num_fused_shared_experts,
248
+ topk_group=self.topk_group,
249
+ correction_bias=self.correction_bias,
250
+ routed_scaling_factor=self.routed_scaling_factor,
251
+ )
252
+
253
+ self.experts = get_moe_impl_class(quant_config)(
177
254
  num_experts=self.num_experts,
178
255
  top_k=self.top_k,
179
- layer_id=layer_id,
180
- hidden_size=self.hidden_size,
181
- intermediate_size=self.moe_intermediate_size,
182
- reduce_results=False,
256
+ layer_id=self.layer_id,
257
+ hidden_size=config.hidden_size,
258
+ intermediate_size=config.moe_intermediate_size,
183
259
  quant_config=quant_config,
260
+ routed_scaling_factor=self.routed_scaling_factor,
184
261
  prefix=add_prefix("experts", prefix),
185
262
  )
186
-
187
- if self.num_shared_experts > 0:
188
- shared_intermediate_size = (
189
- self.moe_intermediate_size * self.num_shared_experts
190
- )
191
- self.shared_experts = BailingMLP(
192
- intermediate_size=shared_intermediate_size,
263
+ # shared expert
264
+ if config.num_shared_experts is not None:
265
+ if hasattr(config, "moe_shared_expert_intermediate_size"):
266
+ intermediate_size = config.moe_shared_expert_intermediate_size
267
+ else:
268
+ intermediate_size = config.moe_intermediate_size
269
+ intermediate_size *= config.num_shared_experts
270
+ # disable tp for shared experts when enable deepep moe
271
+ self.shared_experts = BailingMoEMLP(
272
+ intermediate_size=intermediate_size,
193
273
  config=config,
194
274
  quant_config=quant_config,
195
275
  reduce_results=False,
196
276
  prefix=add_prefix("shared_experts", prefix),
277
+ **(
278
+ dict(tp_rank=0, tp_size=1)
279
+ if get_moe_a2a_backend().is_deepep()
280
+ else {}
281
+ ),
197
282
  )
283
+ # dispatcher
284
+ if get_moe_a2a_backend().is_deepep():
285
+ # TODO: we will support tp < ep in the future
286
+ self.ep_size = get_tensor_model_parallel_world_size()
287
+
288
+ self.deepep_dispatcher = DeepEPDispatcher(
289
+ group=parallel_state.get_tp_group().device_group,
290
+ router_topk=self.top_k,
291
+ permute_fusion=True,
292
+ num_experts=self.num_experts,
293
+ num_local_experts=config.num_experts // self.tp_size,
294
+ hidden_size=config.hidden_size,
295
+ params_dtype=config.torch_dtype,
296
+ deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
297
+ async_finish=True, # TODO
298
+ return_recv_hook=True,
299
+ )
300
+
301
+ def forward(
302
+ self,
303
+ hidden_states: torch.Tensor,
304
+ forward_batch: Optional[ForwardBatch] = None,
305
+ use_reduce_scatter: bool = False,
306
+ ) -> torch.Tensor:
307
+ if not get_moe_a2a_backend().is_deepep():
308
+ return self.forward_normal(hidden_states, use_reduce_scatter)
198
309
  else:
199
- self.shared_experts = None
310
+ return self.forward_deepep(hidden_states, forward_batch)
200
311
 
201
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
- orig_shape = hidden_states.shape
203
- hidden_states_flat = hidden_states.view(-1, self.hidden_size)
312
+ def get_moe_weights(self):
313
+ return [
314
+ x.data
315
+ for name, x in self.experts.named_parameters()
316
+ if name not in ["correction_bias"]
317
+ ]
204
318
 
319
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
205
320
  shared_output = None
206
- if self.shared_experts is not None:
207
- shared_output = self.shared_experts(hidden_states_flat)
321
+ if self.num_shared_experts > 0:
322
+ shared_output = self.shared_experts(hidden_states)
323
+ return shared_output
208
324
 
209
- router_logits, _ = self.gate(hidden_states_flat)
210
- topk_output = self.topk(hidden_states_flat, router_logits)
211
- final_hidden_states = self.experts(hidden_states_flat, topk_output)
325
+ def _forward_router_experts(self, hidden_states: torch.Tensor):
326
+ # router_logits: (num_tokens, n_experts)
327
+ router_logits = self.gate(hidden_states)
328
+ topk_output = self.topk(hidden_states, router_logits)
329
+ return self.experts(hidden_states, topk_output)
212
330
 
213
- if shared_output is not None:
331
+ def forward_normal_dual_stream(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ ) -> torch.Tensor:
335
+ current_stream = torch.cuda.current_stream()
336
+ self.alt_stream.wait_stream(current_stream)
337
+ shared_output = self._forward_shared_experts(hidden_states.clone())
338
+
339
+ with torch.cuda.stream(self.alt_stream):
340
+ router_output = self._forward_router_experts(hidden_states)
341
+ current_stream.wait_stream(self.alt_stream)
342
+
343
+ return router_output, shared_output
344
+
345
+ def forward_normal(
346
+ self,
347
+ hidden_states: torch.Tensor,
348
+ use_reduce_scatter: bool = False,
349
+ ) -> torch.Tensor:
350
+ num_tokens, hidden_size = hidden_states.shape
351
+ hidden_states = hidden_states.view(-1, hidden_size)
352
+
353
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
354
+ if (
355
+ self.alt_stream is not None
356
+ and hidden_states.shape[0] > 0
357
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
358
+ and get_is_capture_mode()
359
+ ):
360
+ final_hidden_states, shared_output = self.forward_normal_dual_stream(
361
+ hidden_states
362
+ )
363
+ else:
364
+ shared_output = self._forward_shared_experts(hidden_states)
365
+ final_hidden_states = self._forward_router_experts(hidden_states)
366
+
367
+ if self.num_shared_experts > 0:
214
368
  final_hidden_states = final_hidden_states + shared_output
215
369
 
216
- if self.tp_size > 1:
370
+ if self.tp_size > 1 and not use_reduce_scatter:
217
371
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
372
+ return final_hidden_states.view(num_tokens, hidden_size)
373
+
374
+ def forward_deepep(
375
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
376
+ ) -> torch.Tensor:
377
+ shared_output = None
378
+ forward_mode = forward_batch.forward_mode
379
+ if is_non_idle_and_non_empty(forward_mode, hidden_states):
380
+ router_logits = self.gate(hidden_states)
381
+ if self.num_shared_experts > 0:
382
+ shared_output = self.shared_experts(hidden_states)
383
+
384
+ topk_weights, topk_idx, _ = self.topk(
385
+ hidden_states,
386
+ router_logits,
387
+ num_token_non_padded=forward_batch.num_token_non_padded,
388
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
389
+ layer_id=self.layer_id,
390
+ ),
391
+ )
392
+ else:
393
+ topk_idx = torch.full(
394
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
395
+ )
396
+ topk_weights = torch.empty(
397
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
398
+ )
399
+
400
+ if self.ep_size > 1:
401
+ (
402
+ hidden_states,
403
+ topk_idx,
404
+ topk_weights,
405
+ reorder_topk_ids,
406
+ num_recv_tokens_per_expert,
407
+ seg_indptr,
408
+ masked_m,
409
+ expected_m,
410
+ ) = self.deepep_dispatcher.dispatch(
411
+ hidden_states,
412
+ topk_idx,
413
+ topk_weights,
414
+ forward_batch=forward_batch,
415
+ )
416
+
417
+ final_hidden_states = self.experts(
418
+ hidden_states=hidden_states,
419
+ topk_idx=topk_idx,
420
+ topk_weights=topk_weights,
421
+ reorder_topk_ids=reorder_topk_ids,
422
+ seg_indptr=seg_indptr,
423
+ masked_m=masked_m,
424
+ expected_m=expected_m,
425
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
426
+ forward_batch=forward_batch,
427
+ )
428
+ if self.ep_size > 1:
429
+ final_hidden_states = self.deepep_dispatcher.combine(
430
+ final_hidden_states,
431
+ topk_idx,
432
+ topk_weights,
433
+ forward_batch=forward_batch,
434
+ )
218
435
 
219
- return final_hidden_states.view(orig_shape)
436
+ final_hidden_states *= self.routed_scaling_factor
220
437
 
438
+ if shared_output is not None:
439
+ final_hidden_states = final_hidden_states + shared_output
440
+ return final_hidden_states
221
441
 
222
- class BailingMoeBlock(nn.Module):
223
442
 
443
+ class BailingMoEAttention(nn.Module):
224
444
  def __init__(
225
445
  self,
226
446
  config: PretrainedConfig,
227
- layer_id: int,
447
+ layer_id: int = 0,
228
448
  quant_config: Optional[QuantizationConfig] = None,
449
+ reduce_results: bool = True,
229
450
  prefix: str = "",
451
+ alt_stream: Optional[torch.cuda.Stream] = None,
230
452
  ):
231
453
  super().__init__()
232
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
- self.attention = BailingAttention(
234
- config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
454
+ self.hidden_size = config.hidden_size
455
+ self.total_num_heads = config.num_attention_heads
456
+ self.total_kv_heads = config.num_key_value_heads
457
+ self.dp_size = get_attention_dp_size()
458
+ attn_tp_rank = get_attention_tp_rank()
459
+ attn_tp_size = get_attention_tp_size()
460
+
461
+ assert self.total_num_heads % attn_tp_size == 0
462
+ assert self.total_kv_heads % attn_tp_size == 0
463
+ assert self.total_num_heads >= self.total_kv_heads
464
+
465
+ self.num_heads = self.total_num_heads // attn_tp_size
466
+ self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
467
+ self.q_size = self.head_dim * self.num_heads
468
+
469
+ self.num_kv_heads = self.total_kv_heads // attn_tp_size
470
+ self.kv_size = max(1, self.num_kv_heads * self.head_dim)
471
+
472
+ self.scale = self.head_dim**-0.5
473
+
474
+ self.use_qk_norm = getattr(config, "use_qk_norm", False)
475
+
476
+ self.query_key_value = QKVParallelLinear(
477
+ self.hidden_size,
478
+ self.head_dim,
479
+ self.total_num_heads,
480
+ self.total_kv_heads,
481
+ bias=(config.use_bias or config.use_qkv_bias),
482
+ quant_config=quant_config,
483
+ prefix=add_prefix("query_key_value", prefix),
484
+ tp_rank=attn_tp_rank,
485
+ tp_size=attn_tp_size,
486
+ )
487
+
488
+ if self.use_qk_norm:
489
+ self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
490
+ self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
491
+
492
+ self.dense = RowParallelLinear(
493
+ self.total_num_heads * self.head_dim,
494
+ self.hidden_size,
495
+ bias=config.use_bias,
496
+ quant_config=quant_config,
497
+ reduce_results=reduce_results,
498
+ prefix=add_prefix("dense", prefix),
499
+ tp_rank=attn_tp_rank,
500
+ tp_size=attn_tp_size,
235
501
  )
236
- self.post_attention_layernorm = RMSNorm(
237
- config.hidden_size, eps=config.rms_norm_eps
502
+
503
+ if hasattr(config, "partial_rotary_factor"):
504
+ self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
505
+ elif hasattr(config, "rotary_dim"):
506
+ self.rotary_dim = config.rotary_dim
507
+ else:
508
+ self.rotary_dim = self.head_dim
509
+ self.rotary_emb = get_rope(
510
+ self.head_dim,
511
+ rotary_dim=self.rotary_dim,
512
+ max_position=config.max_position_embeddings,
513
+ base=config.rope_theta,
514
+ rope_scaling=config.rope_scaling,
238
515
  )
239
- self.mlp = BailingMoE(
240
- config=config,
516
+
517
+ self.attn = RadixAttention(
518
+ self.num_heads,
519
+ self.head_dim,
520
+ self.scale,
521
+ num_kv_heads=self.num_kv_heads,
241
522
  layer_id=layer_id,
242
- quant_config=quant_config,
243
- prefix=add_prefix("mlp", prefix),
523
+ prefix=add_prefix("attn", prefix),
244
524
  )
245
525
 
526
+ self.alt_stream = alt_stream
527
+
528
+ def _apply_qk_norm(
529
+ self, q: torch.Tensor, k: torch.Tensor
530
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
531
+ # overlap qk norm
532
+ if self.alt_stream is not None and get_is_capture_mode():
533
+ current_stream = torch.cuda.current_stream()
534
+ self.alt_stream.wait_stream(current_stream)
535
+ q_by_head = q.reshape(-1, self.head_dim)
536
+ q_by_head = self.query_layernorm(q_by_head)
537
+ with torch.cuda.stream(self.alt_stream):
538
+ k_by_head = k.reshape(-1, self.head_dim)
539
+ k_by_head = self.key_layernorm(k_by_head)
540
+ current_stream.wait_stream(self.alt_stream)
541
+ else:
542
+ q_by_head = q.reshape(-1, self.head_dim)
543
+ q_by_head = self.query_layernorm(q_by_head)
544
+ k_by_head = k.reshape(-1, self.head_dim)
545
+ k_by_head = self.key_layernorm(k_by_head)
546
+ q = q_by_head.view(q.shape)
547
+ k = k_by_head.view(k.shape)
548
+ return q, k
549
+
246
550
  def forward(
247
551
  self,
552
+ positions: torch.Tensor,
248
553
  hidden_states: torch.Tensor,
249
- position_ids: torch.Tensor,
250
- residual: Optional[torch.Tensor],
251
554
  forward_batch: ForwardBatch,
252
- ) -> Tuple[torch.Tensor, torch.Tensor]:
253
- # Pre-normalization and residual connection for the attention block
254
- if residual is None:
255
- residual = hidden_states
256
- normed_hidden_states = self.input_layernorm(hidden_states)
555
+ ) -> torch.Tensor:
556
+ if hidden_states.shape[0] == 0:
557
+ return hidden_states
558
+ qkv, _ = self.query_key_value(hidden_states)
559
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
560
+ if self.use_qk_norm:
561
+ q, k = self._apply_qk_norm(q, k)
562
+ q, k = self.rotary_emb(
563
+ positions,
564
+ q,
565
+ k,
566
+ fused_set_kv_buffer_arg=(
567
+ create_fused_set_kv_buffer_arg(
568
+ value=v,
569
+ layer=self.attn,
570
+ forward_batch=forward_batch,
571
+ )
572
+ if enable_fused_set_kv_buffer(forward_batch)
573
+ else None
574
+ ),
575
+ )
576
+ context_layer = self.attn(
577
+ q,
578
+ k,
579
+ v,
580
+ forward_batch,
581
+ save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
582
+ )
583
+ attn_output, _ = self.dense(context_layer)
584
+ return attn_output
585
+
586
+
587
+ class BailingMoEBlock(nn.Module):
588
+ def __init__(
589
+ self,
590
+ config: PretrainedConfig,
591
+ layer_id: int = 0,
592
+ quant_config: Optional[QuantizationConfig] = None,
593
+ prefix: str = "",
594
+ alt_stream: Optional[torch.cuda.Stream] = None,
595
+ ):
596
+ super().__init__()
597
+ hidden_size = config.hidden_size
598
+
599
+ self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
600
+ self.dp_size = get_attention_dp_size()
601
+ self.attention = BailingMoEAttention(
602
+ config,
603
+ layer_id,
604
+ quant_config,
605
+ reduce_results=False,
606
+ prefix=add_prefix("attention", prefix),
607
+ alt_stream=alt_stream,
608
+ )
609
+ self.layer_id = layer_id
610
+ self.attn_tp_size = get_attention_tp_size()
611
+ self.attn_tp_rank = get_attention_tp_rank()
612
+
613
+ self.is_layer_sparse = self._is_layer_sparse(
614
+ config, layer_id=layer_id, is_nextn=False
615
+ )
616
+ is_previous_layer_sparse = self._is_layer_sparse(
617
+ config, layer_id=layer_id - 1, is_nextn=False
618
+ )
619
+
620
+ self.layer_scatter_modes = LayerScatterModes.init_new(
621
+ layer_id=layer_id,
622
+ num_layers=config.num_hidden_layers,
623
+ is_layer_sparse=self.is_layer_sparse,
624
+ is_previous_layer_sparse=is_previous_layer_sparse,
625
+ )
626
+
627
+ self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
628
+
629
+ if self.is_layer_sparse:
630
+ self.mlp = BailingMoESparseMoeBlock(
631
+ layer_id=layer_id,
632
+ config=config,
633
+ quant_config=quant_config,
634
+ alt_stream=alt_stream,
635
+ prefix=add_prefix("mlp", prefix),
636
+ )
257
637
  else:
258
- normed_hidden_states, residual = self.input_layernorm(
259
- hidden_states, residual
638
+ if enable_moe_dense_fully_dp():
639
+ mlp_tp_rank, mlp_tp_size = 0, 1
640
+ else:
641
+ mlp_tp_rank, mlp_tp_size = None, None
642
+ self.mlp = BailingMoEMLP(
643
+ intermediate_size=config.intermediate_size,
644
+ config=config,
645
+ quant_config=quant_config,
646
+ prefix=add_prefix("mlp", prefix),
647
+ tp_rank=mlp_tp_rank,
648
+ tp_size=mlp_tp_size,
260
649
  )
261
650
 
262
- attn_output = self.attention(
263
- hidden_states=normed_hidden_states,
264
- position_ids=position_ids,
651
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
652
+
653
+ self.layer_communicator = LayerCommunicator(
654
+ layer_scatter_modes=self.layer_scatter_modes,
655
+ input_layernorm=self.input_layernorm,
656
+ post_attention_layernorm=self.post_attention_layernorm,
657
+ allow_reduce_scatter=True,
658
+ )
659
+
660
+ def _is_layer_sparse(
661
+ self, config: PretrainedConfig, layer_id: int, is_nextn: bool
662
+ ) -> bool:
663
+ return is_nextn or (
664
+ config.num_experts is not None and layer_id >= config.first_k_dense_replace
665
+ )
666
+
667
+ def forward(
668
+ self,
669
+ positions: torch.Tensor,
670
+ hidden_states: torch.Tensor,
671
+ forward_batch: ForwardBatch,
672
+ residual: Optional[torch.Tensor],
673
+ ) -> torch.Tensor:
674
+ hidden_states, residual = self.layer_communicator.prepare_attn(
675
+ hidden_states=hidden_states,
676
+ residual=residual,
677
+ forward_batch=forward_batch,
678
+ )
679
+
680
+ hidden_states = self.attention(
681
+ positions=positions,
682
+ hidden_states=hidden_states,
265
683
  forward_batch=forward_batch,
266
684
  )
267
685
 
268
- # Pre-normalization and residual connection for the MLP block
269
- normed_hidden_states, residual = self.post_attention_layernorm(
270
- attn_output, residual
686
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
687
+ hidden_states=hidden_states,
688
+ residual=residual,
689
+ forward_batch=forward_batch,
690
+ )
691
+
692
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
693
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
694
+ forward_batch
695
+ )
696
+
697
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
698
+
699
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
700
+ hidden_states=hidden_states,
701
+ residual=residual,
702
+ forward_batch=forward_batch,
271
703
  )
272
- mlp_output = self.mlp(normed_hidden_states)
273
704
 
274
- return mlp_output, residual
705
+ return hidden_states, residual
275
706
 
276
707
 
277
- class BailingMoeModel(nn.Module):
708
+ class BailingMoEModel(nn.Module):
278
709
 
279
710
  def __init__(
280
711
  self,
281
712
  config: PretrainedConfig,
282
713
  quant_config: Optional[QuantizationConfig] = None,
714
+ alt_stream: Optional[torch.cuda.Stream] = None,
283
715
  prefix: str = "",
284
716
  ):
285
717
  super().__init__()
718
+ self.pp_group = get_pp_group()
286
719
  self.config = config
287
- self.padding_idx = config.pad_token_id
288
720
  self.vocab_size = config.vocab_size
289
721
  self.embed_dim = config.hidden_size
722
+ if self.pp_group.is_first_rank:
723
+ self.word_embeddings = VocabParallelEmbedding(
724
+ self.vocab_size,
725
+ self.embed_dim,
726
+ quant_config=quant_config,
727
+ prefix=add_prefix("word_embeddings", prefix),
728
+ enable_tp=not is_dp_attention_enabled(),
729
+ )
730
+ else:
731
+ self.word_embeddings = PPMissingLayer()
290
732
 
291
- self.embed_tokens = VocabParallelEmbedding(
292
- config.vocab_size,
293
- config.hidden_size,
294
- prefix=add_prefix("embed_tokens", prefix),
295
- )
296
733
  self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
297
734
 
298
- self.layers = make_layers(
735
+ self.layers, self.start_layer, self.end_layer = make_layers(
299
736
  config.num_hidden_layers,
300
- lambda idx, prefix: BailingMoeBlock(
301
- config=config,
737
+ lambda idx, prefix: BailingMoEBlock(
302
738
  layer_id=idx,
739
+ config=config,
303
740
  quant_config=quant_config,
304
741
  prefix=prefix,
742
+ alt_stream=alt_stream,
305
743
  ),
744
+ pp_rank=self.pp_group.rank_in_group,
745
+ pp_size=self.pp_group.world_size,
306
746
  prefix=add_prefix("layers", prefix),
307
747
  )
308
-
309
- self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
748
+ if self.pp_group.is_last_rank:
749
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
750
+ else:
751
+ self.norm = PPMissingLayer(return_tuple=True)
310
752
 
311
753
  def forward(
312
754
  self,
313
755
  input_ids: torch.Tensor,
314
- position_ids: torch.Tensor,
756
+ positions: torch.Tensor,
315
757
  forward_batch: ForwardBatch,
316
- input_embeds: Optional[torch.Tensor] = None,
317
- ) -> torch.Tensor:
318
- if input_embeds is None:
319
- hidden_states = self.embed_tokens(input_ids)
758
+ input_embeds: torch.Tensor = None,
759
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
760
+ ) -> Union[torch.Tensor, PPProxyTensors]:
761
+ if self.pp_group.is_first_rank:
762
+ if input_embeds is None:
763
+ hidden_states = self.word_embeddings(input_ids)
764
+ else:
765
+ hidden_states = input_embeds
766
+ residual = None
320
767
  else:
321
- hidden_states = input_embeds
768
+ assert pp_proxy_tensors is not None
769
+ hidden_states = pp_proxy_tensors["hidden_states"]
770
+ residual = pp_proxy_tensors["residual"]
322
771
 
323
- residual = None
324
- for layer in self.layers:
325
- hidden_states, residual = layer(
326
- hidden_states,
327
- position_ids,
328
- residual,
329
- forward_batch,
772
+ for i in range(self.start_layer, self.end_layer):
773
+ with get_global_expert_distribution_recorder().with_current_layer(i):
774
+ layer = self.layers[i]
775
+ hidden_states, residual = layer(
776
+ positions,
777
+ hidden_states,
778
+ forward_batch,
779
+ residual,
780
+ )
781
+ if not self.pp_group.is_last_rank:
782
+ return PPProxyTensors(
783
+ {
784
+ "hidden_states": hidden_states,
785
+ "residual": residual,
786
+ }
330
787
  )
788
+ else:
789
+ if not forward_batch.forward_mode.is_idle():
790
+ if residual is None:
791
+ hidden_states = self.norm(hidden_states)
792
+ else:
793
+ hidden_states, _ = self.norm(hidden_states, residual)
794
+ return hidden_states
331
795
 
332
- hidden_states, _ = self.norm(hidden_states, residual)
333
- return hidden_states
334
-
335
-
336
- class BailingMoeForCausalLM(nn.Module):
337
796
 
797
+ class BailingMoEForCausalLM(nn.Module):
338
798
  def __init__(
339
799
  self,
340
800
  config: PretrainedConfig,
341
801
  quant_config: Optional[QuantizationConfig] = None,
342
- ) -> None:
802
+ prefix: str = "",
803
+ ):
343
804
  super().__init__()
805
+ self.pp_group = get_pp_group()
344
806
  self.config = config
345
- self.model = BailingMoeModel(config=config, quant_config=quant_config)
346
- self.lm_head = ParallelLMHead(
347
- num_embeddings=config.vocab_size,
348
- embedding_dim=config.hidden_size,
349
- quant_config=quant_config,
807
+ self.quant_config = quant_config
808
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
809
+
810
+ self.model = BailingMoEModel(
811
+ config,
812
+ quant_config,
813
+ alt_stream=alt_stream,
814
+ prefix=add_prefix("model", ""),
350
815
  )
351
- if config.tie_word_embeddings:
352
- self.lm_head.weight = self.model.embed_tokens.weight
353
816
 
817
+ # tie_word_embeddings为true,复用tie_word_embeddings,反之是独立的
818
+ if config.tie_word_embeddings:
819
+ self.lm_head = self.model.word_embeddings
820
+ else:
821
+ # TODO something wrong with ParallelLMHead with DP attention enabled
822
+ self.lm_head = ParallelLMHead(
823
+ config.vocab_size,
824
+ config.hidden_size,
825
+ quant_config=quant_config,
826
+ prefix=add_prefix("lm_head", prefix),
827
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
828
+ )
354
829
  self.logits_processor = LogitsProcessor(config)
355
830
 
831
+ @property
832
+ def start_layer(self):
833
+ return self.model.start_layer
834
+
835
+ @property
836
+ def end_layer(self):
837
+ return self.model.end_layer
838
+
839
+ def get_embed_and_head(self):
840
+ """Used by the eagle_worker."""
841
+ return self.model.word_embeddings.weight, self.lm_head.weight
842
+
843
+ def set_embed_and_head(self, embed, head):
844
+ """Used by the eagle_worker."""
845
+ del self.model.word_embeddings.weight
846
+ del self.lm_head.weight
847
+ self.model.word_embeddings.weight = embed
848
+ self.lm_head.weight = head
849
+ torch.cuda.empty_cache()
850
+ torch.cuda.synchronize()
851
+
852
+ @torch.no_grad()
356
853
  def forward(
357
854
  self,
358
855
  input_ids: torch.Tensor,
359
856
  positions: torch.Tensor,
360
857
  forward_batch: ForwardBatch,
361
- inputs_embeds: Optional[torch.Tensor] = None,
858
+ input_embeds: torch.Tensor = None,
859
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
362
860
  ) -> torch.Tensor:
363
- hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
364
- return self.logits_processor(
365
- input_ids, hidden_states, self.lm_head, forward_batch
861
+ hidden_states = self.model(
862
+ input_ids,
863
+ positions,
864
+ forward_batch,
865
+ input_embeds,
866
+ pp_proxy_tensors=pp_proxy_tensors,
366
867
  )
868
+ if self.pp_group.is_last_rank:
869
+ return self.logits_processor(
870
+ input_ids, hidden_states, self.lm_head, forward_batch
871
+ )
872
+ else:
873
+ return hidden_states
367
874
 
368
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
875
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
876
+ if is_nextn:
877
+ if hasattr(self.config, "num_nextn_predict_layers"):
878
+ num_nextn_layers = self.config.num_nextn_predict_layers
879
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
880
+ # compatible with old design
881
+ nextn_layer_id = (
882
+ 0
883
+ if self.config.num_hidden_layers == 1
884
+ else self.config.num_hidden_layers
885
+ )
886
+ else:
887
+ raise ValueError("num_nextn_predict_layers is not in the config")
369
888
 
370
889
  stacked_params_mapping = [
890
+ # (param_name, shard_name, shard_id)
371
891
  ("gate_up_proj", "gate_proj", 0),
372
892
  ("gate_up_proj", "up_proj", 1),
373
893
  ]
374
894
 
895
+ if is_nextn:
896
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
897
+ nextn_spec_weight_names = [
898
+ "final_layernorm",
899
+ "eh_proj",
900
+ "enorm",
901
+ "hnorm",
902
+ ]
903
+ # Params for weights, fp8 weight scales, fp8 activation scales
904
+ # (param_name, weight_name, expert_id, shard_id)
375
905
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
376
906
  ckpt_gate_proj_name="gate_proj",
377
907
  ckpt_down_proj_name="down_proj",
@@ -381,39 +911,87 @@ class BailingMoeForCausalLM(nn.Module):
381
911
 
382
912
  params_dict = dict(self.named_parameters())
383
913
  for name, loaded_weight in weights:
914
+ if (
915
+ ("v_head" in name)
916
+ or ("inv_freq" in name)
917
+ or (self.config.tie_word_embeddings and "lm_head" in name)
918
+ ):
919
+ continue
384
920
 
385
921
  if (
386
922
  hasattr(self.config, "norm_head")
387
923
  and self.config.norm_head
388
924
  and "lm_head.weight" in name
389
925
  ):
926
+ import torch.nn.functional as F
927
+
390
928
  loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
391
929
 
392
- if "model.word_embeddings.weight" == name:
393
- name = "model.embed_tokens.weight"
930
+ if is_nextn:
931
+ if not name.startswith(nextn_layer_prefix):
932
+ continue
933
+
934
+ # Use shared head and embed weights from target model
935
+ if "shared_head.head" in name or "embed_tokens" in name:
936
+ continue
937
+
938
+ is_decoder = True
939
+ # For nextn specific weights
940
+ for weight_name in nextn_spec_weight_names:
941
+ if weight_name in name:
942
+ name = name.replace(nextn_layer_prefix, "model")
943
+ is_decoder = False
944
+ break
945
+ # For decoder layer weights
946
+ if is_decoder:
947
+ name = name.replace(nextn_layer_prefix, "model.decoder")
394
948
 
395
949
  for param_name, weight_name, shard_id in stacked_params_mapping:
396
- if weight_name in name and "mlp.experts" not in name:
397
- full_param_name = name.replace(weight_name, param_name)
398
- param = params_dict[full_param_name]
399
- param.weight_loader(param, loaded_weight, shard_id)
400
- break
950
+ if weight_name not in name:
951
+ continue
952
+ # We have mlp.experts[0].gate_proj in the checkpoint.
953
+ # Since we handle the experts below in expert_params_mapping,
954
+ # we need to skip here BEFORE we update the name, otherwise
955
+ # name will be updated to mlp.experts[0].gate_up_proj, which
956
+ # will then be updated below in expert_params_mapping
957
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
958
+ if "mlp.experts" in name:
959
+ continue
960
+ name = name.replace(weight_name, param_name)
961
+ # Skip loading extra bias for GPTQ models.
962
+ if name.endswith(".bias") and name not in params_dict:
963
+ continue
964
+ if name not in params_dict:
965
+ continue
966
+
967
+ param = params_dict[name]
968
+ weight_loader = param.weight_loader
969
+ weight_loader(param, loaded_weight, shard_id)
970
+ break
401
971
  else:
402
- for p_name, w_name, e_id, s_id in expert_params_mapping:
403
- if w_name in name and "mlp.experts" in name:
404
- full_param_name = name.replace(w_name, p_name)
405
- param = params_dict[full_param_name]
406
- param.weight_loader(
407
- param,
408
- loaded_weight,
409
- full_param_name,
410
- shard_id=s_id,
411
- expert_id=e_id,
412
- )
413
- break
972
+ for mapping in expert_params_mapping:
973
+ param_name, weight_name, expert_id, shard_id = mapping
974
+ if weight_name not in name:
975
+ continue
976
+ name = name.replace(weight_name, param_name)
977
+ if name not in params_dict:
978
+ continue
979
+ param = params_dict[name]
980
+ weight_loader = param.weight_loader
981
+ weight_loader(
982
+ param,
983
+ loaded_weight,
984
+ name,
985
+ shard_id=shard_id,
986
+ expert_id=expert_id,
987
+ )
988
+ break
414
989
  else:
990
+ # Skip loading extra bias for GPTQ models.
415
991
  if name.endswith(".bias") and name not in params_dict:
416
992
  continue
993
+ if name not in params_dict:
994
+ continue
417
995
 
418
996
  param = params_dict[name]
419
997
  weight_loader = getattr(
@@ -421,5 +999,30 @@ class BailingMoeForCausalLM(nn.Module):
421
999
  )
422
1000
  weight_loader(param, loaded_weight)
423
1001
 
1002
+ if not is_nextn:
1003
+ self.routed_experts_weights_of_layer = {
1004
+ layer_id: layer.mlp.get_moe_weights()
1005
+ for layer_id, layer in enumerate(self.model.layers)
1006
+ if not isinstance(layer, PPMissingLayer)
1007
+ and isinstance(layer.mlp, BailingMoESparseMoeBlock)
1008
+ }
1009
+
1010
+ @classmethod
1011
+ def get_model_config_for_expert_location(cls, config):
1012
+ num_groups = getattr(config, "n_group", 0)
1013
+ return ModelConfigForExpertLocation(
1014
+ num_layers=config.num_hidden_layers,
1015
+ num_logical_experts=config.num_experts,
1016
+ num_groups=None if num_groups == 0 else num_groups,
1017
+ )
1018
+
1019
+
1020
+ class BailingMoeForCausalLM(BailingMoEForCausalLM):
1021
+ pass
1022
+
1023
+
1024
+ class BailingMoeV2ForCausalLM(BailingMoEForCausalLM):
1025
+ pass
1026
+
424
1027
 
425
- EntryClass = BailingMoeForCausalLM
1028
+ EntryClass = [BailingMoEForCausalLM, BailingMoeForCausalLM, BailingMoeV2ForCausalLM]