sglang 0.4.4.post4__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/lang/chat_template.py +24 -0
  2. sglang/srt/configs/model_config.py +4 -0
  3. sglang/srt/conversation.py +29 -4
  4. sglang/srt/layers/attention/flashattention_backend.py +286 -9
  5. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  6. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  7. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  8. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -3
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  15. sglang/srt/layers/quantization/__init__.py +1 -0
  16. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  17. sglang/srt/layers/quantization/fp8.py +3 -1
  18. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  19. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  20. sglang/srt/layers/radix_attention.py +2 -0
  21. sglang/srt/layers/rotary_embedding.py +63 -0
  22. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  23. sglang/srt/model_executor/model_runner.py +1 -0
  24. sglang/srt/models/llama.py +12 -4
  25. sglang/srt/models/llama4.py +420 -0
  26. sglang/srt/models/mllama4.py +154 -0
  27. sglang/version.py +1 -1
  28. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/METADATA +1 -1
  29. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/RECORD +32 -22
  30. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  31. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  32. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
1
+ from typing import List, Mapping, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import Llama4Processor
6
+ from transformers.image_utils import SizeDict
7
+ from transformers.models.llama4.image_processing_llama4 import (
8
+ find_supported_resolutions,
9
+ get_best_fit,
10
+ )
11
+
12
+ from sglang.srt.managers.multimodal_processors.base_processor import (
13
+ BaseMultimodalProcessor,
14
+ MultimodalSpecialTokens,
15
+ )
16
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
17
+ from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
18
+ from sglang.srt.utils import load_image
19
+
20
+
21
+ class Mllama4ImageProcessor(BaseMultimodalProcessor):
22
+ models = [Llama4ForConditionalGeneration]
23
+
24
+ def __init__(self, hf_config, server_args, _processor):
25
+ super().__init__(hf_config, server_args, _processor)
26
+ self.vision_config = hf_config.vision_config
27
+ self.text_config = hf_config.text_config
28
+ self.multimodal_tokens = MultimodalSpecialTokens(
29
+ image_token=_processor.image_token
30
+ )
31
+
32
+ async def process_mm_data_async(
33
+ self,
34
+ image_data: List[Union[str, bytes]],
35
+ input_text,
36
+ max_req_input_len=None,
37
+ *args,
38
+ **kwargs,
39
+ ):
40
+ if not image_data:
41
+ return None
42
+
43
+ if isinstance(input_text, list):
44
+ assert len(input_text) and isinstance(input_text[0], int)
45
+ input_text = self._processor.tokenizer.decode(input_text)
46
+
47
+ # Process images and text using the base processor's load_mm_data method
48
+ processed_data = self.load_mm_data(
49
+ prompt=input_text,
50
+ multimodal_tokens=self.multimodal_tokens,
51
+ max_req_input_len=max_req_input_len or 4096,
52
+ image_data=image_data,
53
+ return_text=True,
54
+ )
55
+
56
+ # Process the images using the processor
57
+ processor = Llama4Processor.from_pretrained(
58
+ self.server_args.model_path, **kwargs
59
+ )
60
+
61
+ # Process the prompt and images
62
+ image_inputs = processor(
63
+ text=processed_data.input_text,
64
+ images=processed_data.images,
65
+ return_tensors="pt",
66
+ )
67
+
68
+ # Handle image resolutions and aspect ratios
69
+ if "pixel_values" in image_inputs:
70
+ image_processor = processor.image_processor
71
+ tokenizer = self._processor.tokenizer
72
+
73
+ # Calculate tile size and find supported resolutions
74
+ tile_size = self.vision_config.image_size
75
+ max_num_tiles = getattr(self.vision_config, "max_patches", 1)
76
+
77
+ possible_resolutions = find_supported_resolutions(
78
+ max_num_chunks=max_num_tiles,
79
+ patch_size=SizeDict(height=tile_size, width=tile_size),
80
+ )
81
+
82
+ # Find best fit for each image
83
+ best_fit_sizes = [
84
+ get_best_fit(
85
+ (image.size[1], image.size[0]), # (height, width)
86
+ torch.tensor(possible_resolutions),
87
+ resize_to_max_canvas=image_processor.resize_to_max_canvas,
88
+ )
89
+ for image in processed_data.images
90
+ ]
91
+
92
+ # Calculate aspect ratios and patches per image
93
+ aspect_ratios = [
94
+ (image_size[0] // tile_size, image_size[1] // tile_size)
95
+ for image_size in best_fit_sizes
96
+ ]
97
+
98
+ patches_per_image = [
99
+ 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
100
+ ]
101
+
102
+ # Add to image_inputs
103
+ image_inputs["aspect_ratios"] = aspect_ratios
104
+ image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
105
+
106
+ # Process embed_is_patch
107
+ vocab = tokenizer.get_vocab()
108
+ patch_id = vocab.get(processor.img_patch_token, -1)
109
+ image_end_id = vocab.get(processor.end_of_img_token, -1)
110
+
111
+ if patch_id != -1 and image_end_id != -1:
112
+ input_ids = image_inputs["input_ids"].view(-1)
113
+
114
+ # Remove BOS token if present
115
+ if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
116
+ input_ids = input_ids[1:]
117
+
118
+ # Find image end indices and split input_ids
119
+ image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
120
+
121
+ if image_end_indices.size(0) > 0:
122
+ # Split at image boundaries
123
+ split_indices = (image_end_indices + 1)[:-1]
124
+ split_input_ids = torch.tensor_split(input_ids, split_indices)
125
+ split_input_ids = [x for x in split_input_ids if x.numel() > 0]
126
+
127
+ # Create embed_is_patch for each image
128
+ embed_is_patch = []
129
+ for per_image_input_ids in split_input_ids:
130
+ embed_is_patch.append(per_image_input_ids == patch_id)
131
+
132
+ image_inputs["embed_is_patch"] = embed_is_patch
133
+
134
+ # Convert to the format expected by SGLang
135
+ image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
136
+
137
+ # Add metadata for image processing
138
+ image_inputs["mm_items"] = [
139
+ MultimodalDataItem(
140
+ pixel_values=image_inputs["pixel_values"],
141
+ modality=Modality.IMAGE,
142
+ # Add additional metadata needed for Llama4 vision processing
143
+ embed_is_patch=image_inputs.get("embed_is_patch", None),
144
+ aspect_ratios=image_inputs.get("aspect_ratios", None),
145
+ patches_per_image=image_inputs.get("patches_per_image", None),
146
+ )
147
+ ]
148
+
149
+ return image_inputs
150
+
151
+ def get_patch_per_chunk(self):
152
+ """Calculate patches per chunk based on vision config"""
153
+ image_size = self.vision_config.image_size
154
+ patch_size = self.vision_config.patch_size
155
+
156
+ assert (
157
+ image_size % patch_size == 0
158
+ ), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
159
+
160
+ ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
161
+ return (image_size // patch_size) ** 2 // ds_ratio
@@ -128,6 +128,7 @@ class ModelRunner:
128
128
  self.model_config.attention_arch == AttentionArch.MLA
129
129
  and not server_args.disable_mla
130
130
  )
131
+ self.attention_chunk_size = model_config.attention_chunk_size
131
132
 
132
133
  # Model-specific adjustment
133
134
  self.model_specific_adjustment()
@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
63
63
  hidden_act: str,
64
64
  quant_config: Optional[QuantizationConfig] = None,
65
65
  prefix: str = "",
66
+ reduce_results: bool = True,
66
67
  ) -> None:
67
68
  super().__init__()
68
69
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
78
79
  bias=False,
79
80
  quant_config=quant_config,
80
81
  prefix=add_prefix("down_proj", prefix),
82
+ reduce_results=reduce_results,
81
83
  )
82
84
  if hidden_act != "silu":
83
85
  raise ValueError(
@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
281
283
  self.layers = make_layers(
282
284
  config.num_hidden_layers,
283
285
  lambda idx, prefix: LlamaDecoderLayer(
284
- config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
286
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
285
287
  ),
286
288
  prefix="model.layers",
287
289
  )
@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
375
377
  super().__init__()
376
378
  self.config = config
377
379
  self.quant_config = quant_config
378
- self.model = LlamaModel(
379
- config, quant_config=quant_config, prefix=add_prefix("model", prefix)
380
- )
380
+ self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
381
381
  # Llama 3.2 1B Instruct set tie_word_embeddings to True
382
382
  # Llama 3.1 8B Instruct set tie_word_embeddings to False
383
383
  if self.config.tie_word_embeddings:
@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
402
402
 
403
403
  self.capture_aux_hidden_states = False
404
404
 
405
+ def _init_model(
406
+ self,
407
+ config: LlamaConfig,
408
+ quant_config: Optional[QuantizationConfig] = None,
409
+ prefix: str = "",
410
+ ):
411
+ return LlamaModel(config, quant_config=quant_config, prefix=prefix)
412
+
405
413
  @torch.no_grad()
406
414
  def forward(
407
415
  self,
@@ -0,0 +1,420 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from
16
+ # https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
17
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
18
+
19
+ import logging
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import Llama4TextConfig
25
+
26
+ from sglang.srt.distributed import (
27
+ get_tensor_model_parallel_world_size,
28
+ tensor_model_parallel_all_reduce,
29
+ )
30
+ from sglang.srt.layers.layernorm import RMSNorm
31
+ from sglang.srt.layers.linear import (
32
+ QKVParallelLinear,
33
+ ReplicatedLinear,
34
+ RowParallelLinear,
35
+ )
36
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
40
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
43
+ from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class Llama4MoE(nn.Module):
49
+
50
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
51
+ @staticmethod
52
+ def custom_routing_function(
53
+ hidden_states: torch.Tensor,
54
+ gating_output: torch.Tensor,
55
+ topk: int,
56
+ renormalize: bool,
57
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
59
+ router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
60
+ hidden_states.dtype
61
+ )
62
+ return (
63
+ router_scores_aK.view(-1).reshape(router_scores_aK.shape),
64
+ router_indices_aK.to(torch.int32),
65
+ )
66
+
67
+ def __init__(
68
+ self,
69
+ config: Llama4TextConfig,
70
+ quant_config: Optional[QuantizationConfig] = None,
71
+ prefix: str = "",
72
+ ):
73
+ super().__init__()
74
+ self.tp_size = get_tensor_model_parallel_world_size()
75
+ self.top_k = config.num_experts_per_tok
76
+
77
+ intermediate_size_moe = config.intermediate_size
78
+ self.router = ReplicatedLinear(
79
+ config.hidden_size,
80
+ config.num_local_experts,
81
+ bias=False,
82
+ quant_config=None,
83
+ prefix=add_prefix("router", prefix),
84
+ )
85
+
86
+ self.experts = FusedMoE(
87
+ num_experts=config.num_local_experts,
88
+ top_k=config.num_experts_per_tok,
89
+ hidden_size=config.hidden_size,
90
+ custom_routing_function=Llama4MoE.custom_routing_function,
91
+ intermediate_size=intermediate_size_moe,
92
+ reduce_results=False,
93
+ renormalize=False,
94
+ quant_config=quant_config,
95
+ apply_router_weight_on_input=True,
96
+ prefix=add_prefix("experts", prefix),
97
+ )
98
+
99
+ self.shared_expert = LlamaMLP(
100
+ hidden_size=config.hidden_size,
101
+ intermediate_size=intermediate_size_moe,
102
+ hidden_act="silu",
103
+ quant_config=quant_config,
104
+ prefix=add_prefix("shared_expert", prefix),
105
+ reduce_results=False, # We need to do scatter before reduce
106
+ )
107
+
108
+ def forward(self, hidden_states):
109
+ # router_scores: [num_tokens, num_experts]
110
+ router_logits, _ = self.router(hidden_states)
111
+ shared_out = self.shared_expert(hidden_states)
112
+ routed_out = self.experts(
113
+ hidden_states=hidden_states,
114
+ router_logits=router_logits,
115
+ )
116
+ out_aD = routed_out + shared_out
117
+
118
+ if self.tp_size > 1:
119
+ out_aD = tensor_model_parallel_all_reduce(out_aD)
120
+
121
+ return out_aD
122
+
123
+
124
+ class Llama4Attention(nn.Module):
125
+
126
+ def __init__(
127
+ self,
128
+ config: Llama4TextConfig,
129
+ layer_id: int,
130
+ hidden_size: int,
131
+ num_heads: int,
132
+ num_kv_heads: int,
133
+ rope_theta: float = 10000,
134
+ rope_scaling: Optional[Dict[str, Any]] = None,
135
+ max_position_embeddings: int = 8192,
136
+ quant_config: Optional[QuantizationConfig] = None,
137
+ bias: bool = False,
138
+ bias_o_proj: bool = False,
139
+ prefix: str = "",
140
+ ) -> None:
141
+ super().__init__()
142
+ self.layer_id = layer_id
143
+ self.hidden_size = hidden_size
144
+ self.use_rope = int((layer_id + 1) % 4 != 0)
145
+ self.use_qk_norm = config.use_qk_norm and self.use_rope
146
+ tp_size = get_tensor_model_parallel_world_size()
147
+ self.total_num_heads = num_heads
148
+ assert self.total_num_heads % tp_size == 0
149
+ self.num_heads = self.total_num_heads // tp_size
150
+ self.total_num_kv_heads = num_kv_heads
151
+ if self.total_num_kv_heads >= tp_size:
152
+ # Number of KV heads is greater than TP size, so we partition
153
+ # the KV heads across multiple tensor parallel GPUs.
154
+ assert self.total_num_kv_heads % tp_size == 0
155
+ else:
156
+ # Number of KV heads is less than TP size, so we replicate
157
+ # the KV heads across multiple tensor parallel GPUs.
158
+ assert tp_size % self.total_num_kv_heads == 0
159
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
160
+ self.head_dim = config.head_dim
161
+ self.q_size = self.num_heads * self.head_dim
162
+ self.kv_size = self.num_kv_heads * self.head_dim
163
+ self.scaling = self.head_dim**-0.5
164
+ self.attn_temperature_tuning = config.attn_temperature_tuning
165
+ self.floor_scale = config.floor_scale
166
+ self.attn_scale = config.attn_scale
167
+ self.rope_theta = rope_theta
168
+ self.max_position_embeddings = max_position_embeddings
169
+ self.n_rep = self.num_heads // self.num_kv_heads
170
+ self.qk_norm = (
171
+ RMSNorm(
172
+ hidden_size=self.head_dim,
173
+ eps=config.rms_norm_eps,
174
+ )
175
+ if self.use_qk_norm
176
+ else None
177
+ )
178
+ self.qkv_proj = QKVParallelLinear(
179
+ hidden_size=hidden_size,
180
+ head_size=self.head_dim,
181
+ total_num_heads=self.total_num_heads,
182
+ total_num_kv_heads=self.total_num_kv_heads,
183
+ bias=bias,
184
+ quant_config=quant_config,
185
+ prefix=add_prefix("qkv_proj", prefix),
186
+ )
187
+
188
+ self.o_proj = RowParallelLinear(
189
+ input_size=self.total_num_heads * self.head_dim,
190
+ output_size=hidden_size,
191
+ bias=bias_o_proj,
192
+ quant_config=quant_config,
193
+ prefix=add_prefix("o_proj", prefix),
194
+ )
195
+ is_neox_style = True
196
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
197
+ if is_gguf and config.model_type in ["llama", "llama4"]:
198
+ is_neox_style = False
199
+
200
+ self.rotary_emb = (
201
+ get_rope(
202
+ self.head_dim,
203
+ rotary_dim=self.head_dim,
204
+ max_position=max_position_embeddings,
205
+ base=int(rope_theta),
206
+ rope_scaling=rope_scaling if rope_scaling != "default" else None,
207
+ is_neox_style=is_neox_style,
208
+ )
209
+ if self.use_rope
210
+ else None
211
+ )
212
+
213
+ self.attn = RadixAttention(
214
+ self.num_heads,
215
+ self.head_dim,
216
+ self.scaling,
217
+ num_kv_heads=self.num_kv_heads,
218
+ layer_id=layer_id,
219
+ prefix=add_prefix("attn", prefix),
220
+ use_irope=self.use_rope,
221
+ )
222
+
223
+ def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
224
+ floor = torch.floor((positions + 1.0) / self.floor_scale)
225
+ attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
226
+
227
+ return attn_scale.unsqueeze(-1)
228
+
229
+ def forward(
230
+ self,
231
+ positions: torch.Tensor,
232
+ hidden_states: torch.Tensor,
233
+ forward_batch: ForwardBatch,
234
+ ) -> torch.Tensor:
235
+ qkv, _ = self.qkv_proj(hidden_states)
236
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
237
+
238
+ if self.rotary_emb is not None:
239
+ q, k = self.rotary_emb(positions, q, k)
240
+
241
+ if self.qk_norm is not None:
242
+ # TODO: support float
243
+ q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
244
+ k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
245
+ q = self.qk_norm(q).to(q.dtype)
246
+ k = self.qk_norm(k).to(k.dtype)
247
+ q = q.reshape(-1, self.q_size)
248
+ k = k.reshape(-1, self.kv_size)
249
+
250
+ # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
251
+ # the inference-time temperature tuning function is customized to not affect short context
252
+ # while working at very long context
253
+ # https://arxiv.org/abs/2501.19399
254
+ if self.attn_temperature_tuning and not self.use_rope:
255
+ attn_scale = self._get_attn_scale(positions)
256
+ q = (q * attn_scale).to(q.dtype)
257
+
258
+ attn_output = self.attn(q, k, v, forward_batch)
259
+ output, _ = self.o_proj(attn_output)
260
+ return output
261
+
262
+
263
+ class Llama4DecoderLayer(nn.Module):
264
+ def __init__(
265
+ self,
266
+ config: Llama4TextConfig,
267
+ layer_id: int = 0,
268
+ quant_config: Optional[QuantizationConfig] = None,
269
+ prefix: str = "",
270
+ ):
271
+ super().__init__()
272
+ self.layer_id = layer_id
273
+ self.hidden_size = config.hidden_size
274
+ rope_theta = config.rope_theta
275
+ rope_scaling = config.rope_scaling
276
+ max_position_embeddings = config.max_position_embeddings
277
+
278
+ self.self_attn = Llama4Attention(
279
+ config=config,
280
+ layer_id=layer_id,
281
+ hidden_size=self.hidden_size,
282
+ num_heads=config.num_attention_heads,
283
+ num_kv_heads=config.num_key_value_heads,
284
+ rope_theta=rope_theta,
285
+ rope_scaling=rope_scaling,
286
+ max_position_embeddings=max_position_embeddings,
287
+ quant_config=quant_config,
288
+ bias=False,
289
+ bias_o_proj=False,
290
+ prefix=add_prefix("self_attn", prefix),
291
+ )
292
+ is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
293
+ if is_moe_layer:
294
+ self.feed_forward = Llama4MoE(
295
+ config=config,
296
+ quant_config=quant_config,
297
+ prefix=add_prefix("feed_forward", prefix),
298
+ )
299
+ else:
300
+ self.feed_forward = LlamaMLP(
301
+ hidden_size=self.hidden_size,
302
+ intermediate_size=config.intermediate_size_mlp,
303
+ hidden_act="silu",
304
+ quant_config=quant_config,
305
+ prefix=add_prefix("feed_forward", prefix),
306
+ )
307
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
308
+ self.post_attention_layernorm = RMSNorm(
309
+ config.hidden_size, eps=config.rms_norm_eps
310
+ )
311
+
312
+ def forward(
313
+ self,
314
+ positions: torch.Tensor,
315
+ hidden_states: torch.Tensor,
316
+ forward_batch: ForwardBatch,
317
+ residual: Optional[torch.Tensor],
318
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
319
+ # Self Attention
320
+ if residual is None:
321
+ residual = hidden_states
322
+ hidden_states = self.input_layernorm(hidden_states)
323
+ else:
324
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
325
+ hidden_states = self.self_attn(
326
+ positions=positions,
327
+ hidden_states=hidden_states,
328
+ forward_batch=forward_batch,
329
+ )
330
+
331
+ # Fully Connected
332
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
333
+ hidden_states = self.feed_forward(hidden_states)
334
+ return hidden_states, residual
335
+
336
+
337
+ class Llama4Model(nn.Module):
338
+ def __init__(
339
+ self,
340
+ config: Llama4TextConfig,
341
+ quant_config: Optional[QuantizationConfig] = None,
342
+ prefix: str = "",
343
+ ) -> None:
344
+ super().__init__()
345
+ self.config = config
346
+ self.padding_idx = config.pad_token_id
347
+ self.vocab_size = config.vocab_size
348
+ self.embed_tokens = VocabParallelEmbedding(
349
+ config.vocab_size,
350
+ config.hidden_size,
351
+ quant_config=quant_config,
352
+ prefix=add_prefix("embed_tokens", prefix),
353
+ )
354
+ self.layers = make_layers(
355
+ config.num_hidden_layers,
356
+ lambda idx, prefix: Llama4DecoderLayer(
357
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
358
+ ),
359
+ prefix="model.layers",
360
+ )
361
+
362
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
363
+ self.layers_to_capture = []
364
+
365
+ def forward(
366
+ self,
367
+ input_ids: torch.Tensor,
368
+ positions: torch.Tensor,
369
+ forward_batch: ForwardBatch,
370
+ input_embeds: torch.Tensor = None,
371
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
372
+ if input_embeds is None:
373
+ hidden_states = self.embed_tokens(input_ids)
374
+ else:
375
+ hidden_states = input_embeds
376
+ residual = None
377
+ aux_hidden_states = []
378
+ for i in range(len(self.layers)):
379
+ if i in self.layers_to_capture:
380
+ aux_hidden_states.append(hidden_states + residual)
381
+ layer = self.layers[i]
382
+ hidden_states, residual = layer(
383
+ positions,
384
+ hidden_states,
385
+ forward_batch,
386
+ residual,
387
+ )
388
+ hidden_states, _ = self.norm(hidden_states, residual)
389
+
390
+ if len(aux_hidden_states) == 0:
391
+ return hidden_states
392
+
393
+ return hidden_states, aux_hidden_states
394
+
395
+
396
+ class Llama4ForCausalLM(LlamaForCausalLM):
397
+
398
+ packed_modules_mapping = {
399
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
400
+ "gate_up_proj": ["gate_proj", "up_proj"],
401
+ }
402
+
403
+ def __init__(
404
+ self,
405
+ config: Llama4TextConfig,
406
+ quant_config: Optional[QuantizationConfig] = None,
407
+ prefix: str = "",
408
+ ):
409
+ super().__init__(config, quant_config, prefix)
410
+
411
+ def _init_model(
412
+ self,
413
+ config: Llama4TextConfig,
414
+ quant_config: Optional[QuantizationConfig] = None,
415
+ prefix: str = "",
416
+ ):
417
+ return Llama4Model(config, quant_config=quant_config, prefix=prefix)
418
+
419
+
420
+ EntryClass = [Llama4ForCausalLM]