sglang 0.4.4.post4__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/layers/attention/flashattention_backend.py +286 -9
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -0
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/model_executor/model_runner.py +1 -0
- sglang/srt/models/llama.py +12 -4
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/mllama4.py +154 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/METADATA +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/RECORD +32 -22
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
|
|
1
|
+
from typing import List, Mapping, Optional, Tuple, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from PIL import Image
|
5
|
+
from transformers import Llama4Processor
|
6
|
+
from transformers.image_utils import SizeDict
|
7
|
+
from transformers.models.llama4.image_processing_llama4 import (
|
8
|
+
find_supported_resolutions,
|
9
|
+
get_best_fit,
|
10
|
+
)
|
11
|
+
|
12
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
13
|
+
BaseMultimodalProcessor,
|
14
|
+
MultimodalSpecialTokens,
|
15
|
+
)
|
16
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
17
|
+
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
18
|
+
from sglang.srt.utils import load_image
|
19
|
+
|
20
|
+
|
21
|
+
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
22
|
+
models = [Llama4ForConditionalGeneration]
|
23
|
+
|
24
|
+
def __init__(self, hf_config, server_args, _processor):
|
25
|
+
super().__init__(hf_config, server_args, _processor)
|
26
|
+
self.vision_config = hf_config.vision_config
|
27
|
+
self.text_config = hf_config.text_config
|
28
|
+
self.multimodal_tokens = MultimodalSpecialTokens(
|
29
|
+
image_token=_processor.image_token
|
30
|
+
)
|
31
|
+
|
32
|
+
async def process_mm_data_async(
|
33
|
+
self,
|
34
|
+
image_data: List[Union[str, bytes]],
|
35
|
+
input_text,
|
36
|
+
max_req_input_len=None,
|
37
|
+
*args,
|
38
|
+
**kwargs,
|
39
|
+
):
|
40
|
+
if not image_data:
|
41
|
+
return None
|
42
|
+
|
43
|
+
if isinstance(input_text, list):
|
44
|
+
assert len(input_text) and isinstance(input_text[0], int)
|
45
|
+
input_text = self._processor.tokenizer.decode(input_text)
|
46
|
+
|
47
|
+
# Process images and text using the base processor's load_mm_data method
|
48
|
+
processed_data = self.load_mm_data(
|
49
|
+
prompt=input_text,
|
50
|
+
multimodal_tokens=self.multimodal_tokens,
|
51
|
+
max_req_input_len=max_req_input_len or 4096,
|
52
|
+
image_data=image_data,
|
53
|
+
return_text=True,
|
54
|
+
)
|
55
|
+
|
56
|
+
# Process the images using the processor
|
57
|
+
processor = Llama4Processor.from_pretrained(
|
58
|
+
self.server_args.model_path, **kwargs
|
59
|
+
)
|
60
|
+
|
61
|
+
# Process the prompt and images
|
62
|
+
image_inputs = processor(
|
63
|
+
text=processed_data.input_text,
|
64
|
+
images=processed_data.images,
|
65
|
+
return_tensors="pt",
|
66
|
+
)
|
67
|
+
|
68
|
+
# Handle image resolutions and aspect ratios
|
69
|
+
if "pixel_values" in image_inputs:
|
70
|
+
image_processor = processor.image_processor
|
71
|
+
tokenizer = self._processor.tokenizer
|
72
|
+
|
73
|
+
# Calculate tile size and find supported resolutions
|
74
|
+
tile_size = self.vision_config.image_size
|
75
|
+
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
76
|
+
|
77
|
+
possible_resolutions = find_supported_resolutions(
|
78
|
+
max_num_chunks=max_num_tiles,
|
79
|
+
patch_size=SizeDict(height=tile_size, width=tile_size),
|
80
|
+
)
|
81
|
+
|
82
|
+
# Find best fit for each image
|
83
|
+
best_fit_sizes = [
|
84
|
+
get_best_fit(
|
85
|
+
(image.size[1], image.size[0]), # (height, width)
|
86
|
+
torch.tensor(possible_resolutions),
|
87
|
+
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
88
|
+
)
|
89
|
+
for image in processed_data.images
|
90
|
+
]
|
91
|
+
|
92
|
+
# Calculate aspect ratios and patches per image
|
93
|
+
aspect_ratios = [
|
94
|
+
(image_size[0] // tile_size, image_size[1] // tile_size)
|
95
|
+
for image_size in best_fit_sizes
|
96
|
+
]
|
97
|
+
|
98
|
+
patches_per_image = [
|
99
|
+
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
100
|
+
]
|
101
|
+
|
102
|
+
# Add to image_inputs
|
103
|
+
image_inputs["aspect_ratios"] = aspect_ratios
|
104
|
+
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
|
105
|
+
|
106
|
+
# Process embed_is_patch
|
107
|
+
vocab = tokenizer.get_vocab()
|
108
|
+
patch_id = vocab.get(processor.img_patch_token, -1)
|
109
|
+
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
110
|
+
|
111
|
+
if patch_id != -1 and image_end_id != -1:
|
112
|
+
input_ids = image_inputs["input_ids"].view(-1)
|
113
|
+
|
114
|
+
# Remove BOS token if present
|
115
|
+
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
116
|
+
input_ids = input_ids[1:]
|
117
|
+
|
118
|
+
# Find image end indices and split input_ids
|
119
|
+
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
120
|
+
|
121
|
+
if image_end_indices.size(0) > 0:
|
122
|
+
# Split at image boundaries
|
123
|
+
split_indices = (image_end_indices + 1)[:-1]
|
124
|
+
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
125
|
+
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
126
|
+
|
127
|
+
# Create embed_is_patch for each image
|
128
|
+
embed_is_patch = []
|
129
|
+
for per_image_input_ids in split_input_ids:
|
130
|
+
embed_is_patch.append(per_image_input_ids == patch_id)
|
131
|
+
|
132
|
+
image_inputs["embed_is_patch"] = embed_is_patch
|
133
|
+
|
134
|
+
# Convert to the format expected by SGLang
|
135
|
+
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
136
|
+
|
137
|
+
# Add metadata for image processing
|
138
|
+
image_inputs["mm_items"] = [
|
139
|
+
MultimodalDataItem(
|
140
|
+
pixel_values=image_inputs["pixel_values"],
|
141
|
+
modality=Modality.IMAGE,
|
142
|
+
# Add additional metadata needed for Llama4 vision processing
|
143
|
+
embed_is_patch=image_inputs.get("embed_is_patch", None),
|
144
|
+
aspect_ratios=image_inputs.get("aspect_ratios", None),
|
145
|
+
patches_per_image=image_inputs.get("patches_per_image", None),
|
146
|
+
)
|
147
|
+
]
|
148
|
+
|
149
|
+
return image_inputs
|
150
|
+
|
151
|
+
def get_patch_per_chunk(self):
|
152
|
+
"""Calculate patches per chunk based on vision config"""
|
153
|
+
image_size = self.vision_config.image_size
|
154
|
+
patch_size = self.vision_config.patch_size
|
155
|
+
|
156
|
+
assert (
|
157
|
+
image_size % patch_size == 0
|
158
|
+
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
|
159
|
+
|
160
|
+
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
|
161
|
+
return (image_size // patch_size) ** 2 // ds_ratio
|
@@ -128,6 +128,7 @@ class ModelRunner:
|
|
128
128
|
self.model_config.attention_arch == AttentionArch.MLA
|
129
129
|
and not server_args.disable_mla
|
130
130
|
)
|
131
|
+
self.attention_chunk_size = model_config.attention_chunk_size
|
131
132
|
|
132
133
|
# Model-specific adjustment
|
133
134
|
self.model_specific_adjustment()
|
sglang/srt/models/llama.py
CHANGED
@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
|
|
63
63
|
hidden_act: str,
|
64
64
|
quant_config: Optional[QuantizationConfig] = None,
|
65
65
|
prefix: str = "",
|
66
|
+
reduce_results: bool = True,
|
66
67
|
) -> None:
|
67
68
|
super().__init__()
|
68
69
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
|
|
78
79
|
bias=False,
|
79
80
|
quant_config=quant_config,
|
80
81
|
prefix=add_prefix("down_proj", prefix),
|
82
|
+
reduce_results=reduce_results,
|
81
83
|
)
|
82
84
|
if hidden_act != "silu":
|
83
85
|
raise ValueError(
|
@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
|
|
281
283
|
self.layers = make_layers(
|
282
284
|
config.num_hidden_layers,
|
283
285
|
lambda idx, prefix: LlamaDecoderLayer(
|
284
|
-
config=config,
|
286
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
285
287
|
),
|
286
288
|
prefix="model.layers",
|
287
289
|
)
|
@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
|
|
375
377
|
super().__init__()
|
376
378
|
self.config = config
|
377
379
|
self.quant_config = quant_config
|
378
|
-
self.model =
|
379
|
-
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
380
|
-
)
|
380
|
+
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
381
381
|
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
382
382
|
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
383
383
|
if self.config.tie_word_embeddings:
|
@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
|
|
402
402
|
|
403
403
|
self.capture_aux_hidden_states = False
|
404
404
|
|
405
|
+
def _init_model(
|
406
|
+
self,
|
407
|
+
config: LlamaConfig,
|
408
|
+
quant_config: Optional[QuantizationConfig] = None,
|
409
|
+
prefix: str = "",
|
410
|
+
):
|
411
|
+
return LlamaModel(config, quant_config=quant_config, prefix=prefix)
|
412
|
+
|
405
413
|
@torch.no_grad()
|
406
414
|
def forward(
|
407
415
|
self,
|
@@ -0,0 +1,420 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
# Adapted from
|
16
|
+
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
|
17
|
+
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
|
+
|
19
|
+
import logging
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import Llama4TextConfig
|
25
|
+
|
26
|
+
from sglang.srt.distributed import (
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
tensor_model_parallel_all_reduce,
|
29
|
+
)
|
30
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
31
|
+
from sglang.srt.layers.linear import (
|
32
|
+
QKVParallelLinear,
|
33
|
+
ReplicatedLinear,
|
34
|
+
RowParallelLinear,
|
35
|
+
)
|
36
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
40
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
41
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
43
|
+
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
44
|
+
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
|
48
|
+
class Llama4MoE(nn.Module):
|
49
|
+
|
50
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
51
|
+
@staticmethod
|
52
|
+
def custom_routing_function(
|
53
|
+
hidden_states: torch.Tensor,
|
54
|
+
gating_output: torch.Tensor,
|
55
|
+
topk: int,
|
56
|
+
renormalize: bool,
|
57
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
58
|
+
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
|
59
|
+
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
60
|
+
hidden_states.dtype
|
61
|
+
)
|
62
|
+
return (
|
63
|
+
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
|
64
|
+
router_indices_aK.to(torch.int32),
|
65
|
+
)
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
config: Llama4TextConfig,
|
70
|
+
quant_config: Optional[QuantizationConfig] = None,
|
71
|
+
prefix: str = "",
|
72
|
+
):
|
73
|
+
super().__init__()
|
74
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
75
|
+
self.top_k = config.num_experts_per_tok
|
76
|
+
|
77
|
+
intermediate_size_moe = config.intermediate_size
|
78
|
+
self.router = ReplicatedLinear(
|
79
|
+
config.hidden_size,
|
80
|
+
config.num_local_experts,
|
81
|
+
bias=False,
|
82
|
+
quant_config=None,
|
83
|
+
prefix=add_prefix("router", prefix),
|
84
|
+
)
|
85
|
+
|
86
|
+
self.experts = FusedMoE(
|
87
|
+
num_experts=config.num_local_experts,
|
88
|
+
top_k=config.num_experts_per_tok,
|
89
|
+
hidden_size=config.hidden_size,
|
90
|
+
custom_routing_function=Llama4MoE.custom_routing_function,
|
91
|
+
intermediate_size=intermediate_size_moe,
|
92
|
+
reduce_results=False,
|
93
|
+
renormalize=False,
|
94
|
+
quant_config=quant_config,
|
95
|
+
apply_router_weight_on_input=True,
|
96
|
+
prefix=add_prefix("experts", prefix),
|
97
|
+
)
|
98
|
+
|
99
|
+
self.shared_expert = LlamaMLP(
|
100
|
+
hidden_size=config.hidden_size,
|
101
|
+
intermediate_size=intermediate_size_moe,
|
102
|
+
hidden_act="silu",
|
103
|
+
quant_config=quant_config,
|
104
|
+
prefix=add_prefix("shared_expert", prefix),
|
105
|
+
reduce_results=False, # We need to do scatter before reduce
|
106
|
+
)
|
107
|
+
|
108
|
+
def forward(self, hidden_states):
|
109
|
+
# router_scores: [num_tokens, num_experts]
|
110
|
+
router_logits, _ = self.router(hidden_states)
|
111
|
+
shared_out = self.shared_expert(hidden_states)
|
112
|
+
routed_out = self.experts(
|
113
|
+
hidden_states=hidden_states,
|
114
|
+
router_logits=router_logits,
|
115
|
+
)
|
116
|
+
out_aD = routed_out + shared_out
|
117
|
+
|
118
|
+
if self.tp_size > 1:
|
119
|
+
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
120
|
+
|
121
|
+
return out_aD
|
122
|
+
|
123
|
+
|
124
|
+
class Llama4Attention(nn.Module):
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
config: Llama4TextConfig,
|
129
|
+
layer_id: int,
|
130
|
+
hidden_size: int,
|
131
|
+
num_heads: int,
|
132
|
+
num_kv_heads: int,
|
133
|
+
rope_theta: float = 10000,
|
134
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
135
|
+
max_position_embeddings: int = 8192,
|
136
|
+
quant_config: Optional[QuantizationConfig] = None,
|
137
|
+
bias: bool = False,
|
138
|
+
bias_o_proj: bool = False,
|
139
|
+
prefix: str = "",
|
140
|
+
) -> None:
|
141
|
+
super().__init__()
|
142
|
+
self.layer_id = layer_id
|
143
|
+
self.hidden_size = hidden_size
|
144
|
+
self.use_rope = int((layer_id + 1) % 4 != 0)
|
145
|
+
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
146
|
+
tp_size = get_tensor_model_parallel_world_size()
|
147
|
+
self.total_num_heads = num_heads
|
148
|
+
assert self.total_num_heads % tp_size == 0
|
149
|
+
self.num_heads = self.total_num_heads // tp_size
|
150
|
+
self.total_num_kv_heads = num_kv_heads
|
151
|
+
if self.total_num_kv_heads >= tp_size:
|
152
|
+
# Number of KV heads is greater than TP size, so we partition
|
153
|
+
# the KV heads across multiple tensor parallel GPUs.
|
154
|
+
assert self.total_num_kv_heads % tp_size == 0
|
155
|
+
else:
|
156
|
+
# Number of KV heads is less than TP size, so we replicate
|
157
|
+
# the KV heads across multiple tensor parallel GPUs.
|
158
|
+
assert tp_size % self.total_num_kv_heads == 0
|
159
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
160
|
+
self.head_dim = config.head_dim
|
161
|
+
self.q_size = self.num_heads * self.head_dim
|
162
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
163
|
+
self.scaling = self.head_dim**-0.5
|
164
|
+
self.attn_temperature_tuning = config.attn_temperature_tuning
|
165
|
+
self.floor_scale = config.floor_scale
|
166
|
+
self.attn_scale = config.attn_scale
|
167
|
+
self.rope_theta = rope_theta
|
168
|
+
self.max_position_embeddings = max_position_embeddings
|
169
|
+
self.n_rep = self.num_heads // self.num_kv_heads
|
170
|
+
self.qk_norm = (
|
171
|
+
RMSNorm(
|
172
|
+
hidden_size=self.head_dim,
|
173
|
+
eps=config.rms_norm_eps,
|
174
|
+
)
|
175
|
+
if self.use_qk_norm
|
176
|
+
else None
|
177
|
+
)
|
178
|
+
self.qkv_proj = QKVParallelLinear(
|
179
|
+
hidden_size=hidden_size,
|
180
|
+
head_size=self.head_dim,
|
181
|
+
total_num_heads=self.total_num_heads,
|
182
|
+
total_num_kv_heads=self.total_num_kv_heads,
|
183
|
+
bias=bias,
|
184
|
+
quant_config=quant_config,
|
185
|
+
prefix=add_prefix("qkv_proj", prefix),
|
186
|
+
)
|
187
|
+
|
188
|
+
self.o_proj = RowParallelLinear(
|
189
|
+
input_size=self.total_num_heads * self.head_dim,
|
190
|
+
output_size=hidden_size,
|
191
|
+
bias=bias_o_proj,
|
192
|
+
quant_config=quant_config,
|
193
|
+
prefix=add_prefix("o_proj", prefix),
|
194
|
+
)
|
195
|
+
is_neox_style = True
|
196
|
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
197
|
+
if is_gguf and config.model_type in ["llama", "llama4"]:
|
198
|
+
is_neox_style = False
|
199
|
+
|
200
|
+
self.rotary_emb = (
|
201
|
+
get_rope(
|
202
|
+
self.head_dim,
|
203
|
+
rotary_dim=self.head_dim,
|
204
|
+
max_position=max_position_embeddings,
|
205
|
+
base=int(rope_theta),
|
206
|
+
rope_scaling=rope_scaling if rope_scaling != "default" else None,
|
207
|
+
is_neox_style=is_neox_style,
|
208
|
+
)
|
209
|
+
if self.use_rope
|
210
|
+
else None
|
211
|
+
)
|
212
|
+
|
213
|
+
self.attn = RadixAttention(
|
214
|
+
self.num_heads,
|
215
|
+
self.head_dim,
|
216
|
+
self.scaling,
|
217
|
+
num_kv_heads=self.num_kv_heads,
|
218
|
+
layer_id=layer_id,
|
219
|
+
prefix=add_prefix("attn", prefix),
|
220
|
+
use_irope=self.use_rope,
|
221
|
+
)
|
222
|
+
|
223
|
+
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
224
|
+
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
225
|
+
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
226
|
+
|
227
|
+
return attn_scale.unsqueeze(-1)
|
228
|
+
|
229
|
+
def forward(
|
230
|
+
self,
|
231
|
+
positions: torch.Tensor,
|
232
|
+
hidden_states: torch.Tensor,
|
233
|
+
forward_batch: ForwardBatch,
|
234
|
+
) -> torch.Tensor:
|
235
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
236
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
237
|
+
|
238
|
+
if self.rotary_emb is not None:
|
239
|
+
q, k = self.rotary_emb(positions, q, k)
|
240
|
+
|
241
|
+
if self.qk_norm is not None:
|
242
|
+
# TODO: support float
|
243
|
+
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
|
244
|
+
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
|
245
|
+
q = self.qk_norm(q).to(q.dtype)
|
246
|
+
k = self.qk_norm(k).to(k.dtype)
|
247
|
+
q = q.reshape(-1, self.q_size)
|
248
|
+
k = k.reshape(-1, self.kv_size)
|
249
|
+
|
250
|
+
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
251
|
+
# the inference-time temperature tuning function is customized to not affect short context
|
252
|
+
# while working at very long context
|
253
|
+
# https://arxiv.org/abs/2501.19399
|
254
|
+
if self.attn_temperature_tuning and not self.use_rope:
|
255
|
+
attn_scale = self._get_attn_scale(positions)
|
256
|
+
q = (q * attn_scale).to(q.dtype)
|
257
|
+
|
258
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
259
|
+
output, _ = self.o_proj(attn_output)
|
260
|
+
return output
|
261
|
+
|
262
|
+
|
263
|
+
class Llama4DecoderLayer(nn.Module):
|
264
|
+
def __init__(
|
265
|
+
self,
|
266
|
+
config: Llama4TextConfig,
|
267
|
+
layer_id: int = 0,
|
268
|
+
quant_config: Optional[QuantizationConfig] = None,
|
269
|
+
prefix: str = "",
|
270
|
+
):
|
271
|
+
super().__init__()
|
272
|
+
self.layer_id = layer_id
|
273
|
+
self.hidden_size = config.hidden_size
|
274
|
+
rope_theta = config.rope_theta
|
275
|
+
rope_scaling = config.rope_scaling
|
276
|
+
max_position_embeddings = config.max_position_embeddings
|
277
|
+
|
278
|
+
self.self_attn = Llama4Attention(
|
279
|
+
config=config,
|
280
|
+
layer_id=layer_id,
|
281
|
+
hidden_size=self.hidden_size,
|
282
|
+
num_heads=config.num_attention_heads,
|
283
|
+
num_kv_heads=config.num_key_value_heads,
|
284
|
+
rope_theta=rope_theta,
|
285
|
+
rope_scaling=rope_scaling,
|
286
|
+
max_position_embeddings=max_position_embeddings,
|
287
|
+
quant_config=quant_config,
|
288
|
+
bias=False,
|
289
|
+
bias_o_proj=False,
|
290
|
+
prefix=add_prefix("self_attn", prefix),
|
291
|
+
)
|
292
|
+
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
|
293
|
+
if is_moe_layer:
|
294
|
+
self.feed_forward = Llama4MoE(
|
295
|
+
config=config,
|
296
|
+
quant_config=quant_config,
|
297
|
+
prefix=add_prefix("feed_forward", prefix),
|
298
|
+
)
|
299
|
+
else:
|
300
|
+
self.feed_forward = LlamaMLP(
|
301
|
+
hidden_size=self.hidden_size,
|
302
|
+
intermediate_size=config.intermediate_size_mlp,
|
303
|
+
hidden_act="silu",
|
304
|
+
quant_config=quant_config,
|
305
|
+
prefix=add_prefix("feed_forward", prefix),
|
306
|
+
)
|
307
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
308
|
+
self.post_attention_layernorm = RMSNorm(
|
309
|
+
config.hidden_size, eps=config.rms_norm_eps
|
310
|
+
)
|
311
|
+
|
312
|
+
def forward(
|
313
|
+
self,
|
314
|
+
positions: torch.Tensor,
|
315
|
+
hidden_states: torch.Tensor,
|
316
|
+
forward_batch: ForwardBatch,
|
317
|
+
residual: Optional[torch.Tensor],
|
318
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
319
|
+
# Self Attention
|
320
|
+
if residual is None:
|
321
|
+
residual = hidden_states
|
322
|
+
hidden_states = self.input_layernorm(hidden_states)
|
323
|
+
else:
|
324
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
325
|
+
hidden_states = self.self_attn(
|
326
|
+
positions=positions,
|
327
|
+
hidden_states=hidden_states,
|
328
|
+
forward_batch=forward_batch,
|
329
|
+
)
|
330
|
+
|
331
|
+
# Fully Connected
|
332
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
333
|
+
hidden_states = self.feed_forward(hidden_states)
|
334
|
+
return hidden_states, residual
|
335
|
+
|
336
|
+
|
337
|
+
class Llama4Model(nn.Module):
|
338
|
+
def __init__(
|
339
|
+
self,
|
340
|
+
config: Llama4TextConfig,
|
341
|
+
quant_config: Optional[QuantizationConfig] = None,
|
342
|
+
prefix: str = "",
|
343
|
+
) -> None:
|
344
|
+
super().__init__()
|
345
|
+
self.config = config
|
346
|
+
self.padding_idx = config.pad_token_id
|
347
|
+
self.vocab_size = config.vocab_size
|
348
|
+
self.embed_tokens = VocabParallelEmbedding(
|
349
|
+
config.vocab_size,
|
350
|
+
config.hidden_size,
|
351
|
+
quant_config=quant_config,
|
352
|
+
prefix=add_prefix("embed_tokens", prefix),
|
353
|
+
)
|
354
|
+
self.layers = make_layers(
|
355
|
+
config.num_hidden_layers,
|
356
|
+
lambda idx, prefix: Llama4DecoderLayer(
|
357
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
358
|
+
),
|
359
|
+
prefix="model.layers",
|
360
|
+
)
|
361
|
+
|
362
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
363
|
+
self.layers_to_capture = []
|
364
|
+
|
365
|
+
def forward(
|
366
|
+
self,
|
367
|
+
input_ids: torch.Tensor,
|
368
|
+
positions: torch.Tensor,
|
369
|
+
forward_batch: ForwardBatch,
|
370
|
+
input_embeds: torch.Tensor = None,
|
371
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
372
|
+
if input_embeds is None:
|
373
|
+
hidden_states = self.embed_tokens(input_ids)
|
374
|
+
else:
|
375
|
+
hidden_states = input_embeds
|
376
|
+
residual = None
|
377
|
+
aux_hidden_states = []
|
378
|
+
for i in range(len(self.layers)):
|
379
|
+
if i in self.layers_to_capture:
|
380
|
+
aux_hidden_states.append(hidden_states + residual)
|
381
|
+
layer = self.layers[i]
|
382
|
+
hidden_states, residual = layer(
|
383
|
+
positions,
|
384
|
+
hidden_states,
|
385
|
+
forward_batch,
|
386
|
+
residual,
|
387
|
+
)
|
388
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
389
|
+
|
390
|
+
if len(aux_hidden_states) == 0:
|
391
|
+
return hidden_states
|
392
|
+
|
393
|
+
return hidden_states, aux_hidden_states
|
394
|
+
|
395
|
+
|
396
|
+
class Llama4ForCausalLM(LlamaForCausalLM):
|
397
|
+
|
398
|
+
packed_modules_mapping = {
|
399
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
400
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
401
|
+
}
|
402
|
+
|
403
|
+
def __init__(
|
404
|
+
self,
|
405
|
+
config: Llama4TextConfig,
|
406
|
+
quant_config: Optional[QuantizationConfig] = None,
|
407
|
+
prefix: str = "",
|
408
|
+
):
|
409
|
+
super().__init__(config, quant_config, prefix)
|
410
|
+
|
411
|
+
def _init_model(
|
412
|
+
self,
|
413
|
+
config: Llama4TextConfig,
|
414
|
+
quant_config: Optional[QuantizationConfig] = None,
|
415
|
+
prefix: str = "",
|
416
|
+
):
|
417
|
+
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
|
418
|
+
|
419
|
+
|
420
|
+
EntryClass = [Llama4ForCausalLM]
|