sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,487 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
# Adapted from
|
16
|
+
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
|
17
|
+
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
|
+
|
19
|
+
import logging
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import Llama4TextConfig
|
25
|
+
|
26
|
+
from sglang.srt.distributed import (
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
tensor_model_parallel_all_reduce,
|
29
|
+
)
|
30
|
+
from sglang.srt.layers.dp_attention import (
|
31
|
+
dp_gather_partial,
|
32
|
+
dp_scatter,
|
33
|
+
get_attention_dp_size,
|
34
|
+
get_attention_tp_rank,
|
35
|
+
get_attention_tp_size,
|
36
|
+
)
|
37
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
38
|
+
from sglang.srt.layers.linear import (
|
39
|
+
QKVParallelLinear,
|
40
|
+
ReplicatedLinear,
|
41
|
+
RowParallelLinear,
|
42
|
+
)
|
43
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
44
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
46
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
47
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
|
+
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
51
|
+
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
52
|
+
|
53
|
+
logger = logging.getLogger(__name__)
|
54
|
+
|
55
|
+
|
56
|
+
class Llama4MoE(nn.Module):
|
57
|
+
|
58
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
59
|
+
@staticmethod
|
60
|
+
def custom_routing_function(
|
61
|
+
hidden_states: torch.Tensor,
|
62
|
+
gating_output: torch.Tensor,
|
63
|
+
topk: int,
|
64
|
+
renormalize: bool,
|
65
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
66
|
+
router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
|
67
|
+
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
68
|
+
hidden_states.dtype
|
69
|
+
)
|
70
|
+
return (
|
71
|
+
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
|
72
|
+
router_indices_aK.to(torch.int32),
|
73
|
+
)
|
74
|
+
|
75
|
+
def __init__(
|
76
|
+
self,
|
77
|
+
config: Llama4TextConfig,
|
78
|
+
quant_config: Optional[QuantizationConfig] = None,
|
79
|
+
prefix: str = "",
|
80
|
+
):
|
81
|
+
super().__init__()
|
82
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
83
|
+
self.top_k = config.num_experts_per_tok
|
84
|
+
|
85
|
+
intermediate_size_moe = config.intermediate_size
|
86
|
+
self.router = ReplicatedLinear(
|
87
|
+
config.hidden_size,
|
88
|
+
config.num_local_experts,
|
89
|
+
bias=False,
|
90
|
+
quant_config=None,
|
91
|
+
prefix=add_prefix("router", prefix),
|
92
|
+
)
|
93
|
+
|
94
|
+
self.experts = FusedMoE(
|
95
|
+
num_experts=config.num_local_experts,
|
96
|
+
top_k=config.num_experts_per_tok,
|
97
|
+
hidden_size=config.hidden_size,
|
98
|
+
custom_routing_function=Llama4MoE.custom_routing_function,
|
99
|
+
intermediate_size=intermediate_size_moe,
|
100
|
+
reduce_results=False,
|
101
|
+
renormalize=False,
|
102
|
+
quant_config=quant_config,
|
103
|
+
apply_router_weight_on_input=True,
|
104
|
+
prefix=add_prefix("experts", prefix),
|
105
|
+
)
|
106
|
+
|
107
|
+
self.shared_expert = LlamaMLP(
|
108
|
+
hidden_size=config.hidden_size,
|
109
|
+
intermediate_size=intermediate_size_moe,
|
110
|
+
hidden_act="silu",
|
111
|
+
quant_config=quant_config,
|
112
|
+
prefix=add_prefix("shared_expert", prefix),
|
113
|
+
reduce_results=False, # We need to do scatter before reduce
|
114
|
+
)
|
115
|
+
|
116
|
+
def forward(self, hidden_states):
|
117
|
+
# router_scores: [num_tokens, num_experts]
|
118
|
+
router_logits, _ = self.router(hidden_states)
|
119
|
+
shared_out = self.shared_expert(hidden_states)
|
120
|
+
routed_out = self.experts(
|
121
|
+
hidden_states=hidden_states,
|
122
|
+
router_logits=router_logits,
|
123
|
+
)
|
124
|
+
out_aD = routed_out + shared_out
|
125
|
+
|
126
|
+
if self.tp_size > 1:
|
127
|
+
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
128
|
+
|
129
|
+
return out_aD
|
130
|
+
|
131
|
+
|
132
|
+
class Llama4Attention(nn.Module):
|
133
|
+
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
config: Llama4TextConfig,
|
137
|
+
layer_id: int,
|
138
|
+
hidden_size: int,
|
139
|
+
num_heads: int,
|
140
|
+
num_kv_heads: int,
|
141
|
+
rope_theta: float = 10000,
|
142
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
143
|
+
max_position_embeddings: int = 8192,
|
144
|
+
quant_config: Optional[QuantizationConfig] = None,
|
145
|
+
bias: bool = False,
|
146
|
+
bias_o_proj: bool = False,
|
147
|
+
prefix: str = "",
|
148
|
+
) -> None:
|
149
|
+
super().__init__()
|
150
|
+
self.layer_id = layer_id
|
151
|
+
self.hidden_size = hidden_size
|
152
|
+
self.use_rope = int((layer_id + 1) % 4 != 0)
|
153
|
+
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
154
|
+
|
155
|
+
self.dp_size = get_attention_dp_size()
|
156
|
+
attn_tp_rank = get_attention_tp_rank()
|
157
|
+
attn_tp_size = get_attention_tp_size()
|
158
|
+
|
159
|
+
self.total_num_heads = num_heads
|
160
|
+
assert self.total_num_heads % attn_tp_size == 0
|
161
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
162
|
+
self.total_num_kv_heads = num_kv_heads
|
163
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
164
|
+
# Number of KV heads is greater than TP size, so we partition
|
165
|
+
# the KV heads across multiple tensor parallel GPUs.
|
166
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
167
|
+
else:
|
168
|
+
# Number of KV heads is less than TP size, so we replicate
|
169
|
+
# the KV heads across multiple tensor parallel GPUs.
|
170
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
171
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
172
|
+
self.head_dim = config.head_dim
|
173
|
+
self.q_size = self.num_heads * self.head_dim
|
174
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
175
|
+
self.scaling = self.head_dim**-0.5
|
176
|
+
self.attn_temperature_tuning = config.attn_temperature_tuning
|
177
|
+
self.floor_scale = config.floor_scale
|
178
|
+
self.attn_scale = config.attn_scale
|
179
|
+
self.rope_theta = rope_theta
|
180
|
+
self.max_position_embeddings = max_position_embeddings
|
181
|
+
self.n_rep = self.num_heads // self.num_kv_heads
|
182
|
+
self.qk_norm = (
|
183
|
+
RMSNorm(
|
184
|
+
hidden_size=self.head_dim,
|
185
|
+
eps=config.rms_norm_eps,
|
186
|
+
)
|
187
|
+
if self.use_qk_norm
|
188
|
+
else None
|
189
|
+
)
|
190
|
+
self.qkv_proj = QKVParallelLinear(
|
191
|
+
hidden_size=hidden_size,
|
192
|
+
head_size=self.head_dim,
|
193
|
+
total_num_heads=self.total_num_heads,
|
194
|
+
total_num_kv_heads=self.total_num_kv_heads,
|
195
|
+
bias=bias,
|
196
|
+
quant_config=quant_config,
|
197
|
+
prefix=add_prefix("qkv_proj", prefix),
|
198
|
+
tp_rank=attn_tp_rank,
|
199
|
+
tp_size=attn_tp_size,
|
200
|
+
)
|
201
|
+
|
202
|
+
self.o_proj = RowParallelLinear(
|
203
|
+
input_size=self.total_num_heads * self.head_dim,
|
204
|
+
output_size=hidden_size,
|
205
|
+
bias=bias_o_proj,
|
206
|
+
quant_config=quant_config,
|
207
|
+
prefix=add_prefix("o_proj", prefix),
|
208
|
+
tp_rank=attn_tp_rank,
|
209
|
+
tp_size=attn_tp_size,
|
210
|
+
reduce_results=False,
|
211
|
+
)
|
212
|
+
is_neox_style = True
|
213
|
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
214
|
+
if is_gguf and config.model_type in ["llama", "llama4"]:
|
215
|
+
is_neox_style = False
|
216
|
+
|
217
|
+
self.rotary_emb = (
|
218
|
+
get_rope(
|
219
|
+
self.head_dim,
|
220
|
+
rotary_dim=self.head_dim,
|
221
|
+
max_position=max_position_embeddings,
|
222
|
+
base=int(rope_theta),
|
223
|
+
rope_scaling=rope_scaling if rope_scaling != "default" else None,
|
224
|
+
is_neox_style=is_neox_style,
|
225
|
+
)
|
226
|
+
if self.use_rope
|
227
|
+
else None
|
228
|
+
)
|
229
|
+
|
230
|
+
self.attn = RadixAttention(
|
231
|
+
self.num_heads,
|
232
|
+
self.head_dim,
|
233
|
+
self.scaling,
|
234
|
+
num_kv_heads=self.num_kv_heads,
|
235
|
+
layer_id=layer_id,
|
236
|
+
prefix=add_prefix("attn", prefix),
|
237
|
+
use_irope=self.use_rope,
|
238
|
+
)
|
239
|
+
|
240
|
+
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
241
|
+
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
242
|
+
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
243
|
+
return attn_scale.unsqueeze(-1)
|
244
|
+
|
245
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
246
|
+
def _mul_attn_scale(self, positions, q):
|
247
|
+
attn_scale = self._get_attn_scale(positions)
|
248
|
+
return (q * attn_scale).to(q.dtype)
|
249
|
+
|
250
|
+
def forward(
|
251
|
+
self,
|
252
|
+
positions: torch.Tensor,
|
253
|
+
hidden_states: torch.Tensor,
|
254
|
+
forward_batch: ForwardBatch,
|
255
|
+
) -> torch.Tensor:
|
256
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
257
|
+
|
258
|
+
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
259
|
+
|
260
|
+
if self.rotary_emb is not None:
|
261
|
+
q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
|
262
|
+
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
|
263
|
+
assert (q_out_unused is q_view) and (k_out_unused is k_view)
|
264
|
+
del q_view, k_view, q_out_unused, k_out_unused
|
265
|
+
|
266
|
+
if self.qk_norm is not None:
|
267
|
+
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
|
268
|
+
qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
|
269
|
+
qk = self.qk_norm(qk).to(torch.bfloat16)
|
270
|
+
qk = qk.reshape(-1, self.q_size + self.kv_size)
|
271
|
+
|
272
|
+
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
|
273
|
+
|
274
|
+
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
275
|
+
# the inference-time temperature tuning function is customized to not affect short context
|
276
|
+
# while working at very long context
|
277
|
+
# https://arxiv.org/abs/2501.19399
|
278
|
+
if self.attn_temperature_tuning and not self.use_rope:
|
279
|
+
q = self._mul_attn_scale(positions=positions, q=q)
|
280
|
+
|
281
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
282
|
+
output, _ = self.o_proj(attn_output)
|
283
|
+
return output
|
284
|
+
|
285
|
+
|
286
|
+
class Llama4DecoderLayer(nn.Module):
|
287
|
+
def __init__(
|
288
|
+
self,
|
289
|
+
config: Llama4TextConfig,
|
290
|
+
layer_id: int = 0,
|
291
|
+
quant_config: Optional[QuantizationConfig] = None,
|
292
|
+
prefix: str = "",
|
293
|
+
):
|
294
|
+
super().__init__()
|
295
|
+
self.layer_id = layer_id
|
296
|
+
self.hidden_size = config.hidden_size
|
297
|
+
rope_theta = config.rope_theta
|
298
|
+
rope_scaling = config.rope_scaling
|
299
|
+
max_position_embeddings = config.max_position_embeddings
|
300
|
+
self.dp_size = get_attention_dp_size()
|
301
|
+
self.attn_tp_size = get_attention_tp_size()
|
302
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
303
|
+
|
304
|
+
self.self_attn = Llama4Attention(
|
305
|
+
config=config,
|
306
|
+
layer_id=layer_id,
|
307
|
+
hidden_size=self.hidden_size,
|
308
|
+
num_heads=config.num_attention_heads,
|
309
|
+
num_kv_heads=config.num_key_value_heads,
|
310
|
+
rope_theta=rope_theta,
|
311
|
+
rope_scaling=rope_scaling,
|
312
|
+
max_position_embeddings=max_position_embeddings,
|
313
|
+
quant_config=quant_config,
|
314
|
+
bias=False,
|
315
|
+
bias_o_proj=False,
|
316
|
+
prefix=add_prefix("self_attn", prefix),
|
317
|
+
)
|
318
|
+
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
|
319
|
+
if is_moe_layer:
|
320
|
+
self.feed_forward = Llama4MoE(
|
321
|
+
config=config,
|
322
|
+
quant_config=quant_config,
|
323
|
+
prefix=add_prefix("feed_forward", prefix),
|
324
|
+
)
|
325
|
+
else:
|
326
|
+
self.feed_forward = LlamaMLP(
|
327
|
+
hidden_size=self.hidden_size,
|
328
|
+
intermediate_size=config.intermediate_size_mlp,
|
329
|
+
hidden_act="silu",
|
330
|
+
quant_config=quant_config,
|
331
|
+
prefix=add_prefix("feed_forward", prefix),
|
332
|
+
)
|
333
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
334
|
+
self.post_attention_layernorm = RMSNorm(
|
335
|
+
config.hidden_size, eps=config.rms_norm_eps
|
336
|
+
)
|
337
|
+
|
338
|
+
def forward(
|
339
|
+
self,
|
340
|
+
positions: torch.Tensor,
|
341
|
+
hidden_states: torch.Tensor,
|
342
|
+
forward_batch: ForwardBatch,
|
343
|
+
residual: Optional[torch.Tensor],
|
344
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
345
|
+
if hidden_states.shape[0] == 0:
|
346
|
+
residual = hidden_states
|
347
|
+
else:
|
348
|
+
# Self Attention
|
349
|
+
if residual is None:
|
350
|
+
residual = hidden_states
|
351
|
+
hidden_states = self.input_layernorm(hidden_states)
|
352
|
+
else:
|
353
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
354
|
+
hidden_states = self.self_attn(
|
355
|
+
positions=positions,
|
356
|
+
hidden_states=hidden_states,
|
357
|
+
forward_batch=forward_batch,
|
358
|
+
)
|
359
|
+
|
360
|
+
# Gather
|
361
|
+
if get_tensor_model_parallel_world_size() > 1:
|
362
|
+
# all gather and all reduce
|
363
|
+
if self.dp_size != 1:
|
364
|
+
if self.attn_tp_rank == 0:
|
365
|
+
hidden_states += residual
|
366
|
+
hidden_states, local_hidden_states = (
|
367
|
+
forward_batch.gathered_buffer,
|
368
|
+
hidden_states,
|
369
|
+
)
|
370
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
371
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
372
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
373
|
+
else:
|
374
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
375
|
+
hidden_states, residual = self.post_attention_layernorm(
|
376
|
+
hidden_states, residual
|
377
|
+
)
|
378
|
+
else:
|
379
|
+
hidden_states, residual = self.post_attention_layernorm(
|
380
|
+
hidden_states, residual
|
381
|
+
)
|
382
|
+
|
383
|
+
# Fully Connected
|
384
|
+
hidden_states = self.feed_forward(hidden_states)
|
385
|
+
|
386
|
+
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
387
|
+
# Scatter
|
388
|
+
if self.dp_size != 1:
|
389
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
390
|
+
# be careful about this!
|
391
|
+
hidden_states, global_hidden_states = (
|
392
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
393
|
+
hidden_states,
|
394
|
+
)
|
395
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
396
|
+
|
397
|
+
return hidden_states, residual
|
398
|
+
|
399
|
+
|
400
|
+
class Llama4Model(nn.Module):
|
401
|
+
def __init__(
|
402
|
+
self,
|
403
|
+
config: Llama4TextConfig,
|
404
|
+
quant_config: Optional[QuantizationConfig] = None,
|
405
|
+
prefix: str = "",
|
406
|
+
) -> None:
|
407
|
+
super().__init__()
|
408
|
+
self.config = config
|
409
|
+
self.padding_idx = config.pad_token_id
|
410
|
+
self.vocab_size = config.vocab_size
|
411
|
+
self.embed_tokens = VocabParallelEmbedding(
|
412
|
+
config.vocab_size,
|
413
|
+
config.hidden_size,
|
414
|
+
quant_config=quant_config,
|
415
|
+
prefix=add_prefix("embed_tokens", prefix),
|
416
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
417
|
+
)
|
418
|
+
self.layers = make_layers(
|
419
|
+
config.num_hidden_layers,
|
420
|
+
lambda idx, prefix: Llama4DecoderLayer(
|
421
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
422
|
+
),
|
423
|
+
prefix=add_prefix("layers", prefix),
|
424
|
+
)
|
425
|
+
|
426
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
427
|
+
self.layers_to_capture = []
|
428
|
+
|
429
|
+
def forward(
|
430
|
+
self,
|
431
|
+
input_ids: torch.Tensor,
|
432
|
+
positions: torch.Tensor,
|
433
|
+
forward_batch: ForwardBatch,
|
434
|
+
input_embeds: torch.Tensor = None,
|
435
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
436
|
+
if input_embeds is None:
|
437
|
+
hidden_states = self.embed_tokens(input_ids)
|
438
|
+
else:
|
439
|
+
hidden_states = input_embeds
|
440
|
+
residual = None
|
441
|
+
aux_hidden_states = []
|
442
|
+
for i in range(len(self.layers)):
|
443
|
+
if i in self.layers_to_capture:
|
444
|
+
aux_hidden_states.append(hidden_states + residual)
|
445
|
+
layer = self.layers[i]
|
446
|
+
hidden_states, residual = layer(
|
447
|
+
positions,
|
448
|
+
hidden_states,
|
449
|
+
forward_batch,
|
450
|
+
residual,
|
451
|
+
)
|
452
|
+
if not forward_batch.forward_mode.is_idle():
|
453
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
454
|
+
|
455
|
+
if len(aux_hidden_states) == 0:
|
456
|
+
return hidden_states
|
457
|
+
|
458
|
+
return hidden_states, aux_hidden_states
|
459
|
+
|
460
|
+
|
461
|
+
class Llama4ForCausalLM(LlamaForCausalLM):
|
462
|
+
packed_modules_mapping = {
|
463
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
464
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
465
|
+
}
|
466
|
+
|
467
|
+
def __init__(
|
468
|
+
self,
|
469
|
+
config: Llama4TextConfig,
|
470
|
+
quant_config: Optional[QuantizationConfig] = None,
|
471
|
+
prefix: str = "",
|
472
|
+
):
|
473
|
+
super().__init__(config, quant_config, prefix)
|
474
|
+
|
475
|
+
def get_input_embeddings(self):
|
476
|
+
return self.model.embed_tokens
|
477
|
+
|
478
|
+
def _init_model(
|
479
|
+
self,
|
480
|
+
config: Llama4TextConfig,
|
481
|
+
quant_config: Optional[QuantizationConfig] = None,
|
482
|
+
prefix: str = "",
|
483
|
+
):
|
484
|
+
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
|
485
|
+
|
486
|
+
|
487
|
+
EntryClass = [Llama4ForCausalLM]
|
sglang/srt/models/minicpm.py
CHANGED
sglang/srt/models/minicpm3.py
CHANGED
@@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
|
192
192
|
self.scaling,
|
193
193
|
num_kv_heads=self.num_local_heads,
|
194
194
|
layer_id=layer_id,
|
195
|
+
quant_config=quant_config,
|
195
196
|
prefix=add_prefix("attn", prefix),
|
196
197
|
)
|
197
198
|
|
@@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
343
344
|
num_kv_heads=1,
|
344
345
|
layer_id=layer_id,
|
345
346
|
v_head_dim=self.kv_lora_rank,
|
347
|
+
quant_config=quant_config,
|
346
348
|
prefix=add_prefix("attn", prefix),
|
347
349
|
)
|
348
350
|
|
sglang/srt/models/mixtral.py
CHANGED
sglang/srt/models/mllama.py
CHANGED
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|
22
22
|
from sglang.srt.layers.linear import (
|
23
23
|
ColumnParallelLinear,
|
24
24
|
QKVParallelLinear,
|
25
|
+
ReplicatedLinear,
|
25
26
|
RowParallelLinear,
|
26
27
|
)
|
27
28
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
184
185
|
def __init__(
|
185
186
|
self,
|
186
187
|
config: config_mllama.MllamaVisionConfig,
|
188
|
+
quant_config: Optional[QuantizationConfig] = None,
|
187
189
|
is_gated: bool = False,
|
188
190
|
prefix: str = "",
|
189
191
|
):
|
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
199
201
|
self.num_attention_heads,
|
200
202
|
self.hidden_size,
|
201
203
|
use_qkv_parallel=True,
|
202
|
-
quant_config=
|
204
|
+
quant_config=quant_config,
|
203
205
|
dropout=0.0,
|
204
206
|
use_context_forward=False,
|
205
207
|
softmax_in_single_precision=False,
|
206
208
|
flatten_batch=False,
|
207
209
|
prefix=add_prefix("self_attn", prefix),
|
208
210
|
)
|
209
|
-
self.mlp = MllamaVisionMLP(
|
211
|
+
self.mlp = MllamaVisionMLP(
|
212
|
+
config, quant_config, prefix=add_prefix("mlp", prefix)
|
213
|
+
)
|
210
214
|
|
211
215
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
212
216
|
self.post_attention_layernorm = nn.LayerNorm(
|
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
|
|
244
248
|
def __init__(
|
245
249
|
self,
|
246
250
|
config: config_mllama.MllamaVisionConfig,
|
251
|
+
quant_config: Optional[QuantizationConfig] = None,
|
247
252
|
num_layers=32,
|
248
253
|
is_gated=False,
|
249
254
|
output_hidden_states=None,
|
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
|
|
254
259
|
self.layers = nn.ModuleList(
|
255
260
|
[
|
256
261
|
MllamaVisionEncoderLayer(
|
257
|
-
config,
|
262
|
+
config,
|
263
|
+
quant_config,
|
264
|
+
is_gated,
|
265
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
258
266
|
)
|
259
267
|
for i in range(num_layers)
|
260
268
|
]
|
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
|
|
283
291
|
|
284
292
|
|
285
293
|
class MllamaVisionModel(nn.Module):
|
286
|
-
def __init__(
|
294
|
+
def __init__(
|
295
|
+
self,
|
296
|
+
config: config_mllama.MllamaVisionConfig,
|
297
|
+
quant_config: Optional[QuantizationConfig] = None,
|
298
|
+
prefix: str = "",
|
299
|
+
):
|
287
300
|
super().__init__()
|
288
301
|
self.image_size = config.image_size
|
289
302
|
self.patch_size = config.patch_size
|
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
|
|
320
333
|
# encoders
|
321
334
|
self.transformer = MllamaVisionEncoder(
|
322
335
|
config,
|
336
|
+
quant_config,
|
323
337
|
config.num_hidden_layers,
|
324
338
|
is_gated=False,
|
325
339
|
output_hidden_states=config.intermediate_layers_indices,
|
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
|
|
327
341
|
)
|
328
342
|
self.global_transformer = MllamaVisionEncoder(
|
329
343
|
config,
|
344
|
+
quant_config,
|
330
345
|
config.num_global_layers,
|
331
346
|
is_gated=True,
|
332
347
|
prefix=add_prefix("global_transformer", prefix),
|
@@ -535,6 +550,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|
535
550
|
self.num_local_key_value_heads,
|
536
551
|
layer_id=layer_id,
|
537
552
|
is_cross_attention=True,
|
553
|
+
quant_config=quant_config,
|
538
554
|
prefix=add_prefix("attn", prefix),
|
539
555
|
)
|
540
556
|
|
@@ -764,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
|
|
764
780
|
|
765
781
|
|
766
782
|
class MllamaForConditionalGeneration(nn.Module):
|
783
|
+
# BitandBytes specific attributes
|
784
|
+
default_bitsandbytes_target_modules = [
|
785
|
+
".gate_proj.",
|
786
|
+
".down_proj.",
|
787
|
+
".up_proj.",
|
788
|
+
".q_proj.",
|
789
|
+
".k_proj.",
|
790
|
+
".v_proj.",
|
791
|
+
".o_proj.",
|
792
|
+
]
|
793
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
794
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
795
|
+
bitsandbytes_stacked_params_mapping = {
|
796
|
+
# shard_name, weight_name, index
|
797
|
+
"q_proj": ("qkv_proj", 0),
|
798
|
+
"k_proj": ("qkv_proj", 1),
|
799
|
+
"v_proj": ("qkv_proj", 2),
|
800
|
+
"gate_proj": ("gate_up_proj", 0),
|
801
|
+
"up_proj": ("gate_up_proj", 1),
|
802
|
+
}
|
803
|
+
|
767
804
|
def __init__(
|
768
805
|
self,
|
769
806
|
config: config_mllama.MllamaConfig,
|
@@ -771,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
771
808
|
prefix: str = "",
|
772
809
|
):
|
773
810
|
super().__init__()
|
811
|
+
self.quant_config = quant_config
|
774
812
|
self.vocab_size = config.text_config.vocab_size
|
775
813
|
self.hidden_size = config.text_config.hidden_size
|
776
814
|
self.max_num_tiles = config.vision_config.max_num_tiles
|
@@ -781,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
781
819
|
self.image_size = config.vision_config.image_size
|
782
820
|
|
783
821
|
self.vision_model = MllamaVisionModel(
|
784
|
-
config.vision_config,
|
822
|
+
config.vision_config,
|
823
|
+
quant_config=quant_config,
|
824
|
+
prefix=add_prefix("vision_model", prefix),
|
785
825
|
)
|
786
826
|
self.language_model = MllamaForCausalLM(
|
787
827
|
config.text_config,
|
788
828
|
quant_config=quant_config,
|
789
829
|
prefix=add_prefix("language_model", prefix),
|
790
830
|
)
|
791
|
-
self.multi_modal_projector =
|
831
|
+
self.multi_modal_projector = ReplicatedLinear(
|
792
832
|
config.vision_config.vision_output_dim,
|
793
833
|
config.text_config.hidden_size,
|
794
834
|
bias=True,
|
835
|
+
quant_config=quant_config,
|
836
|
+
prefix="multi_modal_projector",
|
795
837
|
)
|
796
838
|
self.logits_processor = LogitsProcessor(config.text_config)
|
797
839
|
self.capture_mode = False
|
@@ -958,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
958
1000
|
cross_attention_states = self.vision_model(
|
959
1001
|
batched_images, batched_ar_ids, batched_ar_mask
|
960
1002
|
)
|
961
|
-
cross_attention_states = self.multi_modal_projector(
|
1003
|
+
cross_attention_states, _ = self.multi_modal_projector(
|
1004
|
+
cross_attention_states
|
1005
|
+
)
|
962
1006
|
|
963
1007
|
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
964
1008
|
cross_attention_states = cross_attention_states.view(
|
@@ -1012,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
1012
1056
|
if "vision_model" in name:
|
1013
1057
|
# adapt to VisionAttention
|
1014
1058
|
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
1015
|
-
|
1016
1059
|
param = params_dict.pop(name)
|
1017
1060
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
1018
1061
|
weight_loader(param, loaded_weight)
|