sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,487 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from
16
+ # https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
17
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
18
+
19
+ import logging
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import Llama4TextConfig
25
+
26
+ from sglang.srt.distributed import (
27
+ get_tensor_model_parallel_world_size,
28
+ tensor_model_parallel_all_reduce,
29
+ )
30
+ from sglang.srt.layers.dp_attention import (
31
+ dp_gather_partial,
32
+ dp_scatter,
33
+ get_attention_dp_size,
34
+ get_attention_tp_rank,
35
+ get_attention_tp_size,
36
+ )
37
+ from sglang.srt.layers.layernorm import RMSNorm
38
+ from sglang.srt.layers.linear import (
39
+ QKVParallelLinear,
40
+ ReplicatedLinear,
41
+ RowParallelLinear,
42
+ )
43
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
+ from sglang.srt.layers.radix_attention import RadixAttention
46
+ from sglang.srt.layers.rotary_embedding import get_rope
47
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
51
+ from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ class Llama4MoE(nn.Module):
57
+
58
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
59
+ @staticmethod
60
+ def custom_routing_function(
61
+ hidden_states: torch.Tensor,
62
+ gating_output: torch.Tensor,
63
+ topk: int,
64
+ renormalize: bool,
65
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
66
+ router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
67
+ router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
68
+ hidden_states.dtype
69
+ )
70
+ return (
71
+ router_scores_aK.view(-1).reshape(router_scores_aK.shape),
72
+ router_indices_aK.to(torch.int32),
73
+ )
74
+
75
+ def __init__(
76
+ self,
77
+ config: Llama4TextConfig,
78
+ quant_config: Optional[QuantizationConfig] = None,
79
+ prefix: str = "",
80
+ ):
81
+ super().__init__()
82
+ self.tp_size = get_tensor_model_parallel_world_size()
83
+ self.top_k = config.num_experts_per_tok
84
+
85
+ intermediate_size_moe = config.intermediate_size
86
+ self.router = ReplicatedLinear(
87
+ config.hidden_size,
88
+ config.num_local_experts,
89
+ bias=False,
90
+ quant_config=None,
91
+ prefix=add_prefix("router", prefix),
92
+ )
93
+
94
+ self.experts = FusedMoE(
95
+ num_experts=config.num_local_experts,
96
+ top_k=config.num_experts_per_tok,
97
+ hidden_size=config.hidden_size,
98
+ custom_routing_function=Llama4MoE.custom_routing_function,
99
+ intermediate_size=intermediate_size_moe,
100
+ reduce_results=False,
101
+ renormalize=False,
102
+ quant_config=quant_config,
103
+ apply_router_weight_on_input=True,
104
+ prefix=add_prefix("experts", prefix),
105
+ )
106
+
107
+ self.shared_expert = LlamaMLP(
108
+ hidden_size=config.hidden_size,
109
+ intermediate_size=intermediate_size_moe,
110
+ hidden_act="silu",
111
+ quant_config=quant_config,
112
+ prefix=add_prefix("shared_expert", prefix),
113
+ reduce_results=False, # We need to do scatter before reduce
114
+ )
115
+
116
+ def forward(self, hidden_states):
117
+ # router_scores: [num_tokens, num_experts]
118
+ router_logits, _ = self.router(hidden_states)
119
+ shared_out = self.shared_expert(hidden_states)
120
+ routed_out = self.experts(
121
+ hidden_states=hidden_states,
122
+ router_logits=router_logits,
123
+ )
124
+ out_aD = routed_out + shared_out
125
+
126
+ if self.tp_size > 1:
127
+ out_aD = tensor_model_parallel_all_reduce(out_aD)
128
+
129
+ return out_aD
130
+
131
+
132
+ class Llama4Attention(nn.Module):
133
+
134
+ def __init__(
135
+ self,
136
+ config: Llama4TextConfig,
137
+ layer_id: int,
138
+ hidden_size: int,
139
+ num_heads: int,
140
+ num_kv_heads: int,
141
+ rope_theta: float = 10000,
142
+ rope_scaling: Optional[Dict[str, Any]] = None,
143
+ max_position_embeddings: int = 8192,
144
+ quant_config: Optional[QuantizationConfig] = None,
145
+ bias: bool = False,
146
+ bias_o_proj: bool = False,
147
+ prefix: str = "",
148
+ ) -> None:
149
+ super().__init__()
150
+ self.layer_id = layer_id
151
+ self.hidden_size = hidden_size
152
+ self.use_rope = int((layer_id + 1) % 4 != 0)
153
+ self.use_qk_norm = config.use_qk_norm and self.use_rope
154
+
155
+ self.dp_size = get_attention_dp_size()
156
+ attn_tp_rank = get_attention_tp_rank()
157
+ attn_tp_size = get_attention_tp_size()
158
+
159
+ self.total_num_heads = num_heads
160
+ assert self.total_num_heads % attn_tp_size == 0
161
+ self.num_heads = self.total_num_heads // attn_tp_size
162
+ self.total_num_kv_heads = num_kv_heads
163
+ if self.total_num_kv_heads >= attn_tp_size:
164
+ # Number of KV heads is greater than TP size, so we partition
165
+ # the KV heads across multiple tensor parallel GPUs.
166
+ assert self.total_num_kv_heads % attn_tp_size == 0
167
+ else:
168
+ # Number of KV heads is less than TP size, so we replicate
169
+ # the KV heads across multiple tensor parallel GPUs.
170
+ assert attn_tp_size % self.total_num_kv_heads == 0
171
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
172
+ self.head_dim = config.head_dim
173
+ self.q_size = self.num_heads * self.head_dim
174
+ self.kv_size = self.num_kv_heads * self.head_dim
175
+ self.scaling = self.head_dim**-0.5
176
+ self.attn_temperature_tuning = config.attn_temperature_tuning
177
+ self.floor_scale = config.floor_scale
178
+ self.attn_scale = config.attn_scale
179
+ self.rope_theta = rope_theta
180
+ self.max_position_embeddings = max_position_embeddings
181
+ self.n_rep = self.num_heads // self.num_kv_heads
182
+ self.qk_norm = (
183
+ RMSNorm(
184
+ hidden_size=self.head_dim,
185
+ eps=config.rms_norm_eps,
186
+ )
187
+ if self.use_qk_norm
188
+ else None
189
+ )
190
+ self.qkv_proj = QKVParallelLinear(
191
+ hidden_size=hidden_size,
192
+ head_size=self.head_dim,
193
+ total_num_heads=self.total_num_heads,
194
+ total_num_kv_heads=self.total_num_kv_heads,
195
+ bias=bias,
196
+ quant_config=quant_config,
197
+ prefix=add_prefix("qkv_proj", prefix),
198
+ tp_rank=attn_tp_rank,
199
+ tp_size=attn_tp_size,
200
+ )
201
+
202
+ self.o_proj = RowParallelLinear(
203
+ input_size=self.total_num_heads * self.head_dim,
204
+ output_size=hidden_size,
205
+ bias=bias_o_proj,
206
+ quant_config=quant_config,
207
+ prefix=add_prefix("o_proj", prefix),
208
+ tp_rank=attn_tp_rank,
209
+ tp_size=attn_tp_size,
210
+ reduce_results=False,
211
+ )
212
+ is_neox_style = True
213
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
214
+ if is_gguf and config.model_type in ["llama", "llama4"]:
215
+ is_neox_style = False
216
+
217
+ self.rotary_emb = (
218
+ get_rope(
219
+ self.head_dim,
220
+ rotary_dim=self.head_dim,
221
+ max_position=max_position_embeddings,
222
+ base=int(rope_theta),
223
+ rope_scaling=rope_scaling if rope_scaling != "default" else None,
224
+ is_neox_style=is_neox_style,
225
+ )
226
+ if self.use_rope
227
+ else None
228
+ )
229
+
230
+ self.attn = RadixAttention(
231
+ self.num_heads,
232
+ self.head_dim,
233
+ self.scaling,
234
+ num_kv_heads=self.num_kv_heads,
235
+ layer_id=layer_id,
236
+ prefix=add_prefix("attn", prefix),
237
+ use_irope=self.use_rope,
238
+ )
239
+
240
+ def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
241
+ floor = torch.floor((positions + 1.0) / self.floor_scale)
242
+ attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
243
+ return attn_scale.unsqueeze(-1)
244
+
245
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
246
+ def _mul_attn_scale(self, positions, q):
247
+ attn_scale = self._get_attn_scale(positions)
248
+ return (q * attn_scale).to(q.dtype)
249
+
250
+ def forward(
251
+ self,
252
+ positions: torch.Tensor,
253
+ hidden_states: torch.Tensor,
254
+ forward_batch: ForwardBatch,
255
+ ) -> torch.Tensor:
256
+ qkv, _ = self.qkv_proj(hidden_states)
257
+
258
+ qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
259
+
260
+ if self.rotary_emb is not None:
261
+ q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
262
+ q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
263
+ assert (q_out_unused is q_view) and (k_out_unused is k_view)
264
+ del q_view, k_view, q_out_unused, k_out_unused
265
+
266
+ if self.qk_norm is not None:
267
+ # TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
268
+ qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
269
+ qk = self.qk_norm(qk).to(torch.bfloat16)
270
+ qk = qk.reshape(-1, self.q_size + self.kv_size)
271
+
272
+ q, k = qk.split([self.q_size, self.kv_size], dim=-1)
273
+
274
+ # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
275
+ # the inference-time temperature tuning function is customized to not affect short context
276
+ # while working at very long context
277
+ # https://arxiv.org/abs/2501.19399
278
+ if self.attn_temperature_tuning and not self.use_rope:
279
+ q = self._mul_attn_scale(positions=positions, q=q)
280
+
281
+ attn_output = self.attn(q, k, v, forward_batch)
282
+ output, _ = self.o_proj(attn_output)
283
+ return output
284
+
285
+
286
+ class Llama4DecoderLayer(nn.Module):
287
+ def __init__(
288
+ self,
289
+ config: Llama4TextConfig,
290
+ layer_id: int = 0,
291
+ quant_config: Optional[QuantizationConfig] = None,
292
+ prefix: str = "",
293
+ ):
294
+ super().__init__()
295
+ self.layer_id = layer_id
296
+ self.hidden_size = config.hidden_size
297
+ rope_theta = config.rope_theta
298
+ rope_scaling = config.rope_scaling
299
+ max_position_embeddings = config.max_position_embeddings
300
+ self.dp_size = get_attention_dp_size()
301
+ self.attn_tp_size = get_attention_tp_size()
302
+ self.attn_tp_rank = get_attention_tp_rank()
303
+
304
+ self.self_attn = Llama4Attention(
305
+ config=config,
306
+ layer_id=layer_id,
307
+ hidden_size=self.hidden_size,
308
+ num_heads=config.num_attention_heads,
309
+ num_kv_heads=config.num_key_value_heads,
310
+ rope_theta=rope_theta,
311
+ rope_scaling=rope_scaling,
312
+ max_position_embeddings=max_position_embeddings,
313
+ quant_config=quant_config,
314
+ bias=False,
315
+ bias_o_proj=False,
316
+ prefix=add_prefix("self_attn", prefix),
317
+ )
318
+ is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
319
+ if is_moe_layer:
320
+ self.feed_forward = Llama4MoE(
321
+ config=config,
322
+ quant_config=quant_config,
323
+ prefix=add_prefix("feed_forward", prefix),
324
+ )
325
+ else:
326
+ self.feed_forward = LlamaMLP(
327
+ hidden_size=self.hidden_size,
328
+ intermediate_size=config.intermediate_size_mlp,
329
+ hidden_act="silu",
330
+ quant_config=quant_config,
331
+ prefix=add_prefix("feed_forward", prefix),
332
+ )
333
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
334
+ self.post_attention_layernorm = RMSNorm(
335
+ config.hidden_size, eps=config.rms_norm_eps
336
+ )
337
+
338
+ def forward(
339
+ self,
340
+ positions: torch.Tensor,
341
+ hidden_states: torch.Tensor,
342
+ forward_batch: ForwardBatch,
343
+ residual: Optional[torch.Tensor],
344
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
345
+ if hidden_states.shape[0] == 0:
346
+ residual = hidden_states
347
+ else:
348
+ # Self Attention
349
+ if residual is None:
350
+ residual = hidden_states
351
+ hidden_states = self.input_layernorm(hidden_states)
352
+ else:
353
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
354
+ hidden_states = self.self_attn(
355
+ positions=positions,
356
+ hidden_states=hidden_states,
357
+ forward_batch=forward_batch,
358
+ )
359
+
360
+ # Gather
361
+ if get_tensor_model_parallel_world_size() > 1:
362
+ # all gather and all reduce
363
+ if self.dp_size != 1:
364
+ if self.attn_tp_rank == 0:
365
+ hidden_states += residual
366
+ hidden_states, local_hidden_states = (
367
+ forward_batch.gathered_buffer,
368
+ hidden_states,
369
+ )
370
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
371
+ dp_scatter(residual, hidden_states, forward_batch)
372
+ hidden_states = self.post_attention_layernorm(hidden_states)
373
+ else:
374
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
375
+ hidden_states, residual = self.post_attention_layernorm(
376
+ hidden_states, residual
377
+ )
378
+ else:
379
+ hidden_states, residual = self.post_attention_layernorm(
380
+ hidden_states, residual
381
+ )
382
+
383
+ # Fully Connected
384
+ hidden_states = self.feed_forward(hidden_states)
385
+
386
+ # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
387
+ # Scatter
388
+ if self.dp_size != 1:
389
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
390
+ # be careful about this!
391
+ hidden_states, global_hidden_states = (
392
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
393
+ hidden_states,
394
+ )
395
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
396
+
397
+ return hidden_states, residual
398
+
399
+
400
+ class Llama4Model(nn.Module):
401
+ def __init__(
402
+ self,
403
+ config: Llama4TextConfig,
404
+ quant_config: Optional[QuantizationConfig] = None,
405
+ prefix: str = "",
406
+ ) -> None:
407
+ super().__init__()
408
+ self.config = config
409
+ self.padding_idx = config.pad_token_id
410
+ self.vocab_size = config.vocab_size
411
+ self.embed_tokens = VocabParallelEmbedding(
412
+ config.vocab_size,
413
+ config.hidden_size,
414
+ quant_config=quant_config,
415
+ prefix=add_prefix("embed_tokens", prefix),
416
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
417
+ )
418
+ self.layers = make_layers(
419
+ config.num_hidden_layers,
420
+ lambda idx, prefix: Llama4DecoderLayer(
421
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
422
+ ),
423
+ prefix=add_prefix("layers", prefix),
424
+ )
425
+
426
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
427
+ self.layers_to_capture = []
428
+
429
+ def forward(
430
+ self,
431
+ input_ids: torch.Tensor,
432
+ positions: torch.Tensor,
433
+ forward_batch: ForwardBatch,
434
+ input_embeds: torch.Tensor = None,
435
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
436
+ if input_embeds is None:
437
+ hidden_states = self.embed_tokens(input_ids)
438
+ else:
439
+ hidden_states = input_embeds
440
+ residual = None
441
+ aux_hidden_states = []
442
+ for i in range(len(self.layers)):
443
+ if i in self.layers_to_capture:
444
+ aux_hidden_states.append(hidden_states + residual)
445
+ layer = self.layers[i]
446
+ hidden_states, residual = layer(
447
+ positions,
448
+ hidden_states,
449
+ forward_batch,
450
+ residual,
451
+ )
452
+ if not forward_batch.forward_mode.is_idle():
453
+ hidden_states, _ = self.norm(hidden_states, residual)
454
+
455
+ if len(aux_hidden_states) == 0:
456
+ return hidden_states
457
+
458
+ return hidden_states, aux_hidden_states
459
+
460
+
461
+ class Llama4ForCausalLM(LlamaForCausalLM):
462
+ packed_modules_mapping = {
463
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
464
+ "gate_up_proj": ["gate_proj", "up_proj"],
465
+ }
466
+
467
+ def __init__(
468
+ self,
469
+ config: Llama4TextConfig,
470
+ quant_config: Optional[QuantizationConfig] = None,
471
+ prefix: str = "",
472
+ ):
473
+ super().__init__(config, quant_config, prefix)
474
+
475
+ def get_input_embeddings(self):
476
+ return self.model.embed_tokens
477
+
478
+ def _init_model(
479
+ self,
480
+ config: Llama4TextConfig,
481
+ quant_config: Optional[QuantizationConfig] = None,
482
+ prefix: str = "",
483
+ ):
484
+ return Llama4Model(config, quant_config=quant_config, prefix=prefix)
485
+
486
+
487
+ EntryClass = [Llama4ForCausalLM]
@@ -146,6 +146,7 @@ class MiniCPMAttention(nn.Module):
146
146
  self.scaling,
147
147
  num_kv_heads=self.num_kv_heads,
148
148
  layer_id=layer_id,
149
+ quant_config=quant_config,
149
150
  prefix=add_prefix("attn", prefix),
150
151
  )
151
152
 
@@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
192
192
  self.scaling,
193
193
  num_kv_heads=self.num_local_heads,
194
194
  layer_id=layer_id,
195
+ quant_config=quant_config,
195
196
  prefix=add_prefix("attn", prefix),
196
197
  )
197
198
 
@@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module):
343
344
  num_kv_heads=1,
344
345
  layer_id=layer_id,
345
346
  v_head_dim=self.kv_lora_rank,
347
+ quant_config=quant_config,
346
348
  prefix=add_prefix("attn", prefix),
347
349
  )
348
350
 
@@ -169,6 +169,7 @@ class MixtralAttention(nn.Module):
169
169
  self.scaling,
170
170
  num_kv_heads=self.num_kv_heads,
171
171
  layer_id=layer_id,
172
+ quant_config=quant_config,
172
173
  prefix=add_prefix("attn", prefix),
173
174
  )
174
175
 
@@ -232,6 +232,7 @@ class MixtralAttention(nn.Module):
232
232
  self.scaling,
233
233
  num_kv_heads=self.num_kv_heads,
234
234
  layer_id=layer_id,
235
+ quant_config=quant_config,
235
236
  prefix=add_prefix("attn", prefix),
236
237
  )
237
238
 
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
22
22
  from sglang.srt.layers.linear import (
23
23
  ColumnParallelLinear,
24
24
  QKVParallelLinear,
25
+ ReplicatedLinear,
25
26
  RowParallelLinear,
26
27
  )
27
28
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
184
185
  def __init__(
185
186
  self,
186
187
  config: config_mllama.MllamaVisionConfig,
188
+ quant_config: Optional[QuantizationConfig] = None,
187
189
  is_gated: bool = False,
188
190
  prefix: str = "",
189
191
  ):
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
199
201
  self.num_attention_heads,
200
202
  self.hidden_size,
201
203
  use_qkv_parallel=True,
202
- quant_config=None,
204
+ quant_config=quant_config,
203
205
  dropout=0.0,
204
206
  use_context_forward=False,
205
207
  softmax_in_single_precision=False,
206
208
  flatten_batch=False,
207
209
  prefix=add_prefix("self_attn", prefix),
208
210
  )
209
- self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
211
+ self.mlp = MllamaVisionMLP(
212
+ config, quant_config, prefix=add_prefix("mlp", prefix)
213
+ )
210
214
 
211
215
  self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
212
216
  self.post_attention_layernorm = nn.LayerNorm(
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
244
248
  def __init__(
245
249
  self,
246
250
  config: config_mllama.MllamaVisionConfig,
251
+ quant_config: Optional[QuantizationConfig] = None,
247
252
  num_layers=32,
248
253
  is_gated=False,
249
254
  output_hidden_states=None,
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
254
259
  self.layers = nn.ModuleList(
255
260
  [
256
261
  MllamaVisionEncoderLayer(
257
- config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
262
+ config,
263
+ quant_config,
264
+ is_gated,
265
+ prefix=add_prefix(f"layers.{i}", prefix),
258
266
  )
259
267
  for i in range(num_layers)
260
268
  ]
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
283
291
 
284
292
 
285
293
  class MllamaVisionModel(nn.Module):
286
- def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
294
+ def __init__(
295
+ self,
296
+ config: config_mllama.MllamaVisionConfig,
297
+ quant_config: Optional[QuantizationConfig] = None,
298
+ prefix: str = "",
299
+ ):
287
300
  super().__init__()
288
301
  self.image_size = config.image_size
289
302
  self.patch_size = config.patch_size
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
320
333
  # encoders
321
334
  self.transformer = MllamaVisionEncoder(
322
335
  config,
336
+ quant_config,
323
337
  config.num_hidden_layers,
324
338
  is_gated=False,
325
339
  output_hidden_states=config.intermediate_layers_indices,
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
327
341
  )
328
342
  self.global_transformer = MllamaVisionEncoder(
329
343
  config,
344
+ quant_config,
330
345
  config.num_global_layers,
331
346
  is_gated=True,
332
347
  prefix=add_prefix("global_transformer", prefix),
@@ -535,6 +550,7 @@ class MllamaTextCrossAttention(nn.Module):
535
550
  self.num_local_key_value_heads,
536
551
  layer_id=layer_id,
537
552
  is_cross_attention=True,
553
+ quant_config=quant_config,
538
554
  prefix=add_prefix("attn", prefix),
539
555
  )
540
556
 
@@ -764,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
764
780
 
765
781
 
766
782
  class MllamaForConditionalGeneration(nn.Module):
783
+ # BitandBytes specific attributes
784
+ default_bitsandbytes_target_modules = [
785
+ ".gate_proj.",
786
+ ".down_proj.",
787
+ ".up_proj.",
788
+ ".q_proj.",
789
+ ".k_proj.",
790
+ ".v_proj.",
791
+ ".o_proj.",
792
+ ]
793
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
794
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
795
+ bitsandbytes_stacked_params_mapping = {
796
+ # shard_name, weight_name, index
797
+ "q_proj": ("qkv_proj", 0),
798
+ "k_proj": ("qkv_proj", 1),
799
+ "v_proj": ("qkv_proj", 2),
800
+ "gate_proj": ("gate_up_proj", 0),
801
+ "up_proj": ("gate_up_proj", 1),
802
+ }
803
+
767
804
  def __init__(
768
805
  self,
769
806
  config: config_mllama.MllamaConfig,
@@ -771,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
771
808
  prefix: str = "",
772
809
  ):
773
810
  super().__init__()
811
+ self.quant_config = quant_config
774
812
  self.vocab_size = config.text_config.vocab_size
775
813
  self.hidden_size = config.text_config.hidden_size
776
814
  self.max_num_tiles = config.vision_config.max_num_tiles
@@ -781,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
781
819
  self.image_size = config.vision_config.image_size
782
820
 
783
821
  self.vision_model = MllamaVisionModel(
784
- config.vision_config, prefix=add_prefix("vision_model", prefix)
822
+ config.vision_config,
823
+ quant_config=quant_config,
824
+ prefix=add_prefix("vision_model", prefix),
785
825
  )
786
826
  self.language_model = MllamaForCausalLM(
787
827
  config.text_config,
788
828
  quant_config=quant_config,
789
829
  prefix=add_prefix("language_model", prefix),
790
830
  )
791
- self.multi_modal_projector = nn.Linear(
831
+ self.multi_modal_projector = ReplicatedLinear(
792
832
  config.vision_config.vision_output_dim,
793
833
  config.text_config.hidden_size,
794
834
  bias=True,
835
+ quant_config=quant_config,
836
+ prefix="multi_modal_projector",
795
837
  )
796
838
  self.logits_processor = LogitsProcessor(config.text_config)
797
839
  self.capture_mode = False
@@ -958,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
958
1000
  cross_attention_states = self.vision_model(
959
1001
  batched_images, batched_ar_ids, batched_ar_mask
960
1002
  )
961
- cross_attention_states = self.multi_modal_projector(cross_attention_states)
1003
+ cross_attention_states, _ = self.multi_modal_projector(
1004
+ cross_attention_states
1005
+ )
962
1006
 
963
1007
  bs, _, _, _, image_token_dim = cross_attention_states.shape
964
1008
  cross_attention_states = cross_attention_states.view(
@@ -1012,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
1012
1056
  if "vision_model" in name:
1013
1057
  # adapt to VisionAttention
1014
1058
  name = name.replace("self_attn.o_proj", "self_attn.proj")
1015
-
1016
1059
  param = params_dict.pop(name)
1017
1060
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
1018
1061
  weight_loader(param, loaded_weight)