sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1042 @@
1
+ import enum
2
+ import logging
3
+ from typing import Any, Dict, Iterable, Optional, Set, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from sglang.srt.configs.qwen3_next import Qwen3NextConfig
10
+ from sglang.srt.distributed import (
11
+ divide,
12
+ get_pp_group,
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ )
16
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
17
+ from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
18
+ from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
19
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
20
+ from sglang.srt.layers.dp_attention import (
21
+ get_attention_tp_rank,
22
+ get_attention_tp_size,
23
+ is_dp_attention_enabled,
24
+ )
25
+ from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
26
+ from sglang.srt.layers.linear import (
27
+ ColumnParallelLinear,
28
+ MergedColumnParallelLinear,
29
+ QKVParallelLinear,
30
+ RowParallelLinear,
31
+ )
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
34
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
+ from sglang.srt.layers.radix_attention import RadixAttention
36
+ from sglang.srt.layers.rotary_embedding import get_rope
37
+ from sglang.srt.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
42
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import (
45
+ default_weight_loader,
46
+ sharded_weight_loader,
47
+ )
48
+ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
49
+ from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
50
+
51
+ logger = logging.getLogger(__name__)
52
+ _is_cuda = is_cuda()
53
+ _is_npu = is_npu()
54
+
55
+ import triton
56
+ import triton.language as tl
57
+
58
+
59
+ @triton.jit
60
+ def fused_qkvzba_split_reshape_cat_kernel(
61
+ mixed_qkv,
62
+ z,
63
+ b,
64
+ a,
65
+ mixed_qkvz,
66
+ mixed_ba,
67
+ NUM_HEADS_QK: tl.constexpr,
68
+ NUM_HEADS_V: tl.constexpr,
69
+ HEAD_QK: tl.constexpr,
70
+ HEAD_V: tl.constexpr,
71
+ ):
72
+ i_bs, i_qk = tl.program_id(0), tl.program_id(1)
73
+ QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
74
+ BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
75
+ QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
76
+ q_end: tl.constexpr = HEAD_QK
77
+ blk_q_ptr = (
78
+ mixed_qkvz
79
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
80
+ + i_qk * QKVZ_DIM_T
81
+ + tl.arange(0, q_end)
82
+ )
83
+ k_end: tl.constexpr = q_end + HEAD_QK
84
+ blk_k_ptr = (
85
+ mixed_qkvz
86
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
87
+ + i_qk * QKVZ_DIM_T
88
+ + tl.arange(q_end, k_end)
89
+ )
90
+ v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
91
+ blk_v_ptr = (
92
+ mixed_qkvz
93
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
94
+ + i_qk * QKVZ_DIM_T
95
+ + tl.arange(k_end, v_end)
96
+ )
97
+ z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
98
+ blk_z_ptr = (
99
+ mixed_qkvz
100
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
101
+ + i_qk * QKVZ_DIM_T
102
+ + tl.arange(v_end, z_end)
103
+ )
104
+ blk_q_st_ptr = (
105
+ mixed_qkv
106
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
107
+ + i_qk * HEAD_QK
108
+ + tl.arange(0, HEAD_QK)
109
+ )
110
+ blk_k_st_ptr = (
111
+ mixed_qkv
112
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
113
+ + NUM_HEADS_QK * HEAD_QK
114
+ + i_qk * HEAD_QK
115
+ + tl.arange(0, HEAD_QK)
116
+ )
117
+ blk_v_st_ptr = (
118
+ mixed_qkv
119
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
120
+ + NUM_HEADS_QK * HEAD_QK * 2
121
+ + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
122
+ + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
123
+ )
124
+ blk_z_st_ptr = (
125
+ z
126
+ + i_bs * NUM_HEADS_V * HEAD_V
127
+ + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
128
+ + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
129
+ )
130
+ tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
131
+ tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
132
+ tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
133
+ tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
134
+ b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
135
+ a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
136
+ for i in tl.static_range(b_end):
137
+ blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
138
+ blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
139
+ tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
140
+ for i in tl.static_range(b_end, a_end):
141
+ blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
142
+ blk_a_st_ptr = (
143
+ a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
144
+ )
145
+ tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
146
+
147
+
148
+ def fused_qkvzba_split_reshape_cat(
149
+ mixed_qkvz,
150
+ mixed_ba,
151
+ num_heads_qk,
152
+ num_heads_v,
153
+ head_qk,
154
+ head_v,
155
+ ):
156
+ batch, seq_len = mixed_qkvz.shape[0], 1
157
+ qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
158
+ mixed_qkv = torch.empty(
159
+ [batch * seq_len, qkv_dim_t],
160
+ dtype=mixed_qkvz.dtype,
161
+ device=mixed_qkvz.device,
162
+ )
163
+ z = torch.empty(
164
+ [batch * seq_len, num_heads_v, head_v],
165
+ dtype=mixed_qkvz.dtype,
166
+ device=mixed_qkvz.device,
167
+ )
168
+ b = torch.empty(
169
+ [batch * seq_len, num_heads_v],
170
+ dtype=mixed_ba.dtype,
171
+ device=mixed_ba.device,
172
+ )
173
+ a = torch.empty_like(b)
174
+ grid = (batch * seq_len, num_heads_qk)
175
+ fused_qkvzba_split_reshape_cat_kernel[grid](
176
+ mixed_qkv,
177
+ z,
178
+ b,
179
+ a,
180
+ mixed_qkvz,
181
+ mixed_ba,
182
+ num_heads_qk,
183
+ num_heads_v,
184
+ head_qk,
185
+ head_v,
186
+ num_warps=1,
187
+ num_stages=3,
188
+ )
189
+ return mixed_qkv, z, b, a
190
+
191
+
192
+ # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
193
+ @triton.jit
194
+ def fused_gdn_gating_kernel(
195
+ g,
196
+ A_log,
197
+ a,
198
+ dt_bias,
199
+ seq_len,
200
+ NUM_HEADS: tl.constexpr,
201
+ beta: tl.constexpr,
202
+ threshold: tl.constexpr,
203
+ BLK_HEADS: tl.constexpr,
204
+ ):
205
+ i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
206
+ head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
207
+ off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
208
+ mask = head_off < NUM_HEADS
209
+ blk_A_log = tl.load(A_log + head_off, mask=mask)
210
+ blk_a = tl.load(a + off, mask=mask)
211
+ blk_bias = tl.load(dt_bias + head_off, mask=mask)
212
+ x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
213
+ softplus_x = tl.where(
214
+ beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
215
+ )
216
+ blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
217
+ tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
218
+
219
+
220
+ def fused_gdn_gating(
221
+ A_log: torch.Tensor,
222
+ a: torch.Tensor,
223
+ dt_bias: torch.Tensor,
224
+ beta: float = 1.0,
225
+ threshold: float = 20.0,
226
+ ) -> torch.Tensor:
227
+ batch, num_heads = a.shape
228
+ seq_len = 1
229
+ grid = (batch, seq_len, triton.cdiv(num_heads, 8))
230
+ g = torch.empty_like(a, dtype=torch.float32)
231
+ fused_gdn_gating_kernel[grid](
232
+ g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
233
+ )
234
+ return g
235
+
236
+
237
+ class Qwen3GatedDeltaNet(nn.Module):
238
+ def __init__(
239
+ self,
240
+ config: Qwen3NextConfig,
241
+ layer_id: int,
242
+ alt_stream: Optional[torch.cuda.Stream] = None,
243
+ ) -> None:
244
+ super().__init__()
245
+ self.config = config
246
+ self.attn_tp_rank = get_attention_tp_rank()
247
+ self.attn_tp_size = get_attention_tp_size()
248
+ self.hidden_size = config.hidden_size
249
+ self.num_v_heads = config.linear_num_value_heads
250
+ self.num_k_heads = config.linear_num_key_heads
251
+ self.head_k_dim = config.linear_key_head_dim
252
+ self.head_v_dim = config.linear_value_head_dim
253
+ self.key_dim = self.head_k_dim * self.num_k_heads
254
+ self.value_dim = self.head_v_dim * self.num_v_heads
255
+ self.alt_stream = alt_stream
256
+
257
+ self.conv_kernel_size = config.linear_conv_kernel_dim
258
+ self.layer_id = layer_id
259
+ self.activation = config.hidden_act
260
+ self.layer_norm_epsilon = config.rms_norm_eps
261
+
262
+ # QKV
263
+ self.conv_dim = self.key_dim * 2 + self.value_dim
264
+ self.conv1d = ColumnParallelLinear(
265
+ input_size=self.conv_kernel_size,
266
+ output_size=self.conv_dim,
267
+ bias=False,
268
+ quant_config=None,
269
+ tp_rank=self.attn_tp_rank,
270
+ tp_size=self.attn_tp_size,
271
+ )
272
+ self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
273
+ # projection of the input hidden states
274
+ projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
275
+ projection_size_ba = self.num_v_heads * 2
276
+
277
+ self.in_proj_qkvz = ColumnParallelLinear(
278
+ input_size=self.hidden_size,
279
+ output_size=projection_size_qkvz,
280
+ bias=False,
281
+ tp_rank=self.attn_tp_rank,
282
+ tp_size=self.attn_tp_size,
283
+ )
284
+ self.in_proj_ba = ColumnParallelLinear(
285
+ input_size=self.hidden_size,
286
+ output_size=projection_size_ba,
287
+ bias=False,
288
+ tp_rank=self.attn_tp_rank,
289
+ tp_size=self.attn_tp_size,
290
+ )
291
+
292
+ query_key_settings = (self.key_dim, 0, False)
293
+ value_settings = (self.value_dim, 0, False)
294
+
295
+ delattr(self.conv1d.weight, "weight_loader")
296
+ set_weight_attrs(
297
+ self.conv1d.weight,
298
+ {
299
+ "weight_loader": mamba_v2_sharded_weight_loader(
300
+ [
301
+ query_key_settings,
302
+ query_key_settings,
303
+ value_settings,
304
+ ],
305
+ self.attn_tp_size,
306
+ self.attn_tp_rank,
307
+ )
308
+ },
309
+ )
310
+
311
+ # selective projection used to make dt, B and C input dependent
312
+
313
+ # time step projection (discretization)
314
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
315
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads // self.attn_tp_size))
316
+
317
+ A = torch.empty(
318
+ divide(self.num_v_heads, self.attn_tp_size), dtype=torch.float32
319
+ ).uniform_(0, 16)
320
+ self.A_log = nn.Parameter(torch.log(A))
321
+ self.A_log._no_weight_decay = True
322
+
323
+ set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
324
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
325
+
326
+ self.norm = RMSNormGated(
327
+ self.head_v_dim,
328
+ eps=self.layer_norm_epsilon,
329
+ group_size=None,
330
+ norm_before_gate=True,
331
+ device=torch.get_device_module().current_device(),
332
+ dtype=config.torch_dtype,
333
+ )
334
+
335
+ self.out_proj = RowParallelLinear(
336
+ self.value_dim,
337
+ self.hidden_size,
338
+ bias=False,
339
+ input_is_parallel=True,
340
+ reduce_results=False,
341
+ tp_rank=self.attn_tp_rank,
342
+ tp_size=self.attn_tp_size,
343
+ )
344
+
345
+ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
346
+ """
347
+ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
348
+ """
349
+ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
350
+ self.num_k_heads // self.attn_tp_size,
351
+ (
352
+ self.head_k_dim
353
+ + self.head_k_dim
354
+ + (self.head_v_dim + self.head_v_dim)
355
+ * self.num_v_heads
356
+ // self.num_k_heads
357
+ ),
358
+ )
359
+ new_tensor_shape_ba = mixed_ba.size()[:-1] + (
360
+ self.num_k_heads // self.attn_tp_size,
361
+ 2 * self.num_v_heads // self.num_k_heads,
362
+ )
363
+
364
+ mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
365
+ mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
366
+
367
+ split_arg_list_qkvz = [
368
+ self.head_k_dim,
369
+ self.head_k_dim,
370
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
371
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
372
+ ]
373
+ split_arg_list_ba = [
374
+ self.num_v_heads // self.num_k_heads,
375
+ self.num_v_heads // self.num_k_heads,
376
+ ]
377
+
378
+ # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
379
+ # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
380
+ (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
381
+ (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
382
+
383
+ # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
384
+ value = value.reshape(value.size(0), -1, self.head_v_dim)
385
+ z = z.reshape(z.size(0), -1, self.head_v_dim)
386
+ b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size)
387
+ a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size)
388
+
389
+ return query, key, value, z, b, a
390
+
391
+ def _forward_input_proj(self, hidden_states: torch.Tensor):
392
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
393
+ seq_len, _ = hidden_states.shape
394
+ if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
395
+ current_stream = torch.cuda.current_stream()
396
+ self.alt_stream.wait_stream(current_stream)
397
+ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
398
+ with torch.cuda.stream(self.alt_stream):
399
+ projected_states_ba, _ = self.in_proj_ba(hidden_states)
400
+ current_stream.wait_stream(self.alt_stream)
401
+ else:
402
+ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
403
+ projected_states_ba, _ = self.in_proj_ba(hidden_states)
404
+ return projected_states_qkvz, projected_states_ba
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states: torch.Tensor,
409
+ forward_batch: ForwardBatch,
410
+ ):
411
+ seq_len, _ = hidden_states.shape
412
+ is_cuda_graph = forward_batch.forward_mode.is_cuda_graph()
413
+
414
+ projected_states_qkvz, projected_states_ba = self._forward_input_proj(
415
+ hidden_states
416
+ )
417
+
418
+ if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph:
419
+ mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
420
+ projected_states_qkvz,
421
+ projected_states_ba,
422
+ triton.cdiv(self.num_k_heads, self.attn_tp_size),
423
+ triton.cdiv(self.num_v_heads, self.attn_tp_size),
424
+ self.head_k_dim,
425
+ self.head_v_dim,
426
+ )
427
+ else:
428
+ query, key, value, z, b, a = self.fix_query_key_value_ordering(
429
+ projected_states_qkvz, projected_states_ba
430
+ )
431
+ query, key, value = map(
432
+ lambda x: x.reshape(x.shape[0], -1), (query, key, value)
433
+ )
434
+ mixed_qkv = torch.cat((query, key, value), dim=-1)
435
+ # mixed_qkv = rearrange(mixed_qkv, "b l d -> b d l")
436
+
437
+ # 2. Convolution sequence transformation
438
+ conv_weights = self.conv1d.weight.view(
439
+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)
440
+ )
441
+
442
+ kwargs = {
443
+ "mixed_qkv": mixed_qkv,
444
+ "conv_weights": conv_weights,
445
+ "bias": self.conv1d.bias,
446
+ "activation": self.activation,
447
+ "key_dim": self.key_dim,
448
+ "value_dim": self.value_dim,
449
+ "attention_tp_size": self.attn_tp_size,
450
+ "head_k_dim": self.head_k_dim,
451
+ "head_v_dim": self.head_v_dim,
452
+ "a": a,
453
+ "b": b,
454
+ "A_log": self.A_log,
455
+ "dt_bias": self.dt_bias,
456
+ "layer_id": self.layer_id,
457
+ "seq_len": seq_len,
458
+ "num_k_heads": self.num_k_heads,
459
+ "num_v_heads": self.num_v_heads,
460
+ "z": z,
461
+ }
462
+
463
+ core_attn_out = forward_batch.attn_backend.forward(
464
+ q=None,
465
+ k=None,
466
+ v=None,
467
+ layer=None,
468
+ forward_batch=forward_batch,
469
+ **kwargs,
470
+ )
471
+
472
+ z_shape_og = z.shape
473
+ # reshape input data into 2D tensor
474
+ core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
475
+ z = z.reshape(-1, z.shape[-1])
476
+ core_attn_out = self.norm(core_attn_out, z)
477
+ core_attn_out = core_attn_out.reshape(z_shape_og)
478
+ core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
479
+
480
+ output, _ = self.out_proj(core_attn_out)
481
+ return output
482
+
483
+
484
+ class Qwen3HybridLinearDecoderLayer(nn.Module):
485
+
486
+ def __init__(
487
+ self,
488
+ config: Qwen3NextConfig,
489
+ layer_id: int,
490
+ quant_config: Optional[QuantizationConfig] = None,
491
+ prefix: str = "",
492
+ alt_stream: Optional[torch.cuda.Stream] = None,
493
+ ) -> None:
494
+ super().__init__()
495
+ self.config = config
496
+ self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
497
+
498
+ # Qwen3Next all layers are sparse and have no nextn now
499
+ self.is_layer_sparse = True
500
+ is_previous_layer_sparse = True
501
+ self.layer_id = layer_id
502
+
503
+ self.layer_scatter_modes = LayerScatterModes.init_new(
504
+ layer_id=layer_id,
505
+ num_layers=config.num_hidden_layers,
506
+ is_layer_sparse=self.is_layer_sparse,
507
+ is_previous_layer_sparse=is_previous_layer_sparse,
508
+ )
509
+
510
+ if self.is_layer_sparse:
511
+ self.mlp = Qwen2MoeSparseMoeBlock(
512
+ layer_id=layer_id,
513
+ config=config,
514
+ quant_config=quant_config,
515
+ alt_stream=alt_stream,
516
+ )
517
+ else:
518
+ self.mlp = Qwen2MoeMLP(
519
+ hidden_size=config.hidden_size,
520
+ intermediate_size=config.intermediate_size,
521
+ hidden_act=config.hidden_act,
522
+ quant_config=quant_config,
523
+ )
524
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
525
+ self.post_attention_layernorm = GemmaRMSNorm(
526
+ config.hidden_size, eps=config.rms_norm_eps
527
+ )
528
+ self.layer_communicator = LayerCommunicator(
529
+ layer_scatter_modes=self.layer_scatter_modes,
530
+ input_layernorm=self.input_layernorm,
531
+ post_attention_layernorm=self.post_attention_layernorm,
532
+ allow_reduce_scatter=True,
533
+ )
534
+
535
+ def forward(
536
+ self,
537
+ hidden_states: torch.Tensor,
538
+ residual: Optional[torch.Tensor],
539
+ **kwargs,
540
+ ):
541
+ forward_batch = kwargs.get("forward_batch", None)
542
+
543
+ hidden_states, residual = self.layer_communicator.prepare_attn(
544
+ hidden_states, residual, forward_batch
545
+ )
546
+
547
+ if not forward_batch.forward_mode.is_idle():
548
+ hidden_states = self.linear_attn(
549
+ hidden_states,
550
+ forward_batch,
551
+ )
552
+ # Fully Connected
553
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
554
+ hidden_states, residual, forward_batch
555
+ )
556
+
557
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
558
+ forward_batch
559
+ )
560
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
561
+
562
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
563
+ hidden_states, residual, forward_batch
564
+ )
565
+
566
+ return hidden_states, residual
567
+
568
+
569
+ class Qwen3HybridAttentionDecoderLayer(nn.Module):
570
+
571
+ def __init__(
572
+ self,
573
+ config: Qwen3NextConfig,
574
+ layer_id: int,
575
+ quant_config: Optional[QuantizationConfig] = None,
576
+ prefix: str = "",
577
+ alt_stream: Optional[torch.cuda.Stream] = None,
578
+ ) -> None:
579
+ super().__init__()
580
+ self.config = config
581
+ self.hidden_size = config.hidden_size
582
+ self.attn_tp_rank = get_attention_tp_rank()
583
+ self.attn_tp_size = get_attention_tp_size()
584
+ self.total_num_heads = config.num_attention_heads
585
+ assert self.total_num_heads % self.attn_tp_size == 0
586
+ self.num_heads = self.total_num_heads // self.attn_tp_size
587
+ self.total_num_kv_heads = config.num_key_value_heads
588
+ if self.total_num_kv_heads >= self.attn_tp_size:
589
+ # Number of KV heads is greater than TP size, so we partition
590
+ # the KV heads across multiple tensor parallel GPUs.
591
+ assert self.total_num_kv_heads % self.attn_tp_size == 0
592
+ else:
593
+ # Number of KV heads is less than TP size, so we replicate
594
+ # the KV heads across multiple tensor parallel GPUs.
595
+ assert self.attn_tp_size % self.total_num_kv_heads == 0
596
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
597
+ self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
598
+ self.q_size = self.num_heads * self.head_dim
599
+ self.kv_size = self.num_kv_heads * self.head_dim
600
+ self.scaling = self.head_dim**-0.5
601
+ self.rope_theta = getattr(config, "rope_theta", 10000)
602
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
603
+ self.rope_scaling = getattr(config, "rope_scaling", None)
604
+ self.partial_rotary_factor = config.partial_rotary_factor
605
+ self.layer_id = layer_id
606
+
607
+ self.attn_output_gate = getattr(config, "attn_output_gate", True)
608
+ if self.attn_output_gate:
609
+ logger.warning_once("using attn output gate!")
610
+
611
+ self.rotary_emb = get_rope(
612
+ head_size=self.head_dim,
613
+ rotary_dim=self.head_dim,
614
+ max_position=self.max_position_embeddings,
615
+ rope_scaling=self.rope_scaling,
616
+ base=self.rope_theta,
617
+ partial_rotary_factor=self.partial_rotary_factor,
618
+ is_neox_style=True,
619
+ dtype=torch.get_default_dtype(), # see impl of get_rope
620
+ )
621
+
622
+ self.qkv_proj = QKVParallelLinear(
623
+ config.hidden_size,
624
+ self.head_dim,
625
+ self.total_num_heads * (1 + self.attn_output_gate),
626
+ self.total_num_kv_heads,
627
+ bias=False,
628
+ quant_config=quant_config,
629
+ tp_rank=self.attn_tp_rank,
630
+ tp_size=self.attn_tp_size,
631
+ )
632
+
633
+ self.o_proj = RowParallelLinear(
634
+ self.total_num_heads * self.head_dim,
635
+ config.hidden_size,
636
+ bias=False,
637
+ quant_config=quant_config,
638
+ reduce_results=False,
639
+ tp_rank=self.attn_tp_rank,
640
+ tp_size=self.attn_tp_size,
641
+ )
642
+
643
+ self.attn = RadixAttention(
644
+ self.num_heads,
645
+ self.head_dim,
646
+ self.scaling,
647
+ num_kv_heads=self.num_kv_heads,
648
+ layer_id=layer_id,
649
+ prefix=f"{prefix}.attn",
650
+ )
651
+
652
+ # Qwen3Next all layers are sparse and have no nextn now
653
+ self.is_layer_sparse = True
654
+ is_previous_layer_sparse = True
655
+
656
+ self.layer_scatter_modes = LayerScatterModes.init_new(
657
+ layer_id=layer_id,
658
+ num_layers=config.num_hidden_layers,
659
+ is_layer_sparse=self.is_layer_sparse,
660
+ is_previous_layer_sparse=is_previous_layer_sparse,
661
+ )
662
+
663
+ if self.is_layer_sparse:
664
+ self.mlp = Qwen2MoeSparseMoeBlock(
665
+ layer_id=layer_id,
666
+ config=config,
667
+ quant_config=quant_config,
668
+ alt_stream=alt_stream,
669
+ )
670
+ else:
671
+ self.mlp = Qwen2MoeMLP(
672
+ hidden_size=config.hidden_size,
673
+ intermediate_size=config.intermediate_size,
674
+ hidden_act=config.hidden_act,
675
+ quant_config=quant_config,
676
+ )
677
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
678
+ self.post_attention_layernorm = GemmaRMSNorm(
679
+ config.hidden_size, eps=config.rms_norm_eps
680
+ )
681
+
682
+ self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
683
+ self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
684
+
685
+ self.layer_communicator = LayerCommunicator(
686
+ layer_scatter_modes=self.layer_scatter_modes,
687
+ input_layernorm=self.input_layernorm,
688
+ post_attention_layernorm=self.post_attention_layernorm,
689
+ allow_reduce_scatter=True,
690
+ )
691
+
692
+ self.alt_stream = alt_stream
693
+
694
+ def _apply_qk_norm(
695
+ self, q: torch.Tensor, k: torch.Tensor
696
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
697
+ # overlap qk norm
698
+ if self.alt_stream is not None and get_is_capture_mode():
699
+ current_stream = torch.cuda.current_stream()
700
+ self.alt_stream.wait_stream(current_stream)
701
+ q_by_head = q.reshape(-1, self.head_dim)
702
+ q_by_head = self.q_norm(q_by_head)
703
+ with torch.cuda.stream(self.alt_stream):
704
+ k_by_head = k.reshape(-1, self.head_dim)
705
+ k_by_head = self.k_norm(k_by_head)
706
+ current_stream.wait_stream(self.alt_stream)
707
+ else:
708
+ q_by_head = q.reshape(-1, self.head_dim)
709
+ q_by_head = self.q_norm(q_by_head)
710
+ k_by_head = k.reshape(-1, self.head_dim)
711
+ k_by_head = self.k_norm(k_by_head)
712
+ q = q_by_head.view(q.shape)
713
+ k = k_by_head.view(k.shape)
714
+ return q, k
715
+
716
+ def self_attention(
717
+ self,
718
+ positions: torch.Tensor,
719
+ hidden_states: torch.Tensor,
720
+ forward_batch: ForwardBatch,
721
+ ) -> torch.Tensor:
722
+ qkv, _ = self.qkv_proj(hidden_states)
723
+
724
+ if self.attn_output_gate:
725
+ q_gate, k, v = qkv.split(
726
+ [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
727
+ )
728
+ orig_shape = q_gate.shape[:-1]
729
+ q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
730
+ q, gate = torch.chunk(q_gate, 2, dim=-1)
731
+ q = q.reshape(*orig_shape, -1)
732
+ gate = gate.reshape(*orig_shape, -1)
733
+ else:
734
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
735
+
736
+ q, k = self._apply_qk_norm(q, k)
737
+
738
+ q, k = self.rotary_emb(positions, q, k)
739
+
740
+ attn_output = self.attn(q, k, v, forward_batch)
741
+
742
+ if self.attn_output_gate:
743
+ gate = torch.sigmoid(gate)
744
+ attn_output = attn_output * gate
745
+
746
+ output, _ = self.o_proj(attn_output)
747
+ return output
748
+
749
+ def forward(
750
+ self,
751
+ positions: torch.Tensor,
752
+ hidden_states: torch.Tensor,
753
+ residual: Optional[torch.Tensor],
754
+ forward_batch: ForwardBatch,
755
+ **kwargs: Any,
756
+ ):
757
+ hidden_states, residual = self.layer_communicator.prepare_attn(
758
+ hidden_states, residual, forward_batch
759
+ )
760
+
761
+ if not forward_batch.forward_mode.is_idle():
762
+ hidden_states = self.self_attention(
763
+ positions=positions,
764
+ hidden_states=hidden_states,
765
+ forward_batch=forward_batch,
766
+ )
767
+
768
+ # Fully Connected
769
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
770
+ hidden_states, residual, forward_batch
771
+ )
772
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
773
+ forward_batch
774
+ )
775
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
776
+
777
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
778
+ hidden_states, residual, forward_batch
779
+ )
780
+
781
+ return hidden_states, residual
782
+
783
+
784
+ ALL_DECODER_LAYER_TYPES = {
785
+ "attention": Qwen3HybridAttentionDecoderLayer,
786
+ "linear_attention": Qwen3HybridLinearDecoderLayer,
787
+ }
788
+
789
+
790
+ class Qwen3NextModel(nn.Module):
791
+ def __init__(
792
+ self,
793
+ config: Qwen3NextConfig,
794
+ quant_config: Optional[QuantizationConfig] = None,
795
+ prefix: str = "",
796
+ ) -> None:
797
+ super().__init__()
798
+ self.config = config
799
+
800
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
801
+
802
+ self.embed_tokens = VocabParallelEmbedding(
803
+ config.vocab_size,
804
+ config.hidden_size,
805
+ org_num_embeddings=config.vocab_size,
806
+ enable_tp=not is_dp_attention_enabled(),
807
+ )
808
+
809
+ def get_layer(idx: int, prefix: str):
810
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
811
+ return layer_class(
812
+ config,
813
+ idx,
814
+ quant_config=quant_config,
815
+ prefix=prefix,
816
+ alt_stream=alt_stream,
817
+ )
818
+
819
+ self.layers = make_layers(
820
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
821
+ )
822
+
823
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
824
+ self.infer_count = 0
825
+
826
+ def forward(
827
+ self,
828
+ input_ids: torch.Tensor,
829
+ positions: torch.Tensor,
830
+ forward_batch: ForwardBatch,
831
+ # mamba_cache_params: MambaCacheParams,
832
+ inputs_embeds: Optional[torch.Tensor] = None,
833
+ ) -> torch.Tensor:
834
+
835
+ # pass a sequence index tensor, that is required for
836
+ # proper continuous batching computation including
837
+ # chunked prefill
838
+ if inputs_embeds is not None:
839
+ hidden_states = inputs_embeds
840
+ else:
841
+ hidden_states = self.embed_tokens(input_ids)
842
+
843
+ residual = None
844
+ for i in range(len(self.layers)):
845
+ layer = self.layers[i]
846
+ hidden_states, residual = layer(
847
+ layer_id=i,
848
+ positions=positions,
849
+ hidden_states=hidden_states,
850
+ residual=residual,
851
+ forward_batch=forward_batch,
852
+ )
853
+
854
+ if not forward_batch.forward_mode.is_idle():
855
+ if residual is None:
856
+ hidden_states = self.norm(hidden_states)
857
+ else:
858
+ hidden_states, _ = self.norm(hidden_states, residual)
859
+
860
+ return hidden_states
861
+
862
+
863
+ class HybridLayerType(enum.Enum):
864
+ full_attention = "attention"
865
+ swa_attention = "swa_attention"
866
+ linear_attention = "linear_attention"
867
+ mamba2 = "mamba"
868
+
869
+
870
+ class Qwen3NextForCausalLM(nn.Module):
871
+ fall_back_to_pt_during_load = False
872
+
873
+ def __init__(
874
+ self,
875
+ config: Qwen3NextConfig,
876
+ quant_config: Optional[QuantizationConfig] = None,
877
+ prefix: str = "",
878
+ ) -> None:
879
+ super().__init__()
880
+ self.config = config
881
+ self.pp_group = get_pp_group()
882
+ assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
883
+ self.quant_config = quant_config
884
+ self.model = Qwen3NextModel(
885
+ config, quant_config, prefix=add_prefix("model", prefix)
886
+ )
887
+ self.lm_head = ParallelLMHead(
888
+ config.vocab_size,
889
+ config.hidden_size,
890
+ quant_config=quant_config,
891
+ org_num_embeddings=config.vocab_size,
892
+ prefix=add_prefix("lm_head", prefix),
893
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
894
+ )
895
+ self.lm_head = self.lm_head.float()
896
+ self.logits_processor = LogitsProcessor(config)
897
+
898
+ @torch.no_grad()
899
+ def forward(
900
+ self,
901
+ input_ids: torch.Tensor,
902
+ positions: torch.Tensor,
903
+ forward_batch: ForwardBatch,
904
+ inputs_embeds: Optional[torch.Tensor] = None,
905
+ **kwargs,
906
+ ):
907
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
908
+
909
+ return self.logits_processor(
910
+ input_ids, hidden_states, self.lm_head, forward_batch
911
+ )
912
+
913
+ def get_embed_and_head(self):
914
+ return self.model.embed_tokens.weight, self.lm_head.weight
915
+
916
+ def set_embed_and_head(self, embed, head):
917
+ del self.model.embed_tokens.weight
918
+ del self.lm_head.weight
919
+ self.model.embed_tokens.weight = embed
920
+ self.lm_head.weight = head
921
+ torch.cuda.empty_cache()
922
+ torch.cuda.synchronize()
923
+
924
+ def load_weights(
925
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
926
+ ) -> Set[str]:
927
+ stacked_params_mapping = [
928
+ # (param_name, shard_name, shard_id)
929
+ ("qkv_proj", "q_proj", "q"),
930
+ ("qkv_proj", "k_proj", "k"),
931
+ ("qkv_proj", "v_proj", "v"),
932
+ ("gate_up_proj", "gate_proj", 0),
933
+ ("gate_up_proj", "up_proj", 1),
934
+ ]
935
+
936
+ # Params for weights, fp8 weight scales, fp8 activation scales
937
+ # (param_name, weight_name, expert_id, shard_id)
938
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
939
+ ckpt_gate_proj_name="gate_proj",
940
+ ckpt_down_proj_name="down_proj",
941
+ ckpt_up_proj_name="up_proj",
942
+ num_experts=self.config.num_experts,
943
+ )
944
+
945
+ params_dict = dict(self.named_parameters())
946
+ loaded_params: Set[str] = set()
947
+ for name, loaded_weight in weights:
948
+
949
+ if is_mtp:
950
+
951
+ if "mtp" not in name:
952
+ continue
953
+
954
+ if name in [
955
+ "mtp.fc.weight",
956
+ "mtp.pre_fc_norm_embedding.weight",
957
+ "mtp.pre_fc_norm_hidden.weight",
958
+ ]:
959
+ name = name.replace("mtp.", "")
960
+ else:
961
+ name = name.replace("mtp", "model")
962
+
963
+ if not is_mtp and "mtp" in name:
964
+ continue
965
+
966
+ if "rotary_emb.inv_freq" in name:
967
+ continue
968
+
969
+ if ".self_attn." in name:
970
+ name = name.replace(".self_attn", "")
971
+
972
+ for param_name, weight_name, shard_id in stacked_params_mapping:
973
+ if weight_name not in name:
974
+ continue
975
+
976
+ # TODO(fix mtp loading)
977
+ if "mlp.experts" in name:
978
+ continue
979
+
980
+ name = name.replace(weight_name, param_name)
981
+ # Skip loading extra bias for GPTQ models.
982
+ if name.endswith(".bias") and name not in params_dict:
983
+ continue
984
+ # Skip layers on other devices.
985
+ # if is_pp_missing_parameter(name, self):
986
+ # continue
987
+ if name not in params_dict:
988
+ continue
989
+ param = params_dict[name]
990
+ weight_loader = getattr(param, "weight_loader")
991
+ weight_loader(param, loaded_weight, shard_id)
992
+ break
993
+ else:
994
+ for mapping in expert_params_mapping:
995
+ param_name, weight_name, expert_id, shard_id = mapping
996
+ if weight_name not in name:
997
+ continue
998
+ name = name.replace(weight_name, param_name)
999
+ # Skip layers on other devices.
1000
+ # if is_pp_missing_parameter(name, self):
1001
+ # continue
1002
+ # Skip loading extra bias for GPTQ models.
1003
+ if (
1004
+ name.endswith(".bias") or name.endswith("_bias")
1005
+ ) and name not in params_dict:
1006
+ continue
1007
+ param = params_dict[name]
1008
+
1009
+ weight_loader = getattr(param, "weight_loader")
1010
+ weight_loader(
1011
+ param,
1012
+ loaded_weight,
1013
+ name,
1014
+ shard_id=shard_id,
1015
+ expert_id=expert_id,
1016
+ )
1017
+ break
1018
+ else:
1019
+ # Skip loading extra bias for GPTQ models.
1020
+ if name.endswith(".bias") and name not in params_dict:
1021
+ continue
1022
+ # if is_pp_missing_parameter(name, self):
1023
+ # continue
1024
+
1025
+ param = params_dict[name]
1026
+ weight_loader = getattr(
1027
+ param, "weight_loader", default_weight_loader
1028
+ )
1029
+ weight_loader(param, loaded_weight)
1030
+ loaded_params.add(name)
1031
+ return loaded_params
1032
+
1033
+ @classmethod
1034
+ def get_model_config_for_expert_location(cls, config):
1035
+ return ModelConfigForExpertLocation(
1036
+ num_layers=config.num_hidden_layers,
1037
+ num_logical_experts=config.num_experts,
1038
+ num_groups=None,
1039
+ )
1040
+
1041
+
1042
+ EntryClass = Qwen3NextForCausalLM