sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1039 @@
1
+ import enum
2
+ import logging
3
+ from typing import Any, Dict, Iterable, Optional, Set, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from sglang.srt.configs.qwen3_next import Qwen3NextConfig
10
+ from sglang.srt.distributed import (
11
+ divide,
12
+ get_pp_group,
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ )
16
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
17
+ from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
18
+ from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
19
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
20
+ from sglang.srt.layers.dp_attention import (
21
+ get_attention_tp_rank,
22
+ get_attention_tp_size,
23
+ is_dp_attention_enabled,
24
+ )
25
+ from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
26
+ from sglang.srt.layers.linear import (
27
+ ColumnParallelLinear,
28
+ MergedColumnParallelLinear,
29
+ QKVParallelLinear,
30
+ RowParallelLinear,
31
+ )
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
34
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
+ from sglang.srt.layers.radix_attention import RadixAttention
36
+ from sglang.srt.layers.rotary_embedding import get_rope
37
+ from sglang.srt.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
42
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import (
45
+ default_weight_loader,
46
+ sharded_weight_loader,
47
+ )
48
+ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
49
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers, set_weight_attrs
50
+
51
+ logger = logging.getLogger(__name__)
52
+ _is_cuda = is_cuda()
53
+
54
+ import triton
55
+ import triton.language as tl
56
+
57
+
58
+ @triton.jit
59
+ def fused_qkvzba_split_reshape_cat_kernel(
60
+ mixed_qkv,
61
+ z,
62
+ b,
63
+ a,
64
+ mixed_qkvz,
65
+ mixed_ba,
66
+ NUM_HEADS_QK: tl.constexpr,
67
+ NUM_HEADS_V: tl.constexpr,
68
+ HEAD_QK: tl.constexpr,
69
+ HEAD_V: tl.constexpr,
70
+ ):
71
+ i_bs, i_qk = tl.program_id(0), tl.program_id(1)
72
+ QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
73
+ BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
74
+ QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
75
+ q_end: tl.constexpr = HEAD_QK
76
+ blk_q_ptr = (
77
+ mixed_qkvz
78
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
79
+ + i_qk * QKVZ_DIM_T
80
+ + tl.arange(0, q_end)
81
+ )
82
+ k_end: tl.constexpr = q_end + HEAD_QK
83
+ blk_k_ptr = (
84
+ mixed_qkvz
85
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
86
+ + i_qk * QKVZ_DIM_T
87
+ + tl.arange(q_end, k_end)
88
+ )
89
+ v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
90
+ blk_v_ptr = (
91
+ mixed_qkvz
92
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
93
+ + i_qk * QKVZ_DIM_T
94
+ + tl.arange(k_end, v_end)
95
+ )
96
+ z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
97
+ blk_z_ptr = (
98
+ mixed_qkvz
99
+ + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
100
+ + i_qk * QKVZ_DIM_T
101
+ + tl.arange(v_end, z_end)
102
+ )
103
+ blk_q_st_ptr = (
104
+ mixed_qkv
105
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
106
+ + i_qk * HEAD_QK
107
+ + tl.arange(0, HEAD_QK)
108
+ )
109
+ blk_k_st_ptr = (
110
+ mixed_qkv
111
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
112
+ + NUM_HEADS_QK * HEAD_QK
113
+ + i_qk * HEAD_QK
114
+ + tl.arange(0, HEAD_QK)
115
+ )
116
+ blk_v_st_ptr = (
117
+ mixed_qkv
118
+ + i_bs * NUM_HEADS_QK * QKV_DIM_T
119
+ + NUM_HEADS_QK * HEAD_QK * 2
120
+ + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
121
+ + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
122
+ )
123
+ blk_z_st_ptr = (
124
+ z
125
+ + i_bs * NUM_HEADS_V * HEAD_V
126
+ + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
127
+ + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
128
+ )
129
+ tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
130
+ tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
131
+ tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
132
+ tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
133
+ b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
134
+ a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
135
+ for i in tl.static_range(b_end):
136
+ blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
137
+ blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
138
+ tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
139
+ for i in tl.static_range(b_end, a_end):
140
+ blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
141
+ blk_a_st_ptr = (
142
+ a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
143
+ )
144
+ tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
145
+
146
+
147
+ def fused_qkvzba_split_reshape_cat(
148
+ mixed_qkvz,
149
+ mixed_ba,
150
+ num_heads_qk,
151
+ num_heads_v,
152
+ head_qk,
153
+ head_v,
154
+ ):
155
+ batch, seq_len = mixed_qkvz.shape[0], 1
156
+ qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
157
+ mixed_qkv = torch.empty(
158
+ [batch * seq_len, qkv_dim_t],
159
+ dtype=mixed_qkvz.dtype,
160
+ device=mixed_qkvz.device,
161
+ )
162
+ z = torch.empty(
163
+ [batch * seq_len, num_heads_v, head_v],
164
+ dtype=mixed_qkvz.dtype,
165
+ device=mixed_qkvz.device,
166
+ )
167
+ b = torch.empty(
168
+ [batch * seq_len, num_heads_v],
169
+ dtype=mixed_ba.dtype,
170
+ device=mixed_ba.device,
171
+ )
172
+ a = torch.empty_like(b)
173
+ grid = (batch * seq_len, num_heads_qk)
174
+ fused_qkvzba_split_reshape_cat_kernel[grid](
175
+ mixed_qkv,
176
+ z,
177
+ b,
178
+ a,
179
+ mixed_qkvz,
180
+ mixed_ba,
181
+ num_heads_qk,
182
+ num_heads_v,
183
+ head_qk,
184
+ head_v,
185
+ num_warps=1,
186
+ num_stages=3,
187
+ )
188
+ return mixed_qkv, z, b, a
189
+
190
+
191
+ # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
192
+ @triton.jit
193
+ def fused_gdn_gating_kernel(
194
+ g,
195
+ A_log,
196
+ a,
197
+ dt_bias,
198
+ seq_len,
199
+ NUM_HEADS: tl.constexpr,
200
+ beta: tl.constexpr,
201
+ threshold: tl.constexpr,
202
+ BLK_HEADS: tl.constexpr,
203
+ ):
204
+ i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
205
+ head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
206
+ off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
207
+ mask = head_off < NUM_HEADS
208
+ blk_A_log = tl.load(A_log + head_off, mask=mask)
209
+ blk_a = tl.load(a + off, mask=mask)
210
+ blk_bias = tl.load(dt_bias + head_off, mask=mask)
211
+ x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
212
+ softplus_x = tl.where(
213
+ beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
214
+ )
215
+ blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
216
+ tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
217
+
218
+
219
+ def fused_gdn_gating(
220
+ A_log: torch.Tensor,
221
+ a: torch.Tensor,
222
+ dt_bias: torch.Tensor,
223
+ beta: float = 1.0,
224
+ threshold: float = 20.0,
225
+ ) -> torch.Tensor:
226
+ batch, num_heads = a.shape
227
+ seq_len = 1
228
+ grid = (batch, seq_len, triton.cdiv(num_heads, 8))
229
+ g = torch.empty_like(a, dtype=torch.float32)
230
+ fused_gdn_gating_kernel[grid](
231
+ g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
232
+ )
233
+ return g
234
+
235
+
236
+ class Qwen3GatedDeltaNet(nn.Module):
237
+ def __init__(
238
+ self,
239
+ config: Qwen3NextConfig,
240
+ layer_id: int,
241
+ alt_stream: Optional[torch.cuda.Stream] = None,
242
+ ) -> None:
243
+ super().__init__()
244
+ self.config = config
245
+ self.attn_tp_rank = get_attention_tp_rank()
246
+ self.attn_tp_size = get_attention_tp_size()
247
+ self.hidden_size = config.hidden_size
248
+ self.num_v_heads = config.linear_num_value_heads
249
+ self.num_k_heads = config.linear_num_key_heads
250
+ self.head_k_dim = config.linear_key_head_dim
251
+ self.head_v_dim = config.linear_value_head_dim
252
+ self.key_dim = self.head_k_dim * self.num_k_heads
253
+ self.value_dim = self.head_v_dim * self.num_v_heads
254
+ self.alt_stream = alt_stream
255
+
256
+ self.conv_kernel_size = config.linear_conv_kernel_dim
257
+ self.layer_id = layer_id
258
+ self.activation = config.hidden_act
259
+ self.layer_norm_epsilon = config.rms_norm_eps
260
+
261
+ # QKV
262
+ self.conv_dim = self.key_dim * 2 + self.value_dim
263
+ self.conv1d = ColumnParallelLinear(
264
+ input_size=self.conv_kernel_size,
265
+ output_size=self.conv_dim,
266
+ bias=False,
267
+ quant_config=None,
268
+ tp_rank=self.attn_tp_rank,
269
+ tp_size=self.attn_tp_size,
270
+ )
271
+ self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
272
+ # projection of the input hidden states
273
+ projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
274
+ projection_size_ba = self.num_v_heads * 2
275
+
276
+ self.in_proj_qkvz = ColumnParallelLinear(
277
+ input_size=self.hidden_size,
278
+ output_size=projection_size_qkvz,
279
+ bias=False,
280
+ tp_rank=self.attn_tp_rank,
281
+ tp_size=self.attn_tp_size,
282
+ )
283
+ self.in_proj_ba = ColumnParallelLinear(
284
+ input_size=self.hidden_size,
285
+ output_size=projection_size_ba,
286
+ bias=False,
287
+ tp_rank=self.attn_tp_rank,
288
+ tp_size=self.attn_tp_size,
289
+ )
290
+
291
+ query_key_settings = (self.key_dim, 0, False)
292
+ value_settings = (self.value_dim, 0, False)
293
+
294
+ delattr(self.conv1d.weight, "weight_loader")
295
+ set_weight_attrs(
296
+ self.conv1d.weight,
297
+ {
298
+ "weight_loader": mamba_v2_sharded_weight_loader(
299
+ [
300
+ query_key_settings,
301
+ query_key_settings,
302
+ value_settings,
303
+ ],
304
+ self.attn_tp_size,
305
+ self.attn_tp_rank,
306
+ )
307
+ },
308
+ )
309
+
310
+ # selective projection used to make dt, B and C input dependent
311
+
312
+ # time step projection (discretization)
313
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
314
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads // self.attn_tp_size))
315
+
316
+ A = torch.empty(
317
+ divide(self.num_v_heads, self.attn_tp_size), dtype=torch.float32
318
+ ).uniform_(0, 16)
319
+ self.A_log = nn.Parameter(torch.log(A))
320
+ self.A_log._no_weight_decay = True
321
+
322
+ set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
323
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
324
+
325
+ self.norm = RMSNormGated(
326
+ self.head_v_dim,
327
+ eps=self.layer_norm_epsilon,
328
+ group_size=None,
329
+ norm_before_gate=True,
330
+ device=torch.cuda.current_device(),
331
+ dtype=config.torch_dtype,
332
+ )
333
+
334
+ self.out_proj = RowParallelLinear(
335
+ self.value_dim,
336
+ self.hidden_size,
337
+ bias=False,
338
+ input_is_parallel=True,
339
+ reduce_results=False,
340
+ tp_rank=self.attn_tp_rank,
341
+ tp_size=self.attn_tp_size,
342
+ )
343
+
344
+ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
345
+ """
346
+ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
347
+ """
348
+ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
349
+ self.num_k_heads // self.attn_tp_size,
350
+ (
351
+ self.head_k_dim
352
+ + self.head_k_dim
353
+ + (self.head_v_dim + self.head_v_dim)
354
+ * self.num_v_heads
355
+ // self.num_k_heads
356
+ ),
357
+ )
358
+ new_tensor_shape_ba = mixed_ba.size()[:-1] + (
359
+ self.num_k_heads // self.attn_tp_size,
360
+ 2 * self.num_v_heads // self.num_k_heads,
361
+ )
362
+
363
+ mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
364
+ mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
365
+
366
+ split_arg_list_qkvz = [
367
+ self.head_k_dim,
368
+ self.head_k_dim,
369
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
370
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
371
+ ]
372
+ split_arg_list_ba = [
373
+ self.num_v_heads // self.num_k_heads,
374
+ self.num_v_heads // self.num_k_heads,
375
+ ]
376
+
377
+ # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
378
+ # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
379
+ (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
380
+ (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
381
+
382
+ # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
383
+ value = value.reshape(value.size(0), -1, self.head_v_dim)
384
+ z = z.reshape(z.size(0), -1, self.head_v_dim)
385
+ b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size)
386
+ a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size)
387
+
388
+ return query, key, value, z, b, a
389
+
390
+ def _forward_input_proj(self, hidden_states: torch.Tensor):
391
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
392
+ seq_len, _ = hidden_states.shape
393
+ if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
394
+ current_stream = torch.cuda.current_stream()
395
+ self.alt_stream.wait_stream(current_stream)
396
+ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
397
+ with torch.cuda.stream(self.alt_stream):
398
+ projected_states_ba, _ = self.in_proj_ba(hidden_states)
399
+ current_stream.wait_stream(self.alt_stream)
400
+ else:
401
+ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
402
+ projected_states_ba, _ = self.in_proj_ba(hidden_states)
403
+ return projected_states_qkvz, projected_states_ba
404
+
405
+ def forward(
406
+ self,
407
+ hidden_states: torch.Tensor,
408
+ forward_batch: ForwardBatch,
409
+ ):
410
+ seq_len, _ = hidden_states.shape
411
+ is_cuda_graph = forward_batch.forward_mode.is_cuda_graph()
412
+
413
+ projected_states_qkvz, projected_states_ba = self._forward_input_proj(
414
+ hidden_states
415
+ )
416
+
417
+ if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph:
418
+ mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
419
+ projected_states_qkvz,
420
+ projected_states_ba,
421
+ triton.cdiv(self.num_k_heads, self.attn_tp_size),
422
+ triton.cdiv(self.num_v_heads, self.attn_tp_size),
423
+ self.head_k_dim,
424
+ self.head_v_dim,
425
+ )
426
+ else:
427
+ query, key, value, z, b, a = self.fix_query_key_value_ordering(
428
+ projected_states_qkvz, projected_states_ba
429
+ )
430
+ query, key, value = map(
431
+ lambda x: x.reshape(x.shape[0], -1), (query, key, value)
432
+ )
433
+ mixed_qkv = torch.cat((query, key, value), dim=-1)
434
+ # mixed_qkv = rearrange(mixed_qkv, "b l d -> b d l")
435
+
436
+ # 2. Convolution sequence transformation
437
+ conv_weights = self.conv1d.weight.view(
438
+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)
439
+ )
440
+
441
+ kwargs = {
442
+ "mixed_qkv": mixed_qkv,
443
+ "conv_weights": conv_weights,
444
+ "bias": self.conv1d.bias,
445
+ "activation": self.activation,
446
+ "key_dim": self.key_dim,
447
+ "value_dim": self.value_dim,
448
+ "attention_tp_size": self.attn_tp_size,
449
+ "head_k_dim": self.head_k_dim,
450
+ "head_v_dim": self.head_v_dim,
451
+ "a": a,
452
+ "b": b,
453
+ "A_log": self.A_log,
454
+ "dt_bias": self.dt_bias,
455
+ "layer_id": self.layer_id,
456
+ "seq_len": seq_len,
457
+ "z": z,
458
+ }
459
+
460
+ core_attn_out = forward_batch.attn_backend.forward(
461
+ q=None,
462
+ k=None,
463
+ v=None,
464
+ layer=None,
465
+ forward_batch=forward_batch,
466
+ **kwargs,
467
+ )
468
+
469
+ z_shape_og = z.shape
470
+ # reshape input data into 2D tensor
471
+ core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
472
+ z = z.reshape(-1, z.shape[-1])
473
+ core_attn_out = self.norm(core_attn_out, z)
474
+ core_attn_out = core_attn_out.reshape(z_shape_og)
475
+ core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
476
+
477
+ output, _ = self.out_proj(core_attn_out)
478
+ return output
479
+
480
+
481
+ class Qwen3HybridLinearDecoderLayer(nn.Module):
482
+
483
+ def __init__(
484
+ self,
485
+ config: Qwen3NextConfig,
486
+ layer_id: int,
487
+ quant_config: Optional[QuantizationConfig] = None,
488
+ prefix: str = "",
489
+ alt_stream: Optional[torch.cuda.Stream] = None,
490
+ ) -> None:
491
+ super().__init__()
492
+ self.config = config
493
+ self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
494
+
495
+ # Qwen3Next all layers are sparse and have no nextn now
496
+ self.is_layer_sparse = True
497
+ is_previous_layer_sparse = True
498
+ self.layer_id = layer_id
499
+
500
+ self.layer_scatter_modes = LayerScatterModes.init_new(
501
+ layer_id=layer_id,
502
+ num_layers=config.num_hidden_layers,
503
+ is_layer_sparse=self.is_layer_sparse,
504
+ is_previous_layer_sparse=is_previous_layer_sparse,
505
+ )
506
+
507
+ if self.is_layer_sparse:
508
+ self.mlp = Qwen2MoeSparseMoeBlock(
509
+ layer_id=layer_id,
510
+ config=config,
511
+ quant_config=quant_config,
512
+ alt_stream=alt_stream,
513
+ )
514
+ else:
515
+ self.mlp = Qwen2MoeMLP(
516
+ hidden_size=config.hidden_size,
517
+ intermediate_size=config.intermediate_size,
518
+ hidden_act=config.hidden_act,
519
+ quant_config=quant_config,
520
+ )
521
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
522
+ self.post_attention_layernorm = GemmaRMSNorm(
523
+ config.hidden_size, eps=config.rms_norm_eps
524
+ )
525
+ self.layer_communicator = LayerCommunicator(
526
+ layer_scatter_modes=self.layer_scatter_modes,
527
+ input_layernorm=self.input_layernorm,
528
+ post_attention_layernorm=self.post_attention_layernorm,
529
+ allow_reduce_scatter=True,
530
+ )
531
+
532
+ def forward(
533
+ self,
534
+ hidden_states: torch.Tensor,
535
+ residual: Optional[torch.Tensor],
536
+ **kwargs,
537
+ ):
538
+ forward_batch = kwargs.get("forward_batch", None)
539
+
540
+ hidden_states, residual = self.layer_communicator.prepare_attn(
541
+ hidden_states, residual, forward_batch
542
+ )
543
+
544
+ if not forward_batch.forward_mode.is_idle():
545
+ hidden_states = self.linear_attn(
546
+ hidden_states,
547
+ forward_batch,
548
+ )
549
+ # Fully Connected
550
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
551
+ hidden_states, residual, forward_batch
552
+ )
553
+
554
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
555
+ forward_batch
556
+ )
557
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
558
+
559
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
560
+ hidden_states, residual, forward_batch
561
+ )
562
+
563
+ return hidden_states, residual
564
+
565
+
566
+ class Qwen3HybridAttentionDecoderLayer(nn.Module):
567
+
568
+ def __init__(
569
+ self,
570
+ config: Qwen3NextConfig,
571
+ layer_id: int,
572
+ quant_config: Optional[QuantizationConfig] = None,
573
+ prefix: str = "",
574
+ alt_stream: Optional[torch.cuda.Stream] = None,
575
+ ) -> None:
576
+ super().__init__()
577
+ self.config = config
578
+ self.hidden_size = config.hidden_size
579
+ self.attn_tp_rank = get_attention_tp_rank()
580
+ self.attn_tp_size = get_attention_tp_size()
581
+ self.total_num_heads = config.num_attention_heads
582
+ assert self.total_num_heads % self.attn_tp_size == 0
583
+ self.num_heads = self.total_num_heads // self.attn_tp_size
584
+ self.total_num_kv_heads = config.num_key_value_heads
585
+ if self.total_num_kv_heads >= self.attn_tp_size:
586
+ # Number of KV heads is greater than TP size, so we partition
587
+ # the KV heads across multiple tensor parallel GPUs.
588
+ assert self.total_num_kv_heads % self.attn_tp_size == 0
589
+ else:
590
+ # Number of KV heads is less than TP size, so we replicate
591
+ # the KV heads across multiple tensor parallel GPUs.
592
+ assert self.attn_tp_size % self.total_num_kv_heads == 0
593
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
594
+ self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
595
+ self.q_size = self.num_heads * self.head_dim
596
+ self.kv_size = self.num_kv_heads * self.head_dim
597
+ self.scaling = self.head_dim**-0.5
598
+ self.rope_theta = getattr(config, "rope_theta", 10000)
599
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
600
+ self.rope_scaling = getattr(config, "rope_scaling", None)
601
+ self.partial_rotary_factor = config.partial_rotary_factor
602
+ self.layer_id = layer_id
603
+
604
+ self.attn_output_gate = getattr(config, "attn_output_gate", True)
605
+ if self.attn_output_gate:
606
+ logger.warning_once("using attn output gate!")
607
+
608
+ self.rotary_emb = get_rope(
609
+ head_size=self.head_dim,
610
+ rotary_dim=self.head_dim,
611
+ max_position=self.max_position_embeddings,
612
+ rope_scaling=self.rope_scaling,
613
+ base=self.rope_theta,
614
+ partial_rotary_factor=self.partial_rotary_factor,
615
+ is_neox_style=True,
616
+ dtype=torch.get_default_dtype(), # see impl of get_rope
617
+ )
618
+
619
+ self.qkv_proj = QKVParallelLinear(
620
+ config.hidden_size,
621
+ self.head_dim,
622
+ self.total_num_heads * (1 + self.attn_output_gate),
623
+ self.total_num_kv_heads,
624
+ bias=False,
625
+ quant_config=quant_config,
626
+ tp_rank=self.attn_tp_rank,
627
+ tp_size=self.attn_tp_size,
628
+ )
629
+
630
+ self.o_proj = RowParallelLinear(
631
+ self.total_num_heads * self.head_dim,
632
+ config.hidden_size,
633
+ bias=False,
634
+ quant_config=quant_config,
635
+ reduce_results=False,
636
+ tp_rank=self.attn_tp_rank,
637
+ tp_size=self.attn_tp_size,
638
+ )
639
+
640
+ self.attn = RadixAttention(
641
+ self.num_heads,
642
+ self.head_dim,
643
+ self.scaling,
644
+ num_kv_heads=self.num_kv_heads,
645
+ layer_id=layer_id,
646
+ prefix=f"{prefix}.attn",
647
+ )
648
+
649
+ # Qwen3Next all layers are sparse and have no nextn now
650
+ self.is_layer_sparse = True
651
+ is_previous_layer_sparse = True
652
+
653
+ self.layer_scatter_modes = LayerScatterModes.init_new(
654
+ layer_id=layer_id,
655
+ num_layers=config.num_hidden_layers,
656
+ is_layer_sparse=self.is_layer_sparse,
657
+ is_previous_layer_sparse=is_previous_layer_sparse,
658
+ )
659
+
660
+ if self.is_layer_sparse:
661
+ self.mlp = Qwen2MoeSparseMoeBlock(
662
+ layer_id=layer_id,
663
+ config=config,
664
+ quant_config=quant_config,
665
+ alt_stream=alt_stream,
666
+ )
667
+ else:
668
+ self.mlp = Qwen2MoeMLP(
669
+ hidden_size=config.hidden_size,
670
+ intermediate_size=config.intermediate_size,
671
+ hidden_act=config.hidden_act,
672
+ quant_config=quant_config,
673
+ )
674
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
675
+ self.post_attention_layernorm = GemmaRMSNorm(
676
+ config.hidden_size, eps=config.rms_norm_eps
677
+ )
678
+
679
+ self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
680
+ self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
681
+
682
+ self.layer_communicator = LayerCommunicator(
683
+ layer_scatter_modes=self.layer_scatter_modes,
684
+ input_layernorm=self.input_layernorm,
685
+ post_attention_layernorm=self.post_attention_layernorm,
686
+ allow_reduce_scatter=True,
687
+ )
688
+
689
+ self.alt_stream = alt_stream
690
+
691
+ def _apply_qk_norm(
692
+ self, q: torch.Tensor, k: torch.Tensor
693
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
694
+ # overlap qk norm
695
+ if self.alt_stream is not None and get_is_capture_mode():
696
+ current_stream = torch.cuda.current_stream()
697
+ self.alt_stream.wait_stream(current_stream)
698
+ q_by_head = q.reshape(-1, self.head_dim)
699
+ q_by_head = self.q_norm(q_by_head)
700
+ with torch.cuda.stream(self.alt_stream):
701
+ k_by_head = k.reshape(-1, self.head_dim)
702
+ k_by_head = self.k_norm(k_by_head)
703
+ current_stream.wait_stream(self.alt_stream)
704
+ else:
705
+ q_by_head = q.reshape(-1, self.head_dim)
706
+ q_by_head = self.q_norm(q_by_head)
707
+ k_by_head = k.reshape(-1, self.head_dim)
708
+ k_by_head = self.k_norm(k_by_head)
709
+ q = q_by_head.view(q.shape)
710
+ k = k_by_head.view(k.shape)
711
+ return q, k
712
+
713
+ def self_attention(
714
+ self,
715
+ positions: torch.Tensor,
716
+ hidden_states: torch.Tensor,
717
+ forward_batch: ForwardBatch,
718
+ ) -> torch.Tensor:
719
+ qkv, _ = self.qkv_proj(hidden_states)
720
+
721
+ if self.attn_output_gate:
722
+ q_gate, k, v = qkv.split(
723
+ [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
724
+ )
725
+ orig_shape = q_gate.shape[:-1]
726
+ q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
727
+ q, gate = torch.chunk(q_gate, 2, dim=-1)
728
+ q = q.reshape(*orig_shape, -1)
729
+ gate = gate.reshape(*orig_shape, -1)
730
+ else:
731
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
732
+
733
+ q, k = self._apply_qk_norm(q, k)
734
+
735
+ q, k = self.rotary_emb(positions, q, k)
736
+
737
+ attn_output = self.attn(q, k, v, forward_batch)
738
+
739
+ if self.attn_output_gate:
740
+ gate = torch.sigmoid(gate)
741
+ attn_output = attn_output * gate
742
+
743
+ output, _ = self.o_proj(attn_output)
744
+ return output
745
+
746
+ def forward(
747
+ self,
748
+ positions: torch.Tensor,
749
+ hidden_states: torch.Tensor,
750
+ residual: Optional[torch.Tensor],
751
+ forward_batch: ForwardBatch,
752
+ **kwargs: Any,
753
+ ):
754
+ hidden_states, residual = self.layer_communicator.prepare_attn(
755
+ hidden_states, residual, forward_batch
756
+ )
757
+
758
+ if not forward_batch.forward_mode.is_idle():
759
+ hidden_states = self.self_attention(
760
+ positions=positions,
761
+ hidden_states=hidden_states,
762
+ forward_batch=forward_batch,
763
+ )
764
+
765
+ # Fully Connected
766
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
767
+ hidden_states, residual, forward_batch
768
+ )
769
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
770
+ forward_batch
771
+ )
772
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
773
+
774
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
775
+ hidden_states, residual, forward_batch
776
+ )
777
+
778
+ return hidden_states, residual
779
+
780
+
781
+ ALL_DECODER_LAYER_TYPES = {
782
+ "attention": Qwen3HybridAttentionDecoderLayer,
783
+ "linear_attention": Qwen3HybridLinearDecoderLayer,
784
+ }
785
+
786
+
787
+ class Qwen3NextModel(nn.Module):
788
+ def __init__(
789
+ self,
790
+ config: Qwen3NextConfig,
791
+ quant_config: Optional[QuantizationConfig] = None,
792
+ prefix: str = "",
793
+ ) -> None:
794
+ super().__init__()
795
+ self.config = config
796
+
797
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
798
+
799
+ self.embed_tokens = VocabParallelEmbedding(
800
+ config.vocab_size,
801
+ config.hidden_size,
802
+ org_num_embeddings=config.vocab_size,
803
+ enable_tp=not is_dp_attention_enabled(),
804
+ )
805
+
806
+ def get_layer(idx: int, prefix: str):
807
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
808
+ return layer_class(
809
+ config,
810
+ idx,
811
+ quant_config=quant_config,
812
+ prefix=prefix,
813
+ alt_stream=alt_stream,
814
+ )
815
+
816
+ self.layers = make_layers(
817
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
818
+ )
819
+
820
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
821
+ self.infer_count = 0
822
+
823
+ def forward(
824
+ self,
825
+ input_ids: torch.Tensor,
826
+ positions: torch.Tensor,
827
+ forward_batch: ForwardBatch,
828
+ # mamba_cache_params: MambaCacheParams,
829
+ inputs_embeds: Optional[torch.Tensor] = None,
830
+ ) -> torch.Tensor:
831
+
832
+ # pass a sequence index tensor, that is required for
833
+ # proper continuous batching computation including
834
+ # chunked prefill
835
+ if inputs_embeds is not None:
836
+ hidden_states = inputs_embeds
837
+ else:
838
+ hidden_states = self.embed_tokens(input_ids)
839
+
840
+ residual = None
841
+ for i in range(len(self.layers)):
842
+ layer = self.layers[i]
843
+ hidden_states, residual = layer(
844
+ layer_id=i,
845
+ positions=positions,
846
+ hidden_states=hidden_states,
847
+ residual=residual,
848
+ forward_batch=forward_batch,
849
+ )
850
+
851
+ if not forward_batch.forward_mode.is_idle():
852
+ if residual is None:
853
+ hidden_states = self.norm(hidden_states)
854
+ else:
855
+ hidden_states, _ = self.norm(hidden_states, residual)
856
+
857
+ return hidden_states
858
+
859
+
860
+ class HybridLayerType(enum.Enum):
861
+ full_attention = "attention"
862
+ swa_attention = "swa_attention"
863
+ linear_attention = "linear_attention"
864
+ mamba2 = "mamba"
865
+
866
+
867
+ class Qwen3NextForCausalLM(nn.Module):
868
+ fall_back_to_pt_during_load = False
869
+
870
+ def __init__(
871
+ self,
872
+ config: Qwen3NextConfig,
873
+ quant_config: Optional[QuantizationConfig] = None,
874
+ prefix: str = "",
875
+ ) -> None:
876
+ super().__init__()
877
+ self.config = config
878
+ self.pp_group = get_pp_group()
879
+ assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
880
+ self.quant_config = quant_config
881
+ self.model = Qwen3NextModel(
882
+ config, quant_config, prefix=add_prefix("model", prefix)
883
+ )
884
+ self.lm_head = ParallelLMHead(
885
+ config.vocab_size,
886
+ config.hidden_size,
887
+ quant_config=quant_config,
888
+ org_num_embeddings=config.vocab_size,
889
+ prefix=add_prefix("lm_head", prefix),
890
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
891
+ )
892
+ self.lm_head = self.lm_head.float()
893
+ self.logits_processor = LogitsProcessor(config)
894
+
895
+ @torch.no_grad()
896
+ def forward(
897
+ self,
898
+ input_ids: torch.Tensor,
899
+ positions: torch.Tensor,
900
+ forward_batch: ForwardBatch,
901
+ inputs_embeds: Optional[torch.Tensor] = None,
902
+ **kwargs,
903
+ ):
904
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
905
+
906
+ return self.logits_processor(
907
+ input_ids, hidden_states, self.lm_head, forward_batch
908
+ )
909
+
910
+ def get_embed_and_head(self):
911
+ return self.model.embed_tokens.weight, self.lm_head.weight
912
+
913
+ def set_embed_and_head(self, embed, head):
914
+ del self.model.embed_tokens.weight
915
+ del self.lm_head.weight
916
+ self.model.embed_tokens.weight = embed
917
+ self.lm_head.weight = head
918
+ torch.cuda.empty_cache()
919
+ torch.cuda.synchronize()
920
+
921
+ def load_weights(
922
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
923
+ ) -> Set[str]:
924
+ stacked_params_mapping = [
925
+ # (param_name, shard_name, shard_id)
926
+ ("qkv_proj", "q_proj", "q"),
927
+ ("qkv_proj", "k_proj", "k"),
928
+ ("qkv_proj", "v_proj", "v"),
929
+ ("gate_up_proj", "gate_proj", 0),
930
+ ("gate_up_proj", "up_proj", 1),
931
+ ]
932
+
933
+ # Params for weights, fp8 weight scales, fp8 activation scales
934
+ # (param_name, weight_name, expert_id, shard_id)
935
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
936
+ ckpt_gate_proj_name="gate_proj",
937
+ ckpt_down_proj_name="down_proj",
938
+ ckpt_up_proj_name="up_proj",
939
+ num_experts=self.config.num_experts,
940
+ )
941
+
942
+ params_dict = dict(self.named_parameters())
943
+ loaded_params: Set[str] = set()
944
+ for name, loaded_weight in weights:
945
+
946
+ if is_mtp:
947
+
948
+ if "mtp" not in name:
949
+ continue
950
+
951
+ if name in [
952
+ "mtp.fc.weight",
953
+ "mtp.pre_fc_norm_embedding.weight",
954
+ "mtp.pre_fc_norm_hidden.weight",
955
+ ]:
956
+ name = name.replace("mtp.", "")
957
+ else:
958
+ name = name.replace("mtp", "model")
959
+
960
+ if not is_mtp and "mtp" in name:
961
+ continue
962
+
963
+ if "rotary_emb.inv_freq" in name:
964
+ continue
965
+
966
+ if ".self_attn." in name:
967
+ name = name.replace(".self_attn", "")
968
+
969
+ for param_name, weight_name, shard_id in stacked_params_mapping:
970
+ if weight_name not in name:
971
+ continue
972
+
973
+ # TODO(fix mtp loading)
974
+ if "mlp.experts" in name:
975
+ continue
976
+
977
+ name = name.replace(weight_name, param_name)
978
+ # Skip loading extra bias for GPTQ models.
979
+ if name.endswith(".bias") and name not in params_dict:
980
+ continue
981
+ # Skip layers on other devices.
982
+ # if is_pp_missing_parameter(name, self):
983
+ # continue
984
+ if name not in params_dict:
985
+ continue
986
+ param = params_dict[name]
987
+ weight_loader = getattr(param, "weight_loader")
988
+ weight_loader(param, loaded_weight, shard_id)
989
+ break
990
+ else:
991
+ for mapping in expert_params_mapping:
992
+ param_name, weight_name, expert_id, shard_id = mapping
993
+ if weight_name not in name:
994
+ continue
995
+ name = name.replace(weight_name, param_name)
996
+ # Skip layers on other devices.
997
+ # if is_pp_missing_parameter(name, self):
998
+ # continue
999
+ # Skip loading extra bias for GPTQ models.
1000
+ if (
1001
+ name.endswith(".bias") or name.endswith("_bias")
1002
+ ) and name not in params_dict:
1003
+ continue
1004
+ param = params_dict[name]
1005
+
1006
+ weight_loader = getattr(param, "weight_loader")
1007
+ weight_loader(
1008
+ param,
1009
+ loaded_weight,
1010
+ name,
1011
+ shard_id=shard_id,
1012
+ expert_id=expert_id,
1013
+ )
1014
+ break
1015
+ else:
1016
+ # Skip loading extra bias for GPTQ models.
1017
+ if name.endswith(".bias") and name not in params_dict:
1018
+ continue
1019
+ # if is_pp_missing_parameter(name, self):
1020
+ # continue
1021
+
1022
+ param = params_dict[name]
1023
+ weight_loader = getattr(
1024
+ param, "weight_loader", default_weight_loader
1025
+ )
1026
+ weight_loader(param, loaded_weight)
1027
+ loaded_params.add(name)
1028
+ return loaded_params
1029
+
1030
+ @classmethod
1031
+ def get_model_config_for_expert_location(cls, config):
1032
+ return ModelConfigForExpertLocation(
1033
+ num_layers=config.num_hidden_layers,
1034
+ num_logical_experts=config.num_experts,
1035
+ num_groups=None,
1036
+ )
1037
+
1038
+
1039
+ EntryClass = Qwen3NextForCausalLM