sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,694 @@
|
|
1
|
+
# Copyright 2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
import copy
|
15
|
+
from typing import Iterable, Optional, Set, Tuple
|
16
|
+
|
17
|
+
import einops
|
18
|
+
import torch
|
19
|
+
import torch.nn.functional as F
|
20
|
+
from torch import nn
|
21
|
+
from transformers import (
|
22
|
+
ROPE_INIT_FUNCTIONS,
|
23
|
+
AutoModel,
|
24
|
+
Gemma3TextConfig,
|
25
|
+
PretrainedConfig,
|
26
|
+
PreTrainedModel,
|
27
|
+
)
|
28
|
+
|
29
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
30
|
+
from sglang.srt.layers.activation import GeluAndMul
|
31
|
+
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
32
|
+
from sglang.srt.layers.linear import (
|
33
|
+
MergedColumnParallelLinear,
|
34
|
+
QKVParallelLinear,
|
35
|
+
RowParallelLinear,
|
36
|
+
)
|
37
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
+
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
41
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
|
+
from sglang.srt.model_loader.weight_utils import (
|
44
|
+
default_weight_loader,
|
45
|
+
maybe_remap_kv_scale_name,
|
46
|
+
)
|
47
|
+
from sglang.srt.utils import add_prefix, make_layers
|
48
|
+
|
49
|
+
|
50
|
+
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
51
|
+
# SGLang assumes exclusive
|
52
|
+
def get_attention_sliding_window_size(config):
|
53
|
+
return config.sliding_window - 1
|
54
|
+
|
55
|
+
|
56
|
+
# Adapted from:
|
57
|
+
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
|
58
|
+
def extract_layer_index(prefix: str) -> int:
|
59
|
+
"""Extract the layer index from a prefix string."""
|
60
|
+
parts = prefix.split(".")
|
61
|
+
for part in parts:
|
62
|
+
if part.startswith("layers."):
|
63
|
+
layer_str = part.split(".")[-1]
|
64
|
+
try:
|
65
|
+
return int(layer_str)
|
66
|
+
except ValueError:
|
67
|
+
continue
|
68
|
+
return -1
|
69
|
+
|
70
|
+
|
71
|
+
class Gemma3MLP(nn.Module):
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
hidden_size: int,
|
75
|
+
intermediate_size: int,
|
76
|
+
hidden_activation: str,
|
77
|
+
quant_config: Optional[QuantizationConfig] = None,
|
78
|
+
prefix: str = "",
|
79
|
+
) -> None:
|
80
|
+
super().__init__()
|
81
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
82
|
+
hidden_size,
|
83
|
+
[intermediate_size] * 2,
|
84
|
+
bias=False,
|
85
|
+
quant_config=quant_config,
|
86
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
87
|
+
)
|
88
|
+
self.down_proj = RowParallelLinear(
|
89
|
+
intermediate_size,
|
90
|
+
hidden_size,
|
91
|
+
bias=False,
|
92
|
+
quant_config=quant_config,
|
93
|
+
prefix=add_prefix("down_proj", prefix),
|
94
|
+
)
|
95
|
+
if hidden_activation != "gelu_pytorch_tanh":
|
96
|
+
raise ValueError(
|
97
|
+
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
|
98
|
+
"function. Please set `hidden_activation` to "
|
99
|
+
"`gelu_pytorch_tanh`."
|
100
|
+
)
|
101
|
+
self.act_fn = GeluAndMul()
|
102
|
+
|
103
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
104
|
+
gate_up, _ = self.gate_up_proj(x)
|
105
|
+
x = self.act_fn(gate_up)
|
106
|
+
x, _ = self.down_proj(x)
|
107
|
+
return x
|
108
|
+
|
109
|
+
|
110
|
+
class Gemma3Attention(nn.Module):
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
layer_id: int,
|
114
|
+
config: Gemma3TextConfig,
|
115
|
+
max_position_embeddings: int,
|
116
|
+
quant_config: Optional[QuantizationConfig] = None,
|
117
|
+
prefix: str = "",
|
118
|
+
) -> None:
|
119
|
+
super().__init__()
|
120
|
+
self.layer_id = layer_id
|
121
|
+
self.config = config
|
122
|
+
tp_size = get_tensor_model_parallel_world_size()
|
123
|
+
|
124
|
+
self.total_num_heads = config.num_attention_heads
|
125
|
+
assert self.total_num_heads % tp_size == 0
|
126
|
+
self.num_heads = self.total_num_heads // tp_size
|
127
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
128
|
+
|
129
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
130
|
+
|
131
|
+
if self.total_num_kv_heads >= tp_size:
|
132
|
+
# Number of KV heads is greater than TP size, so we partition
|
133
|
+
# the KV heads across multiple tensor parallel GPUs.
|
134
|
+
assert self.total_num_kv_heads % tp_size == 0
|
135
|
+
else:
|
136
|
+
# Number of KV heads is less than TP size, so we replicate
|
137
|
+
# the KV heads across multiple tensor parallel GPUs.
|
138
|
+
assert tp_size % self.total_num_kv_heads == 0
|
139
|
+
|
140
|
+
hidden_size = config.hidden_size
|
141
|
+
|
142
|
+
head_dim = getattr(
|
143
|
+
config, "head_dim", hidden_size // config.num_attention_heads
|
144
|
+
)
|
145
|
+
self.head_dim = head_dim
|
146
|
+
|
147
|
+
self.q_size = self.num_heads * self.head_dim
|
148
|
+
|
149
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
150
|
+
self.scaling = config.query_pre_attn_scalar**-0.5
|
151
|
+
|
152
|
+
self.qkv_proj = QKVParallelLinear(
|
153
|
+
hidden_size,
|
154
|
+
self.head_dim,
|
155
|
+
self.total_num_heads,
|
156
|
+
self.total_num_kv_heads,
|
157
|
+
bias=config.attention_bias,
|
158
|
+
quant_config=quant_config,
|
159
|
+
prefix=add_prefix("qkv_proj", prefix),
|
160
|
+
)
|
161
|
+
self.o_proj = RowParallelLinear(
|
162
|
+
self.total_num_heads * self.head_dim,
|
163
|
+
hidden_size,
|
164
|
+
bias=config.attention_bias,
|
165
|
+
quant_config=quant_config,
|
166
|
+
prefix=add_prefix("o_proj", prefix),
|
167
|
+
)
|
168
|
+
|
169
|
+
# Determine if layer uses sliding window based on pattern
|
170
|
+
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
|
171
|
+
|
172
|
+
# Initialize the rotary embedding.
|
173
|
+
if self.is_sliding:
|
174
|
+
# Local attention. Override the values in config.json.
|
175
|
+
self.rope_theta = config.rope_local_base_freq
|
176
|
+
self.rope_scaling = {"rope_type": "default"}
|
177
|
+
# FIXME(mick): idk why vllm does this
|
178
|
+
# self.sliding_window = config.interleaved_sliding_window
|
179
|
+
self.sliding_window = get_attention_sliding_window_size(config)
|
180
|
+
else:
|
181
|
+
# Global attention. Use the values in config.json.
|
182
|
+
self.rope_theta = config.rope_theta
|
183
|
+
self.rope_scaling = config.rope_scaling
|
184
|
+
self.sliding_window = None
|
185
|
+
|
186
|
+
self.attn = RadixAttention(
|
187
|
+
self.num_heads,
|
188
|
+
self.head_dim,
|
189
|
+
self.scaling,
|
190
|
+
num_kv_heads=self.num_kv_heads,
|
191
|
+
layer_id=layer_id,
|
192
|
+
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
|
193
|
+
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
194
|
+
# attention backend in `ForwardBatch`.
|
195
|
+
sliding_window_size=self.sliding_window,
|
196
|
+
prefix=add_prefix("attn", prefix),
|
197
|
+
)
|
198
|
+
|
199
|
+
# Gemma3 adds normalization for q and k
|
200
|
+
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
201
|
+
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
202
|
+
|
203
|
+
def naive_attn_with_masks(
|
204
|
+
self,
|
205
|
+
q: torch.Tensor,
|
206
|
+
k: torch.Tensor,
|
207
|
+
v: torch.Tensor,
|
208
|
+
out: torch.Tensor,
|
209
|
+
**kwargs,
|
210
|
+
) -> torch.Tensor:
|
211
|
+
q = q.view(-1, self.num_heads, self.head_dim)
|
212
|
+
# Expand the key and value to handle GQA.
|
213
|
+
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
214
|
+
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
215
|
+
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
216
|
+
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
217
|
+
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
218
|
+
|
219
|
+
if self.is_sliding:
|
220
|
+
attn_masks = kwargs["local_attn_masks"]
|
221
|
+
else:
|
222
|
+
attn_masks = kwargs["global_attn_masks"]
|
223
|
+
|
224
|
+
seq_lens = kwargs["seq_lens"]
|
225
|
+
start_idx = 0
|
226
|
+
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
227
|
+
end_idx = start_idx + seq_len
|
228
|
+
query = q[start_idx:end_idx].unsqueeze(0)
|
229
|
+
key = k[start_idx:end_idx].unsqueeze(0)
|
230
|
+
value = v[start_idx:end_idx].unsqueeze(0)
|
231
|
+
|
232
|
+
# Transpose.
|
233
|
+
query = query.transpose(1, 2)
|
234
|
+
key = key.transpose(1, 2)
|
235
|
+
value = value.transpose(1, 2)
|
236
|
+
|
237
|
+
output = F.scaled_dot_product_attention(
|
238
|
+
query,
|
239
|
+
key,
|
240
|
+
value,
|
241
|
+
attn_mask,
|
242
|
+
self.scaling,
|
243
|
+
)
|
244
|
+
output = output.transpose(1, 2).flatten(-2, -1)
|
245
|
+
out[start_idx:end_idx] = output
|
246
|
+
start_idx = end_idx
|
247
|
+
return out
|
248
|
+
|
249
|
+
def forward(
|
250
|
+
self,
|
251
|
+
hidden_states: torch.Tensor,
|
252
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
253
|
+
forward_batch: ForwardBatch,
|
254
|
+
**kwargs,
|
255
|
+
) -> torch.Tensor:
|
256
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
257
|
+
# [s, h * head_dim]
|
258
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
259
|
+
|
260
|
+
# [s, h, head_dim]
|
261
|
+
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
262
|
+
# -> [h, s, head_dim]
|
263
|
+
q = q.transpose(0, 1).unsqueeze(0)
|
264
|
+
q = self.q_norm(q)
|
265
|
+
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
266
|
+
# -> [h, s, head_dim]
|
267
|
+
k = k.transpose(0, 1).unsqueeze(0)
|
268
|
+
k = self.k_norm(k)
|
269
|
+
|
270
|
+
# q, k = self.rotary_emb(positions, q, k)
|
271
|
+
cos, sin = position_embeddings
|
272
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
273
|
+
|
274
|
+
# [b, h, s, head_dim] -> [b, s, h, head_dim]
|
275
|
+
q = q.permute(0, 2, 1, 3)
|
276
|
+
k = k.permute(0, 2, 1, 3)
|
277
|
+
|
278
|
+
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
|
279
|
+
output, _ = self.o_proj(attn_output)
|
280
|
+
return output
|
281
|
+
|
282
|
+
|
283
|
+
class Gemma3DecoderLayer(nn.Module):
|
284
|
+
def __init__(
|
285
|
+
self,
|
286
|
+
layer_id: int,
|
287
|
+
config: PretrainedConfig,
|
288
|
+
quant_config: Optional[QuantizationConfig] = None,
|
289
|
+
prefix: str = "",
|
290
|
+
) -> None:
|
291
|
+
super().__init__()
|
292
|
+
self.hidden_size = config.hidden_size
|
293
|
+
self.self_attn = Gemma3Attention(
|
294
|
+
layer_id=layer_id,
|
295
|
+
config=config,
|
296
|
+
max_position_embeddings=config.max_position_embeddings,
|
297
|
+
quant_config=quant_config,
|
298
|
+
prefix=add_prefix("self_attn", prefix),
|
299
|
+
)
|
300
|
+
self.hidden_size = config.hidden_size
|
301
|
+
self.mlp = Gemma3MLP(
|
302
|
+
hidden_size=self.hidden_size,
|
303
|
+
intermediate_size=config.intermediate_size,
|
304
|
+
hidden_activation=config.hidden_activation,
|
305
|
+
quant_config=quant_config,
|
306
|
+
prefix=add_prefix("mlp", prefix),
|
307
|
+
)
|
308
|
+
self.input_layernorm = Gemma3RMSNorm(
|
309
|
+
config.hidden_size, eps=config.rms_norm_eps
|
310
|
+
)
|
311
|
+
self.post_attention_layernorm = Gemma3RMSNorm(
|
312
|
+
config.hidden_size, eps=config.rms_norm_eps
|
313
|
+
)
|
314
|
+
self.pre_feedforward_layernorm = Gemma3RMSNorm(
|
315
|
+
config.hidden_size, eps=config.rms_norm_eps
|
316
|
+
)
|
317
|
+
self.post_feedforward_layernorm = Gemma3RMSNorm(
|
318
|
+
config.hidden_size, eps=config.rms_norm_eps
|
319
|
+
)
|
320
|
+
self.is_sliding = self.self_attn.is_sliding
|
321
|
+
self.layer_id = layer_id
|
322
|
+
|
323
|
+
def forward(
|
324
|
+
self,
|
325
|
+
positions: torch.Tensor,
|
326
|
+
hidden_states: torch.Tensor,
|
327
|
+
position_embeddings_global: torch.Tensor,
|
328
|
+
position_embeddings_local: torch.Tensor,
|
329
|
+
forward_batch: ForwardBatch,
|
330
|
+
**kwargs,
|
331
|
+
) -> tuple[
|
332
|
+
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
|
333
|
+
]:
|
334
|
+
residual = hidden_states
|
335
|
+
hidden_states = self.input_layernorm(hidden_states)
|
336
|
+
|
337
|
+
# apply global RoPE to non-sliding layer only
|
338
|
+
if self.self_attn.is_sliding:
|
339
|
+
position_embeddings = position_embeddings_local
|
340
|
+
else:
|
341
|
+
position_embeddings = position_embeddings_global
|
342
|
+
|
343
|
+
hidden_states = self.self_attn(
|
344
|
+
positions=positions,
|
345
|
+
hidden_states=hidden_states,
|
346
|
+
position_embeddings=position_embeddings,
|
347
|
+
forward_batch=forward_batch,
|
348
|
+
**kwargs,
|
349
|
+
)
|
350
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
351
|
+
hidden_states = residual + hidden_states
|
352
|
+
|
353
|
+
residual = hidden_states
|
354
|
+
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
355
|
+
hidden_states = self.mlp(hidden_states)
|
356
|
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
357
|
+
hidden_states = residual + hidden_states
|
358
|
+
|
359
|
+
outputs = (hidden_states,)
|
360
|
+
|
361
|
+
return outputs
|
362
|
+
|
363
|
+
|
364
|
+
class Gemma3RotaryEmbedding(nn.Module):
|
365
|
+
def __init__(self, config: Gemma3TextConfig, device=None):
|
366
|
+
super().__init__()
|
367
|
+
# BC: "rope_type" was originally "type"
|
368
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
369
|
+
self.rope_type = config.rope_scaling.get(
|
370
|
+
"rope_type", config.rope_scaling.get("type")
|
371
|
+
)
|
372
|
+
else:
|
373
|
+
self.rope_type = "default"
|
374
|
+
self.max_seq_len_cached = config.max_position_embeddings
|
375
|
+
self.original_max_seq_len = config.max_position_embeddings
|
376
|
+
|
377
|
+
self.config = config
|
378
|
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
379
|
+
|
380
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
381
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
382
|
+
self.original_inv_freq = self.inv_freq
|
383
|
+
|
384
|
+
def _dynamic_frequency_update(self, position_ids, device):
|
385
|
+
"""
|
386
|
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
387
|
+
1 - growing beyond the cached sequence length (allow scaling)
|
388
|
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
389
|
+
"""
|
390
|
+
seq_len = torch.max(position_ids) + 1
|
391
|
+
if seq_len > self.max_seq_len_cached: # growth
|
392
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
393
|
+
self.config, device, seq_len=seq_len
|
394
|
+
)
|
395
|
+
self.register_buffer(
|
396
|
+
"inv_freq", inv_freq, persistent=False
|
397
|
+
) # TODO joao: may break with compilation
|
398
|
+
self.max_seq_len_cached = seq_len
|
399
|
+
|
400
|
+
if (
|
401
|
+
seq_len < self.original_max_seq_len
|
402
|
+
and self.max_seq_len_cached > self.original_max_seq_len
|
403
|
+
): # reset
|
404
|
+
# This .to() is needed if the model has been moved to a device after being initialized (because
|
405
|
+
# the buffer is automatically moved, but not the original copy)
|
406
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
407
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
408
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
409
|
+
|
410
|
+
@torch.no_grad()
|
411
|
+
def forward(self, x, position_ids):
|
412
|
+
if "dynamic" in self.rope_type:
|
413
|
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
414
|
+
|
415
|
+
# Core RoPE block
|
416
|
+
inv_freq_expanded = (
|
417
|
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
418
|
+
)
|
419
|
+
position_ids_expanded = position_ids[:, None, :].float()
|
420
|
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
421
|
+
device_type = x.device.type
|
422
|
+
device_type = (
|
423
|
+
device_type
|
424
|
+
if isinstance(device_type, str) and device_type != "mps"
|
425
|
+
else "cpu"
|
426
|
+
)
|
427
|
+
with torch.autocast(device_type=device_type, enabled=False):
|
428
|
+
freqs = (
|
429
|
+
inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
|
430
|
+
).transpose(1, 2)
|
431
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
432
|
+
cos = emb.cos()
|
433
|
+
sin = emb.sin()
|
434
|
+
|
435
|
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
436
|
+
cos = cos * self.attention_scaling
|
437
|
+
sin = sin * self.attention_scaling
|
438
|
+
|
439
|
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
440
|
+
|
441
|
+
|
442
|
+
class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
443
|
+
"""
|
444
|
+
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
445
|
+
"""
|
446
|
+
|
447
|
+
def __init__(
|
448
|
+
self,
|
449
|
+
num_embeddings: int,
|
450
|
+
embedding_dim: int,
|
451
|
+
padding_idx: int,
|
452
|
+
embed_scale: Optional[float] = 1.0,
|
453
|
+
):
|
454
|
+
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
455
|
+
self.embed_scale = embed_scale
|
456
|
+
|
457
|
+
def forward(self, input_ids: torch.Tensor):
|
458
|
+
return super().forward(input_ids) * self.embed_scale
|
459
|
+
|
460
|
+
|
461
|
+
class Gemma3TextModel(PreTrainedModel):
|
462
|
+
def __init__(
|
463
|
+
self,
|
464
|
+
config: Gemma3TextConfig,
|
465
|
+
quant_config: Optional[QuantizationConfig] = None,
|
466
|
+
prefix: str = "",
|
467
|
+
) -> None:
|
468
|
+
super().__init__(config=config)
|
469
|
+
self.config = config
|
470
|
+
self.quant_config = quant_config
|
471
|
+
|
472
|
+
self.padding_idx = config.pad_token_id
|
473
|
+
self.vocab_size = config.vocab_size
|
474
|
+
|
475
|
+
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
|
476
|
+
self.embed_tokens = Gemma3TextScaledWordEmbedding(
|
477
|
+
config.vocab_size,
|
478
|
+
config.hidden_size,
|
479
|
+
self.padding_idx,
|
480
|
+
embed_scale=self.config.hidden_size**0.5,
|
481
|
+
)
|
482
|
+
|
483
|
+
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
484
|
+
self.rotary_emb = Gemma3RotaryEmbedding(config=config)
|
485
|
+
self.gradient_checkpointing = False
|
486
|
+
|
487
|
+
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
|
488
|
+
config = copy.deepcopy(config)
|
489
|
+
config.rope_theta = config.rope_local_base_freq
|
490
|
+
config.rope_scaling = {"rope_type": "default"}
|
491
|
+
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
|
492
|
+
|
493
|
+
self.layers = make_layers(
|
494
|
+
config.num_hidden_layers,
|
495
|
+
lambda idx, prefix: Gemma3DecoderLayer(
|
496
|
+
layer_id=idx,
|
497
|
+
config=config,
|
498
|
+
quant_config=quant_config,
|
499
|
+
prefix=prefix,
|
500
|
+
),
|
501
|
+
prefix=add_prefix("layers", prefix),
|
502
|
+
)
|
503
|
+
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
504
|
+
self.post_init()
|
505
|
+
|
506
|
+
def forward(
|
507
|
+
self,
|
508
|
+
input_ids: torch.Tensor,
|
509
|
+
positions: torch.Tensor,
|
510
|
+
forward_batch: ForwardBatch,
|
511
|
+
input_embeds: torch.Tensor = None,
|
512
|
+
**kwargs,
|
513
|
+
) -> torch.Tensor:
|
514
|
+
if input_embeds is None:
|
515
|
+
hidden_states = self.embed_tokens(input_ids)
|
516
|
+
else:
|
517
|
+
hidden_states = input_embeds
|
518
|
+
|
519
|
+
if positions.dim() == 1:
|
520
|
+
positions = einops.rearrange(positions, "s -> 1 s")
|
521
|
+
|
522
|
+
position_embeddings_global = self.rotary_emb(hidden_states, positions)
|
523
|
+
position_embeddings_local = self.rotary_emb_local(hidden_states, positions)
|
524
|
+
for layer in self.layers:
|
525
|
+
layer_outputs = layer(
|
526
|
+
positions=positions,
|
527
|
+
position_embeddings_global=position_embeddings_global,
|
528
|
+
position_embeddings_local=position_embeddings_local,
|
529
|
+
hidden_states=hidden_states,
|
530
|
+
forward_batch=forward_batch,
|
531
|
+
**kwargs,
|
532
|
+
)
|
533
|
+
hidden_states = layer_outputs[0]
|
534
|
+
|
535
|
+
hidden_states = self.norm(hidden_states)
|
536
|
+
|
537
|
+
return hidden_states
|
538
|
+
|
539
|
+
|
540
|
+
class Gemma3ForCausalLM(PreTrainedModel):
|
541
|
+
config_class = Gemma3TextConfig
|
542
|
+
|
543
|
+
_tied_weights_keys = ["lm_head.weight"]
|
544
|
+
_tp_plan = {"lm_head": "colwise_rep"}
|
545
|
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
546
|
+
config_class = Gemma3TextConfig
|
547
|
+
base_model_prefix = "language_model"
|
548
|
+
|
549
|
+
# BitandBytes specific attributes
|
550
|
+
default_bitsandbytes_target_modules = [
|
551
|
+
".gate_proj.",
|
552
|
+
".down_proj.",
|
553
|
+
".up_proj.",
|
554
|
+
".q_proj.",
|
555
|
+
".k_proj.",
|
556
|
+
".v_proj.",
|
557
|
+
".o_proj.",
|
558
|
+
]
|
559
|
+
bitsandbytes_stacked_params_mapping = {
|
560
|
+
# shard_name, weight_name, index
|
561
|
+
"q_proj": ("qkv_proj", 0),
|
562
|
+
"k_proj": ("qkv_proj", 1),
|
563
|
+
"v_proj": ("qkv_proj", 2),
|
564
|
+
"gate_proj": ("gate_up_proj", 0),
|
565
|
+
"up_proj": ("gate_up_proj", 1),
|
566
|
+
}
|
567
|
+
|
568
|
+
packed_modules_mapping = {
|
569
|
+
"qkv_proj": [
|
570
|
+
"q_proj",
|
571
|
+
"k_proj",
|
572
|
+
"v_proj",
|
573
|
+
],
|
574
|
+
"gate_up_proj": [
|
575
|
+
"gate_proj",
|
576
|
+
"up_proj",
|
577
|
+
],
|
578
|
+
}
|
579
|
+
|
580
|
+
# LoRA specific attributes
|
581
|
+
supported_lora_modules = [
|
582
|
+
"qkv_proj",
|
583
|
+
"o_proj",
|
584
|
+
"gate_up_proj",
|
585
|
+
"down_proj",
|
586
|
+
]
|
587
|
+
# Gemma does not apply LoRA to the embedding layer.
|
588
|
+
embedding_modules = {}
|
589
|
+
embedding_padding_modules = []
|
590
|
+
supports_lora = True
|
591
|
+
|
592
|
+
def __init__(
|
593
|
+
self,
|
594
|
+
config: Gemma3TextConfig,
|
595
|
+
quant_config: Optional[QuantizationConfig] = None,
|
596
|
+
prefix: str = "",
|
597
|
+
) -> None:
|
598
|
+
super().__init__(config=config)
|
599
|
+
self.config = config
|
600
|
+
self.quant_config = quant_config
|
601
|
+
self.model = Gemma3TextModel(
|
602
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
603
|
+
)
|
604
|
+
self.logits_processor = LogitsProcessor(config)
|
605
|
+
|
606
|
+
if self.config.tie_word_embeddings:
|
607
|
+
self.lm_head = self.model.embed_tokens
|
608
|
+
else:
|
609
|
+
self.lm_head = ParallelLMHead(
|
610
|
+
config.vocab_size,
|
611
|
+
config.hidden_size,
|
612
|
+
quant_config=quant_config,
|
613
|
+
prefix=add_prefix("lm_head", prefix),
|
614
|
+
)
|
615
|
+
self.post_init()
|
616
|
+
|
617
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
618
|
+
return self.model.embed_tokens
|
619
|
+
|
620
|
+
def get_attention_sliding_window_size(self):
|
621
|
+
return get_attention_sliding_window_size(self.config)
|
622
|
+
|
623
|
+
def dtype(self) -> torch.dtype:
|
624
|
+
return next(self.parameters()).dtype
|
625
|
+
|
626
|
+
@torch.no_grad()
|
627
|
+
def forward(
|
628
|
+
self,
|
629
|
+
input_ids: torch.Tensor,
|
630
|
+
positions: torch.Tensor,
|
631
|
+
forward_batch: ForwardBatch,
|
632
|
+
input_embeds: torch.Tensor = None,
|
633
|
+
**kwargs,
|
634
|
+
) -> LogitsProcessor:
|
635
|
+
hidden_states = self.model(
|
636
|
+
input_ids, positions, forward_batch, input_embeds, **kwargs
|
637
|
+
)
|
638
|
+
|
639
|
+
return self.logits_processor(
|
640
|
+
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
641
|
+
)
|
642
|
+
|
643
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
644
|
+
stacked_params_mapping = [
|
645
|
+
# (param_name, shard_name, shard_id)
|
646
|
+
("qkv_proj", "q_proj", "q"),
|
647
|
+
("qkv_proj", "k_proj", "k"),
|
648
|
+
("qkv_proj", "v_proj", "v"),
|
649
|
+
("gate_up_proj", "gate_proj", 0),
|
650
|
+
("gate_up_proj", "up_proj", 1),
|
651
|
+
]
|
652
|
+
params_dict = dict(self.named_parameters())
|
653
|
+
loaded_params: Set[str] = set()
|
654
|
+
for name, loaded_weight in weights:
|
655
|
+
for param_name, shard_name, shard_id in stacked_params_mapping:
|
656
|
+
# if param_name in name:
|
657
|
+
# print(f"{param_name} is already in {name}")
|
658
|
+
if shard_name not in name:
|
659
|
+
continue
|
660
|
+
name = name.replace(shard_name, param_name)
|
661
|
+
# Skip loading extra bias for GPTQ models.
|
662
|
+
if name.endswith(".bias") and name not in params_dict:
|
663
|
+
continue
|
664
|
+
param = params_dict[name]
|
665
|
+
weight_loader = param.weight_loader
|
666
|
+
weight_loader(param, loaded_weight, shard_id)
|
667
|
+
break
|
668
|
+
else:
|
669
|
+
# lm_head is not used in vllm as it is tied with embed_token.
|
670
|
+
# To prevent errors, skip loading lm_head.weight.
|
671
|
+
if "lm_head.weight" in name:
|
672
|
+
continue
|
673
|
+
# Skip loading extra bias for GPTQ models.
|
674
|
+
if name.endswith(".bias") and name not in params_dict:
|
675
|
+
continue
|
676
|
+
# Remapping the name of FP8 kv-scale.
|
677
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
678
|
+
if name is None:
|
679
|
+
continue
|
680
|
+
|
681
|
+
param = params_dict[name]
|
682
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
683
|
+
weight_loader(param, loaded_weight)
|
684
|
+
loaded_params.add(name)
|
685
|
+
# unloaded_params = params_dict.keys() - loaded_params
|
686
|
+
# if unloaded_params:
|
687
|
+
# logger.warning(
|
688
|
+
# "Some weights are not initialized from checkpoints: %s", unloaded_params
|
689
|
+
# )
|
690
|
+
return loaded_params
|
691
|
+
|
692
|
+
|
693
|
+
EntryClass = Gemma3ForCausalLM
|
694
|
+
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
|