sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,678 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/models/kimi_linear.py
2
+
3
+ from collections.abc import Iterable
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from einops import rearrange
8
+ from torch import nn
9
+
10
+ from sglang.srt.configs.kimi_linear import KimiLinearConfig
11
+ from sglang.srt.distributed import (
12
+ divide,
13
+ get_pp_group,
14
+ get_tensor_model_parallel_world_size,
15
+ tensor_model_parallel_all_reduce,
16
+ )
17
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
18
+ from sglang.srt.layers.attention.fla.kda import FusedRMSNormGated
19
+ from sglang.srt.layers.layernorm import RMSNorm
20
+ from sglang.srt.layers.linear import (
21
+ ColumnParallelLinear,
22
+ ReplicatedLinear,
23
+ RowParallelLinear,
24
+ )
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
27
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
28
+ from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
29
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
+ from sglang.srt.layers.utils import PPMissingLayer
31
+ from sglang.srt.layers.vocab_parallel_embedding import (
32
+ ParallelLMHead,
33
+ VocabParallelEmbedding,
34
+ )
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
36
+ from sglang.srt.model_loader.weight_utils import (
37
+ default_weight_loader,
38
+ maybe_remap_kv_scale_name,
39
+ sharded_weight_loader,
40
+ )
41
+ from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA as KimiMLAAttention
42
+ from sglang.srt.models.llama import LlamaMLP as KimiMLP
43
+ from sglang.srt.models.transformers import maybe_prefix
44
+ from sglang.srt.utils import make_layers
45
+ from sglang.srt.utils.common import BumpAllocator, add_prefix, set_weight_attrs
46
+
47
+
48
+ class KimiMoE(nn.Module):
49
+ def __init__(
50
+ self,
51
+ config: KimiLinearConfig,
52
+ quant_config: Optional[QuantizationConfig] = None,
53
+ prefix: str = "",
54
+ layer_idx: int = 0,
55
+ ):
56
+ super().__init__()
57
+ hidden_size = config.hidden_size
58
+ intermediate_size = config.intermediate_size
59
+ moe_intermediate_size = config.moe_intermediate_size
60
+ num_experts = config.num_experts
61
+ moe_renormalize = config.moe_renormalize
62
+ self.tp_size = get_tensor_model_parallel_world_size()
63
+ self.routed_scaling_factor = config.routed_scaling_factor
64
+ self.num_shared_experts = config.num_shared_experts
65
+ self.layer_idx = layer_idx
66
+
67
+ if config.hidden_act != "silu":
68
+ raise ValueError(
69
+ f"Unsupported activation: {config.hidden_act}. "
70
+ "Only silu is supported for now."
71
+ )
72
+
73
+ # Gate always runs at half / full precision for now.
74
+ self.gate = ReplicatedLinear(
75
+ hidden_size,
76
+ num_experts,
77
+ bias=False,
78
+ quant_config=None,
79
+ prefix=f"{prefix}.gate",
80
+ )
81
+
82
+ self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
83
+
84
+ self.experts = get_moe_impl_class(quant_config)(
85
+ num_experts=config.n_routed_experts,
86
+ top_k=config.num_experts_per_token,
87
+ hidden_size=config.hidden_size,
88
+ intermediate_size=config.moe_intermediate_size,
89
+ layer_id=self.layer_idx,
90
+ quant_config=quant_config,
91
+ routed_scaling_factor=self.routed_scaling_factor,
92
+ prefix=add_prefix("experts", prefix),
93
+ )
94
+
95
+ self.topk = TopK(
96
+ top_k=config.num_experts_per_token,
97
+ renormalize=moe_renormalize,
98
+ use_grouped_topk=True,
99
+ num_expert_group=config.num_expert_group,
100
+ topk_group=config.topk_group,
101
+ correction_bias=self.gate.e_score_correction_bias,
102
+ quant_config=quant_config,
103
+ routed_scaling_factor=self.routed_scaling_factor,
104
+ apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
105
+ # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
106
+ # and requires the output format to be standard. We use quant_config to determine the output format.
107
+ output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
108
+ )
109
+
110
+ if self.num_shared_experts is not None:
111
+ intermediate_size = moe_intermediate_size * self.num_shared_experts
112
+ self.shared_experts = KimiMLP(
113
+ hidden_size=config.hidden_size,
114
+ intermediate_size=intermediate_size,
115
+ hidden_act=config.hidden_act,
116
+ quant_config=quant_config,
117
+ reduce_results=False,
118
+ )
119
+
120
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121
+ num_tokens, hidden_size = hidden_states.shape
122
+ hidden_states = hidden_states.view(-1, hidden_size)
123
+ if self.num_shared_experts is not None:
124
+ shared_output = self.shared_experts(hidden_states)
125
+ router_logits, _ = self.gate(hidden_states)
126
+ topk_output = self.topk(hidden_states, router_logits)
127
+ final_hidden_states = self.experts(hidden_states, topk_output)
128
+
129
+ if shared_output is not None:
130
+ final_hidden_states = final_hidden_states + shared_output
131
+
132
+ if self.tp_size > 1:
133
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
134
+ return final_hidden_states.view(num_tokens, hidden_size)
135
+
136
+
137
+ class KimiDeltaAttention(nn.Module):
138
+ def __init__(
139
+ self,
140
+ layer_idx: int,
141
+ hidden_size: int,
142
+ config: KimiLinearConfig,
143
+ quant_config: Optional[QuantizationConfig] = None,
144
+ rms_norm_eps: float = 1e-5,
145
+ prefix: str = "",
146
+ **kwargs,
147
+ ) -> None:
148
+ super().__init__()
149
+ self.tp_size = get_tensor_model_parallel_world_size()
150
+ self.hidden_size = hidden_size
151
+ self.config = config
152
+ self.head_dim = config.linear_attn_config["head_dim"]
153
+ self.num_heads = config.linear_attn_config["num_heads"]
154
+ self.layer_idx = layer_idx
155
+ self.prefix = prefix
156
+ assert self.num_heads % self.tp_size == 0
157
+ self.local_num_heads = divide(self.num_heads, self.tp_size)
158
+
159
+ projection_size = self.head_dim * self.num_heads
160
+ self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
161
+
162
+ self.q_proj = ColumnParallelLinear(
163
+ self.hidden_size,
164
+ projection_size,
165
+ bias=False,
166
+ quant_config=quant_config,
167
+ prefix=f"{prefix}.q_proj",
168
+ )
169
+ self.k_proj = ColumnParallelLinear(
170
+ self.hidden_size,
171
+ projection_size,
172
+ bias=False,
173
+ quant_config=quant_config,
174
+ prefix=f"{prefix}.k_proj",
175
+ )
176
+ self.v_proj = ColumnParallelLinear(
177
+ self.hidden_size,
178
+ projection_size,
179
+ bias=False,
180
+ quant_config=quant_config,
181
+ prefix=f"{prefix}.v_proj",
182
+ )
183
+
184
+ self.f_a_proj = ReplicatedLinear(
185
+ self.hidden_size,
186
+ self.head_dim,
187
+ bias=False,
188
+ quant_config=quant_config,
189
+ prefix=f"{prefix}.f_a_proj",
190
+ )
191
+
192
+ self.f_b_proj = ColumnParallelLinear(
193
+ self.head_dim,
194
+ projection_size,
195
+ bias=False,
196
+ quant_config=quant_config,
197
+ prefix=f"{prefix}.f_b_proj",
198
+ )
199
+ self.dt_bias = nn.Parameter(
200
+ torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
201
+ )
202
+
203
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
204
+
205
+ self.b_proj = ColumnParallelLinear(
206
+ self.hidden_size,
207
+ self.num_heads,
208
+ bias=False,
209
+ quant_config=quant_config,
210
+ prefix=f"{prefix}.b_proj",
211
+ )
212
+
213
+ self.q_conv1d = ColumnParallelLinear(
214
+ input_size=self.conv_size,
215
+ output_size=projection_size,
216
+ bias=False,
217
+ params_dtype=torch.float32,
218
+ prefix=f"{prefix}.q_conv1d",
219
+ )
220
+ self.k_conv1d = ColumnParallelLinear(
221
+ input_size=self.conv_size,
222
+ output_size=projection_size,
223
+ bias=False,
224
+ params_dtype=torch.float32,
225
+ prefix=f"{prefix}.k_conv1d",
226
+ )
227
+ self.v_conv1d = ColumnParallelLinear(
228
+ input_size=self.conv_size,
229
+ output_size=projection_size,
230
+ bias=False,
231
+ params_dtype=torch.float32,
232
+ prefix=f"{prefix}.v_conv1d",
233
+ )
234
+ # unsqueeze to fit conv1d weights shape into the linear weights shape.
235
+ # Can't do this in `weight_loader` since it already exists in
236
+ # `ColumnParallelLinear` and `set_weight_attrs`
237
+ # doesn't allow to override it
238
+ self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
239
+ self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
240
+ self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
241
+
242
+ self.A_log = nn.Parameter(
243
+ torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
244
+ )
245
+ set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
246
+
247
+ self.g_a_proj = ReplicatedLinear(
248
+ self.hidden_size,
249
+ self.head_dim,
250
+ bias=False,
251
+ quant_config=quant_config,
252
+ prefix=f"{prefix}.g_a_proj",
253
+ )
254
+ self.g_b_proj = ColumnParallelLinear(
255
+ self.head_dim,
256
+ projection_size,
257
+ bias=False,
258
+ quant_config=quant_config,
259
+ prefix=f"{prefix}.g_b_proj",
260
+ )
261
+ self.o_norm = FusedRMSNormGated(
262
+ self.head_dim, eps=rms_norm_eps, activation="sigmoid"
263
+ )
264
+ self.o_proj = RowParallelLinear(
265
+ projection_size,
266
+ self.hidden_size,
267
+ bias=False,
268
+ quant_config=quant_config,
269
+ prefix=f"{prefix}.o_proj",
270
+ )
271
+
272
+ def forward(
273
+ self,
274
+ hidden_states: torch.Tensor,
275
+ positions: torch.Tensor,
276
+ forward_batch: ForwardBatch,
277
+ zero_allocator: BumpAllocator,
278
+ ) -> None:
279
+ q_proj_states = self.q_proj(hidden_states)[0]
280
+ k_proj_states = self.k_proj(hidden_states)[0]
281
+ v_proj_states = self.v_proj(hidden_states)[0]
282
+
283
+ q_conv_weights = self.q_conv1d.weight.view(
284
+ self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
285
+ )
286
+ k_conv_weights = self.k_conv1d.weight.view(
287
+ self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
288
+ )
289
+ v_conv_weights = self.v_conv1d.weight.view(
290
+ self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
291
+ )
292
+
293
+ kwargs = {
294
+ "q_proj_states": q_proj_states,
295
+ "k_proj_states": k_proj_states,
296
+ "v_proj_states": v_proj_states,
297
+ "q_conv_weights": q_conv_weights,
298
+ "k_conv_weights": k_conv_weights,
299
+ "v_conv_weights": v_conv_weights,
300
+ "q_conv_bias": self.q_conv1d.bias,
301
+ "k_conv_bias": self.k_conv1d.bias,
302
+ "v_conv_bias": self.v_conv1d.bias,
303
+ "dt_bias": self.dt_bias,
304
+ "b_proj": self.b_proj,
305
+ "f_a_proj": self.f_a_proj,
306
+ "f_b_proj": self.f_b_proj,
307
+ "A_log": self.A_log,
308
+ "head_dim": self.head_dim,
309
+ "hidden_states": hidden_states,
310
+ "layer_id": self.layer_idx,
311
+ }
312
+
313
+ core_attn_out = forward_batch.attn_backend.forward(
314
+ q=None,
315
+ k=None,
316
+ v=None,
317
+ layer=None,
318
+ forward_batch=forward_batch,
319
+ **kwargs,
320
+ )
321
+
322
+ g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
323
+ g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
324
+ core_attn_out = self.o_norm(core_attn_out, g)
325
+ core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
326
+
327
+ return self.o_proj(core_attn_out)[0]
328
+
329
+
330
+ class KimiDecoderLayer(nn.Module):
331
+ def __init__(
332
+ self,
333
+ config: KimiLinearConfig,
334
+ layer_idx: int,
335
+ quant_config: Optional[QuantizationConfig] = None,
336
+ prefix: str = "",
337
+ ) -> None:
338
+ super().__init__()
339
+ self.hidden_size = config.hidden_size
340
+
341
+ self.is_moe = config.is_moe
342
+
343
+ if config.is_kda_layer(layer_idx):
344
+ self.self_attn = KimiDeltaAttention(
345
+ layer_idx=layer_idx,
346
+ hidden_size=config.hidden_size,
347
+ config=config,
348
+ quant_config=quant_config,
349
+ prefix=f"{prefix}.self_attn",
350
+ )
351
+ else:
352
+ self.self_attn = KimiMLAAttention(
353
+ layer_id=layer_idx,
354
+ hidden_size=self.hidden_size,
355
+ num_heads=config.num_attention_heads,
356
+ quant_config=quant_config,
357
+ prefix=f"{prefix}.self_attn",
358
+ config=config,
359
+ qk_nope_head_dim=config.qk_nope_head_dim,
360
+ qk_rope_head_dim=config.qk_rope_head_dim,
361
+ v_head_dim=config.v_head_dim,
362
+ q_lora_rank=config.q_lora_rank,
363
+ kv_lora_rank=config.kv_lora_rank,
364
+ skip_rope=True,
365
+ )
366
+
367
+ if (
368
+ self.is_moe
369
+ and config.num_experts is not None
370
+ and layer_idx >= config.first_k_dense_replace
371
+ and layer_idx % config.moe_layer_freq == 0
372
+ ):
373
+ self.block_sparse_moe = KimiMoE(
374
+ config=config,
375
+ quant_config=quant_config,
376
+ layer_idx=layer_idx,
377
+ prefix=f"{prefix}.mlp",
378
+ )
379
+ self.mlp = self.block_sparse_moe
380
+ else:
381
+ self.mlp = KimiMLP(
382
+ hidden_size=self.hidden_size,
383
+ intermediate_size=config.intermediate_size,
384
+ hidden_act=config.hidden_act,
385
+ quant_config=quant_config,
386
+ prefix=f"{prefix}.mlp",
387
+ )
388
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
389
+ self.post_attention_layernorm = RMSNorm(
390
+ config.hidden_size, eps=config.rms_norm_eps
391
+ )
392
+
393
+ def forward(
394
+ self,
395
+ positions: torch.Tensor,
396
+ hidden_states: torch.Tensor,
397
+ forward_batch: ForwardBatch,
398
+ residual: Optional[torch.Tensor],
399
+ zero_allocator: BumpAllocator,
400
+ ) -> tuple[torch.Tensor, torch.Tensor]:
401
+ # Self Attention
402
+ if residual is None:
403
+ residual = hidden_states
404
+ hidden_states = self.input_layernorm(hidden_states)
405
+ else:
406
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
407
+
408
+ hidden_states = self.self_attn(
409
+ hidden_states=hidden_states,
410
+ positions=positions,
411
+ forward_batch=forward_batch,
412
+ zero_allocator=zero_allocator,
413
+ )
414
+
415
+ # Fully Connected
416
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
417
+ hidden_states = self.mlp(hidden_states)
418
+ return hidden_states, residual
419
+
420
+
421
+ class KimiLinearModel(nn.Module):
422
+ def __init__(
423
+ self,
424
+ config: KimiLinearConfig,
425
+ quant_config: Optional[QuantizationConfig] = None,
426
+ prefix: str = "",
427
+ ):
428
+ super().__init__()
429
+
430
+ self.config = config
431
+
432
+ self.padding_idx = config.pad_token_id
433
+ self.vocab_size = config.vocab_size
434
+ self.pp_group = get_pp_group()
435
+
436
+ if self.pp_group.is_first_rank:
437
+ self.embed_tokens = VocabParallelEmbedding(
438
+ config.vocab_size,
439
+ config.hidden_size,
440
+ prefix=f"{prefix}.embed_tokens",
441
+ )
442
+ else:
443
+ self.embed_tokens = PPMissingLayer()
444
+
445
+ self.layers, self.start_layer, self.end_layer = make_layers(
446
+ config.num_hidden_layers,
447
+ lambda idx, prefix: KimiDecoderLayer(
448
+ layer_idx=idx,
449
+ config=config,
450
+ quant_config=quant_config,
451
+ prefix=prefix,
452
+ ),
453
+ pp_rank=self.pp_group.rank_in_group,
454
+ pp_size=self.pp_group.world_size,
455
+ prefix=f"{prefix}.layers",
456
+ )
457
+
458
+ if self.pp_group.is_last_rank:
459
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
460
+ else:
461
+ self.norm = PPMissingLayer()
462
+
463
+ world_size = get_tensor_model_parallel_world_size()
464
+ assert (
465
+ config.num_attention_heads % world_size == 0
466
+ ), "num_attention_heads must be divisible by world_size"
467
+
468
+ def forward(
469
+ self,
470
+ input_ids: torch.Tensor | None,
471
+ positions: torch.Tensor,
472
+ forward_batch: ForwardBatch,
473
+ inputs_embeds: torch.Tensor | None = None,
474
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
475
+ ) -> torch.Tensor:
476
+ if get_pp_group().is_first_rank:
477
+ if inputs_embeds is not None:
478
+ hidden_states = inputs_embeds
479
+ else:
480
+ hidden_states = self.embed_tokens(input_ids)
481
+ residual = None
482
+ else:
483
+ assert pp_proxy_tensors is not None
484
+ hidden_states = pp_proxy_tensors["hidden_states"]
485
+ residual = pp_proxy_tensors["residual"]
486
+
487
+ total_num_layers = self.end_layer - self.start_layer
488
+ device = hidden_states.device
489
+ zero_allocator = BumpAllocator(
490
+ buffer_size=total_num_layers * 2,
491
+ dtype=torch.float32,
492
+ device=device,
493
+ )
494
+ # TODO: capture aux hidden states
495
+ aux_hidden_states = []
496
+ for i in range(self.start_layer, self.end_layer):
497
+ ctx = get_global_expert_distribution_recorder().with_current_layer(i)
498
+ with ctx:
499
+ layer = self.layers[i]
500
+ hidden_states, residual = layer(
501
+ positions=positions,
502
+ hidden_states=hidden_states,
503
+ forward_batch=forward_batch,
504
+ residual=residual,
505
+ zero_allocator=zero_allocator,
506
+ )
507
+
508
+ if not self.pp_group.is_last_rank:
509
+ return PPProxyTensors(
510
+ {
511
+ "hidden_states": hidden_states,
512
+ "residual": residual,
513
+ }
514
+ )
515
+ else:
516
+ if hidden_states.shape[0] != 0:
517
+ if residual is None:
518
+ hidden_states = self.norm(hidden_states)
519
+ else:
520
+ hidden_states, _ = self.norm(hidden_states, residual)
521
+
522
+ if len(aux_hidden_states) == 0:
523
+ return hidden_states
524
+
525
+ return hidden_states, aux_hidden_states
526
+
527
+
528
+ class KimiLinearForCausalLM(nn.Module):
529
+ def __init__(
530
+ self,
531
+ config: KimiLinearConfig,
532
+ quant_config: Optional[QuantizationConfig] = None,
533
+ prefix: str = "",
534
+ ) -> None:
535
+ super().__init__()
536
+ self.config = config
537
+ self.quant_config = quant_config
538
+ self.model = KimiLinearModel(
539
+ config, quant_config, prefix=maybe_prefix(prefix, "model")
540
+ )
541
+ self.pp_group = get_pp_group()
542
+ if self.pp_group.is_last_rank:
543
+ self.lm_head = ParallelLMHead(
544
+ self.config.vocab_size,
545
+ self.config.hidden_size,
546
+ quant_config=quant_config,
547
+ prefix=maybe_prefix(prefix, "lm_head"),
548
+ )
549
+ else:
550
+ self.lm_head = PPMissingLayer()
551
+ logit_scale = getattr(self.config, "logit_scale", 1.0)
552
+ self.logits_processor = LogitsProcessor(config=config, logit_scale=logit_scale)
553
+
554
+ def forward(
555
+ self,
556
+ input_ids: torch.Tensor,
557
+ positions: torch.Tensor,
558
+ forward_batch: ForwardBatch,
559
+ inputs_embeds: Optional[torch.Tensor] = None,
560
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
561
+ ) -> torch.Tensor:
562
+ hidden_states = self.model(
563
+ input_ids,
564
+ positions,
565
+ forward_batch,
566
+ inputs_embeds,
567
+ pp_proxy_tensors,
568
+ )
569
+ if self.pp_group.is_last_rank:
570
+ return self.logits_processor(
571
+ input_ids, hidden_states, self.lm_head, forward_batch
572
+ )
573
+ else:
574
+ return hidden_states
575
+
576
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
577
+ stacked_params_mapping = [
578
+ # (param_name, shard_name, shard_id)
579
+ (".gate_up_proj", ".gate_proj", 0),
580
+ (".gate_up_proj", ".up_proj", 1),
581
+ ]
582
+ if self.config.is_moe:
583
+ # Params for weights, fp8 weight scales, fp8 activation scales
584
+ # (param_name, weight_name, expert_id, shard_id)
585
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
586
+ ckpt_gate_proj_name="w1",
587
+ ckpt_down_proj_name="w2",
588
+ ckpt_up_proj_name="w3",
589
+ num_experts=self.config.num_experts,
590
+ )
591
+ else:
592
+ expert_params_mapping = []
593
+ params_dict = dict(self.named_parameters())
594
+ loaded_params: set[str] = set()
595
+ for args in weights:
596
+ name, loaded_weight = args[:2]
597
+ kwargs = args[2] if len(args) > 2 else {}
598
+ if "rotary_emb.inv_freq" in name:
599
+ continue
600
+
601
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
602
+ # Models trained using ColossalAI may include these tensors in
603
+ # the checkpoint. Skip them.
604
+ continue
605
+ for param_name, weight_name, shard_id in stacked_params_mapping:
606
+ if weight_name not in name:
607
+ continue
608
+ # We have mlp.experts[0].gate_proj in the checkpoint.
609
+ # Since we handle the experts below in expert_params_mapping,
610
+ # we need to skip here BEFORE we update the name, otherwise
611
+ # name will be updated to mlp.experts[0].gate_up_proj, which
612
+ # will then be updated below in expert_params_mapping
613
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
614
+ if ("mlp.experts." in name) and name not in params_dict:
615
+ continue
616
+ name = name.replace(weight_name, param_name)
617
+ # Skip loading extra bias for GPTQ models.
618
+ if name.endswith(".bias") and name not in params_dict:
619
+ continue
620
+ # if is_pp_missing_parameter(name, self):
621
+ # continue
622
+ param = params_dict[name]
623
+ weight_loader = param.weight_loader
624
+ weight_loader(param, loaded_weight, shard_id)
625
+ break
626
+ else:
627
+ for idx, (param_name, weight_name, expert_id, shard_id) in enumerate(
628
+ expert_params_mapping
629
+ ):
630
+ if weight_name not in name:
631
+ continue
632
+ name = name.replace(weight_name, param_name)
633
+ # if is_pp_missing_parameter(name, self):
634
+ # continue
635
+ param = params_dict[name]
636
+ weight_loader = param.weight_loader
637
+ weight_loader(
638
+ param,
639
+ loaded_weight,
640
+ name,
641
+ expert_id=expert_id,
642
+ shard_id=shard_id,
643
+ )
644
+ break
645
+ else:
646
+ # Skip loading extra bias for GPTQ models.
647
+ if (
648
+ name.endswith(".bias")
649
+ and name not in params_dict
650
+ and not self.config.is_linear_attn
651
+ ): # noqa: E501
652
+ continue
653
+ # Remapping the name of FP8 kv-scale.
654
+ name = maybe_remap_kv_scale_name(name, params_dict)
655
+ if name is None:
656
+ continue
657
+ # if is_pp_missing_parameter(name, self):
658
+ # continue
659
+
660
+ param = params_dict[name]
661
+ weight_loader = getattr(
662
+ param, "weight_loader", default_weight_loader
663
+ )
664
+ weight_loader(param, loaded_weight, **kwargs)
665
+ loaded_params.add(name)
666
+
667
+ for layer_id in self.config.full_attention_layer_ids:
668
+ self_attn = self.model.layers[layer_id].self_attn
669
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
670
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
671
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
672
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
673
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
674
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
675
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
676
+
677
+
678
+ EntryClass = KimiLinearForCausalLM
@@ -148,7 +148,7 @@ class Llama4MoE(nn.Module):
148
148
  return out_aD
149
149
 
150
150
  def _forward_core(self, hidden_states, forward_mode: ForwardMode):
151
- if hidden_states.shape[0] < 4 and _is_cuda:
151
+ if _is_cuda:
152
152
  return self._forward_core_shared_routed_overlap(hidden_states)
153
153
  else:
154
154
  return self._forward_core_normal(hidden_states)