sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,406 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
|
3
|
+
# coding=utf-8
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.distributed import (
|
10
|
+
get_tensor_model_parallel_rank,
|
11
|
+
get_tensor_model_parallel_world_size,
|
12
|
+
tensor_model_parallel_all_reduce,
|
13
|
+
)
|
14
|
+
from vllm.model_executor.layers.fused_moe import fused_moe
|
15
|
+
from vllm.model_executor.layers.linear import (
|
16
|
+
QKVParallelLinear,
|
17
|
+
ReplicatedLinear,
|
18
|
+
RowParallelLinear,
|
19
|
+
)
|
20
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
21
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
23
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
24
|
+
ParallelLMHead,
|
25
|
+
VocabParallelEmbedding,
|
26
|
+
)
|
27
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
|
+
from vllm.model_executor.utils import set_weight_attrs
|
29
|
+
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
30
|
+
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
34
|
+
|
35
|
+
|
36
|
+
class DbrxRouter(nn.Module):
|
37
|
+
"""A Router implementation for DBRX that returns logits for each expert
|
38
|
+
per token.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: DbrxConfig,
|
44
|
+
params_dtype: Optional[torch.dtype] = None,
|
45
|
+
):
|
46
|
+
super().__init__()
|
47
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
48
|
+
self.num_total_experts = config.ffn_config.moe_num_experts
|
49
|
+
self.d_model = config.d_model
|
50
|
+
self.layer = ReplicatedLinear(
|
51
|
+
self.d_model,
|
52
|
+
self.num_total_experts,
|
53
|
+
bias=False,
|
54
|
+
params_dtype=params_dtype,
|
55
|
+
quant_config=None,
|
56
|
+
)
|
57
|
+
|
58
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
59
|
+
router_logits, _ = self.layer(hidden_states)
|
60
|
+
return router_logits
|
61
|
+
|
62
|
+
|
63
|
+
class DbrxExperts(nn.Module):
|
64
|
+
"""A tensor-parallel MoE implementation for DBRX.
|
65
|
+
|
66
|
+
Each expert's weights are sharded across all ranks and a fused MoE
|
67
|
+
kernel is used for the forward pass, and finally we reduce the outputs
|
68
|
+
across ranks.
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
config: DbrxConfig,
|
74
|
+
quant_config: Optional[QuantizationConfig] = None,
|
75
|
+
params_dtype: Optional[torch.dtype] = None,
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
79
|
+
self.num_total_experts = config.ffn_config.moe_num_experts
|
80
|
+
self.top_k = config.ffn_config.moe_top_k
|
81
|
+
self.d_model = config.d_model
|
82
|
+
self.intermediate_size = config.ffn_config.ffn_hidden_size // self.tp_size
|
83
|
+
|
84
|
+
if params_dtype is None:
|
85
|
+
params_dtype = torch.get_default_dtype()
|
86
|
+
self.params_dtype = params_dtype
|
87
|
+
|
88
|
+
self.router = DbrxRouter(config, self.params_dtype)
|
89
|
+
self.ws = nn.Parameter(
|
90
|
+
torch.empty(
|
91
|
+
self.num_total_experts,
|
92
|
+
2 * self.intermediate_size,
|
93
|
+
self.d_model,
|
94
|
+
device="cuda",
|
95
|
+
dtype=self.params_dtype,
|
96
|
+
)
|
97
|
+
)
|
98
|
+
self.w2s = nn.Parameter(
|
99
|
+
torch.empty(
|
100
|
+
self.num_total_experts,
|
101
|
+
self.d_model,
|
102
|
+
self.intermediate_size,
|
103
|
+
device="cuda",
|
104
|
+
dtype=self.params_dtype,
|
105
|
+
)
|
106
|
+
)
|
107
|
+
|
108
|
+
set_weight_attrs(
|
109
|
+
self.ws,
|
110
|
+
{
|
111
|
+
"weight_loader": self.weight_loader,
|
112
|
+
},
|
113
|
+
)
|
114
|
+
set_weight_attrs(
|
115
|
+
self.w2s,
|
116
|
+
{
|
117
|
+
"weight_loader": self.weight_loader,
|
118
|
+
},
|
119
|
+
)
|
120
|
+
|
121
|
+
def weight_loader(
|
122
|
+
self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str
|
123
|
+
):
|
124
|
+
tp_rank = get_tensor_model_parallel_rank()
|
125
|
+
param_data = param.data
|
126
|
+
shard_size = self.intermediate_size
|
127
|
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
128
|
+
# DBRX uses GLU for each experts.
|
129
|
+
# GLU has 3 linear layers: w1, v1 and w2.
|
130
|
+
if weight_name.endswith("w1"):
|
131
|
+
loaded_weight = torch.reshape(
|
132
|
+
loaded_weight,
|
133
|
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
134
|
+
)
|
135
|
+
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
136
|
+
if weight_name.endswith("v1"):
|
137
|
+
loaded_weight = torch.reshape(
|
138
|
+
loaded_weight,
|
139
|
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
140
|
+
)
|
141
|
+
param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[:, shard, :]
|
142
|
+
if weight_name.endswith("w2"):
|
143
|
+
loaded_weight = torch.reshape(
|
144
|
+
loaded_weight,
|
145
|
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
146
|
+
).transpose(1, 2)
|
147
|
+
param_data[:] = loaded_weight[:, :, shard]
|
148
|
+
|
149
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
150
|
+
num_tokens, hidden_size = hidden_states.shape
|
151
|
+
hidden_states = hidden_states.view(-1, self.d_model)
|
152
|
+
# router_logits: (num_tokens, n_experts)
|
153
|
+
router_logits = self.router(hidden_states)
|
154
|
+
final_hidden_states = fused_moe(
|
155
|
+
hidden_states,
|
156
|
+
self.ws,
|
157
|
+
self.w2s,
|
158
|
+
router_logits,
|
159
|
+
self.top_k,
|
160
|
+
renormalize=True,
|
161
|
+
inplace=True,
|
162
|
+
)
|
163
|
+
|
164
|
+
if self.tp_size > 1:
|
165
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
166
|
+
|
167
|
+
return final_hidden_states.view(num_tokens, hidden_size)
|
168
|
+
|
169
|
+
|
170
|
+
class DbrxAttention(nn.Module):
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
config: DbrxConfig,
|
174
|
+
layer_id: int = 0,
|
175
|
+
quant_config: Optional[QuantizationConfig] = None,
|
176
|
+
):
|
177
|
+
super().__init__()
|
178
|
+
self.d_model = config.d_model
|
179
|
+
self.total_num_heads = config.n_heads
|
180
|
+
self.head_dim = self.d_model // self.total_num_heads
|
181
|
+
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
182
|
+
self.clip_qkv = config.attn_config.clip_qkv
|
183
|
+
self.rope_theta = config.attn_config.rope_theta
|
184
|
+
self.max_position = config.max_seq_len
|
185
|
+
|
186
|
+
# pylint: disable=invalid-name
|
187
|
+
self.Wqkv = QKVParallelLinear(
|
188
|
+
self.d_model,
|
189
|
+
self.head_dim,
|
190
|
+
self.total_num_heads,
|
191
|
+
self.total_num_kv_heads,
|
192
|
+
bias=False,
|
193
|
+
quant_config=quant_config,
|
194
|
+
)
|
195
|
+
self.out_proj = RowParallelLinear(
|
196
|
+
self.d_model,
|
197
|
+
self.d_model,
|
198
|
+
bias=False,
|
199
|
+
quant_config=quant_config,
|
200
|
+
)
|
201
|
+
self.rotary_emb = get_rope(
|
202
|
+
self.head_dim,
|
203
|
+
rotary_dim=self.head_dim,
|
204
|
+
max_position=self.max_position,
|
205
|
+
base=int(self.rope_theta),
|
206
|
+
is_neox_style=True,
|
207
|
+
)
|
208
|
+
|
209
|
+
tp_world_size = get_tensor_model_parallel_world_size()
|
210
|
+
self.tp_size = tp_world_size
|
211
|
+
assert self.total_num_heads % tp_world_size == 0
|
212
|
+
self.num_heads = self.total_num_heads // tp_world_size
|
213
|
+
if self.total_num_kv_heads >= tp_world_size:
|
214
|
+
# Number of KV heads is greater than TP size, so we partition
|
215
|
+
# the KV heads across multiple tensor parallel GPUs.
|
216
|
+
assert self.total_num_kv_heads % tp_world_size == 0
|
217
|
+
else:
|
218
|
+
# Number of KV heads is less than TP size, so we replicate
|
219
|
+
# the KV heads across multiple tensor parallel GPUs.
|
220
|
+
assert tp_world_size % self.total_num_kv_heads == 0
|
221
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
222
|
+
self.q_size = self.num_heads * self.head_dim
|
223
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
224
|
+
self.scaling = self.head_dim**-0.5
|
225
|
+
self.attn = RadixAttention(
|
226
|
+
self.num_heads,
|
227
|
+
self.head_dim,
|
228
|
+
self.scaling,
|
229
|
+
num_kv_heads=self.num_kv_heads,
|
230
|
+
layer_id=layer_id,
|
231
|
+
)
|
232
|
+
|
233
|
+
def forward(
|
234
|
+
self,
|
235
|
+
position_ids: torch.Tensor,
|
236
|
+
hidden_states: torch.Tensor,
|
237
|
+
input_metadata: InputMetadata,
|
238
|
+
) -> torch.Tensor:
|
239
|
+
qkv, _ = self.Wqkv(hidden_states)
|
240
|
+
if self.clip_qkv is not None:
|
241
|
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
242
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
243
|
+
q, k = self.rotary_emb(position_ids, q, k)
|
244
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
245
|
+
hidden_states, _ = self.out_proj(attn_output)
|
246
|
+
return hidden_states
|
247
|
+
|
248
|
+
|
249
|
+
class DbrxFusedNormAttention(nn.Module):
|
250
|
+
def __init__(
|
251
|
+
self,
|
252
|
+
config: DbrxConfig,
|
253
|
+
layer_id: int = 0,
|
254
|
+
quant_config: Optional[QuantizationConfig] = None,
|
255
|
+
):
|
256
|
+
super().__init__()
|
257
|
+
self.d_model = config.d_model
|
258
|
+
self.attn = DbrxAttention(config, layer_id, quant_config=quant_config)
|
259
|
+
self.norm_1 = nn.LayerNorm(self.d_model)
|
260
|
+
self.norm_2 = nn.LayerNorm(self.d_model)
|
261
|
+
|
262
|
+
def forward(
|
263
|
+
self,
|
264
|
+
position_ids: torch.Tensor,
|
265
|
+
hidden_states: torch.Tensor,
|
266
|
+
input_metadata: InputMetadata,
|
267
|
+
) -> torch.Tensor:
|
268
|
+
residual = hidden_states
|
269
|
+
hidden_states = self.norm_1(hidden_states)
|
270
|
+
x = self.attn(
|
271
|
+
position_ids=position_ids,
|
272
|
+
hidden_states=hidden_states,
|
273
|
+
input_metadata=input_metadata,
|
274
|
+
)
|
275
|
+
hidden_states = residual + x
|
276
|
+
residual = hidden_states
|
277
|
+
hidden_states = self.norm_2(hidden_states)
|
278
|
+
return hidden_states, residual
|
279
|
+
|
280
|
+
|
281
|
+
class DbrxBlock(nn.Module):
|
282
|
+
def __init__(
|
283
|
+
self,
|
284
|
+
config: DbrxConfig,
|
285
|
+
layer_id: int = 0,
|
286
|
+
quant_config: Optional[QuantizationConfig] = None,
|
287
|
+
):
|
288
|
+
super().__init__()
|
289
|
+
self.norm_attn_norm = DbrxFusedNormAttention(
|
290
|
+
config, layer_id, quant_config=quant_config
|
291
|
+
)
|
292
|
+
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
293
|
+
|
294
|
+
def forward(
|
295
|
+
self,
|
296
|
+
position_ids: torch.Tensor,
|
297
|
+
hidden_states: torch.Tensor,
|
298
|
+
input_metadata: InputMetadata,
|
299
|
+
) -> torch.Tensor:
|
300
|
+
hidden_states, residual = self.norm_attn_norm(
|
301
|
+
position_ids=position_ids,
|
302
|
+
hidden_states=hidden_states,
|
303
|
+
input_metadata=input_metadata,
|
304
|
+
)
|
305
|
+
hidden_states = self.ffn(hidden_states)
|
306
|
+
hidden_states = hidden_states + residual
|
307
|
+
return hidden_states
|
308
|
+
|
309
|
+
|
310
|
+
class DbrxModel(nn.Module):
|
311
|
+
def __init__(
|
312
|
+
self,
|
313
|
+
config: DbrxConfig,
|
314
|
+
quant_config: Optional[QuantizationConfig] = None,
|
315
|
+
):
|
316
|
+
super().__init__()
|
317
|
+
self.wte = VocabParallelEmbedding(
|
318
|
+
config.vocab_size,
|
319
|
+
config.d_model,
|
320
|
+
)
|
321
|
+
self.blocks = nn.ModuleList(
|
322
|
+
[
|
323
|
+
DbrxBlock(config, i, quant_config=quant_config)
|
324
|
+
for i in range(config.n_layers)
|
325
|
+
]
|
326
|
+
)
|
327
|
+
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
328
|
+
for module in self.modules():
|
329
|
+
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
330
|
+
# Remove the bias term in Linear and LayerNorm.
|
331
|
+
module.register_parameter("bias", None)
|
332
|
+
|
333
|
+
def forward(
|
334
|
+
self,
|
335
|
+
input_ids: torch.Tensor,
|
336
|
+
position_ids: torch.Tensor,
|
337
|
+
input_metadata: InputMetadata,
|
338
|
+
input_embeds: torch.Tensor = None,
|
339
|
+
) -> torch.Tensor:
|
340
|
+
if input_embeds is None:
|
341
|
+
hidden_states = self.wte(input_ids)
|
342
|
+
else:
|
343
|
+
hidden_states = input_embeds
|
344
|
+
for i in range(len(self.blocks)):
|
345
|
+
block = self.blocks[i]
|
346
|
+
hidden_states = block(position_ids, hidden_states, input_metadata)
|
347
|
+
hidden_states = self.norm_f(hidden_states)
|
348
|
+
return hidden_states
|
349
|
+
|
350
|
+
|
351
|
+
class DbrxForCausalLM(nn.Module):
|
352
|
+
def __init__(
|
353
|
+
self,
|
354
|
+
config: DbrxConfig,
|
355
|
+
quant_config: Optional[QuantizationConfig] = None,
|
356
|
+
cache_config: Optional[CacheConfig] = None,
|
357
|
+
):
|
358
|
+
super().__init__()
|
359
|
+
self.config = config
|
360
|
+
self.quant_config = quant_config
|
361
|
+
self.unpadded_vocab_size = config.vocab_size
|
362
|
+
self.transformer = DbrxModel(config, quant_config=quant_config)
|
363
|
+
self.lm_head = ParallelLMHead(
|
364
|
+
config.vocab_size,
|
365
|
+
config.d_model,
|
366
|
+
org_num_embeddings=config.vocab_size,
|
367
|
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
368
|
+
)
|
369
|
+
self.logits_processor = LogitsProcessor(config)
|
370
|
+
|
371
|
+
def forward(
|
372
|
+
self,
|
373
|
+
input_ids: torch.Tensor,
|
374
|
+
positions: torch.Tensor,
|
375
|
+
input_metadata: InputMetadata,
|
376
|
+
) -> torch.Tensor:
|
377
|
+
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
378
|
+
return self.logits_processor(
|
379
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
380
|
+
)
|
381
|
+
|
382
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
383
|
+
expert_params_mapping = [
|
384
|
+
(
|
385
|
+
"ws" if weight_name in ["w1", "v1"] else "w2s",
|
386
|
+
f"experts.mlp.{weight_name}",
|
387
|
+
)
|
388
|
+
for weight_name in ["w1", "v1", "w2"]
|
389
|
+
]
|
390
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
391
|
+
for name, loaded_weight in weights:
|
392
|
+
for param_name, weight_name in expert_params_mapping:
|
393
|
+
if weight_name not in name:
|
394
|
+
continue
|
395
|
+
name = name.replace(weight_name, param_name)
|
396
|
+
param = params_dict[name]
|
397
|
+
weight_loader = param.weight_loader
|
398
|
+
weight_loader(param, loaded_weight, weight_name)
|
399
|
+
break
|
400
|
+
else:
|
401
|
+
param = params_dict[name]
|
402
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
403
|
+
weight_loader(param, loaded_weight)
|
404
|
+
|
405
|
+
|
406
|
+
EntryClass = DbrxForCausalLM
|
sglang/srt/models/gemma.py
CHANGED
@@ -1,32 +1,28 @@
|
|
1
1
|
# Adapted from:
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
|
3
3
|
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
4
|
-
from typing import Optional, Tuple
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
8
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
9
7
|
from torch import nn
|
10
8
|
from transformers import PretrainedConfig
|
11
|
-
from vllm.config import LoRAConfig
|
12
|
-
from vllm.
|
9
|
+
from vllm.config import CacheConfig, LoRAConfig
|
10
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
13
11
|
from vllm.model_executor.layers.activation import GeluAndMul
|
14
12
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
15
13
|
from vllm.model_executor.layers.linear import (
|
16
|
-
LinearMethodBase,
|
17
14
|
MergedColumnParallelLinear,
|
18
15
|
QKVParallelLinear,
|
19
16
|
RowParallelLinear,
|
20
17
|
)
|
18
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
21
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
23
|
-
from vllm.model_executor.
|
24
|
-
|
25
|
-
|
26
|
-
from
|
27
|
-
|
28
|
-
hf_model_weights_iterator,
|
29
|
-
)
|
21
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
22
|
+
|
23
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
24
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
25
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
30
26
|
|
31
27
|
|
32
28
|
class GemmaMLP(nn.Module):
|
@@ -34,17 +30,20 @@ class GemmaMLP(nn.Module):
|
|
34
30
|
self,
|
35
31
|
hidden_size: int,
|
36
32
|
intermediate_size: int,
|
37
|
-
|
33
|
+
quant_config: Optional[QuantizationConfig] = None,
|
38
34
|
) -> None:
|
39
35
|
super().__init__()
|
40
36
|
self.gate_up_proj = MergedColumnParallelLinear(
|
41
37
|
hidden_size,
|
42
38
|
[intermediate_size] * 2,
|
43
39
|
bias=False,
|
44
|
-
|
40
|
+
quant_config=quant_config,
|
45
41
|
)
|
46
42
|
self.down_proj = RowParallelLinear(
|
47
|
-
intermediate_size,
|
43
|
+
intermediate_size,
|
44
|
+
hidden_size,
|
45
|
+
bias=False,
|
46
|
+
quant_config=quant_config,
|
48
47
|
)
|
49
48
|
self.act_fn = GeluAndMul()
|
50
49
|
|
@@ -65,7 +64,7 @@ class GemmaAttention(nn.Module):
|
|
65
64
|
layer_id: int = 0,
|
66
65
|
max_position_embeddings: int = 8192,
|
67
66
|
rope_theta: float = 10000,
|
68
|
-
|
67
|
+
quant_config: Optional[QuantizationConfig] = None,
|
69
68
|
) -> None:
|
70
69
|
super().__init__()
|
71
70
|
self.hidden_size = hidden_size
|
@@ -95,13 +94,13 @@ class GemmaAttention(nn.Module):
|
|
95
94
|
self.total_num_heads,
|
96
95
|
self.total_num_kv_heads,
|
97
96
|
bias=False,
|
98
|
-
|
97
|
+
quant_config=quant_config,
|
99
98
|
)
|
100
99
|
self.o_proj = RowParallelLinear(
|
101
100
|
self.total_num_heads * self.head_dim,
|
102
101
|
hidden_size,
|
103
102
|
bias=False,
|
104
|
-
|
103
|
+
quant_config=quant_config,
|
105
104
|
)
|
106
105
|
|
107
106
|
self.rotary_emb = get_rope(
|
@@ -138,7 +137,7 @@ class GemmaDecoderLayer(nn.Module):
|
|
138
137
|
self,
|
139
138
|
config: PretrainedConfig,
|
140
139
|
layer_id: int = 0,
|
141
|
-
|
140
|
+
quant_config: Optional[QuantizationConfig] = None,
|
142
141
|
) -> None:
|
143
142
|
super().__init__()
|
144
143
|
self.hidden_size = config.hidden_size
|
@@ -150,12 +149,12 @@ class GemmaDecoderLayer(nn.Module):
|
|
150
149
|
layer_id=layer_id,
|
151
150
|
max_position_embeddings=config.max_position_embeddings,
|
152
151
|
rope_theta=config.rope_theta,
|
153
|
-
|
152
|
+
quant_config=quant_config,
|
154
153
|
)
|
155
154
|
self.mlp = GemmaMLP(
|
156
155
|
hidden_size=self.hidden_size,
|
157
156
|
intermediate_size=config.intermediate_size,
|
158
|
-
|
157
|
+
quant_config=quant_config,
|
159
158
|
)
|
160
159
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
161
160
|
self.post_attention_layernorm = RMSNorm(
|
@@ -191,7 +190,7 @@ class GemmaModel(nn.Module):
|
|
191
190
|
def __init__(
|
192
191
|
self,
|
193
192
|
config: PretrainedConfig,
|
194
|
-
|
193
|
+
quant_config: Optional[QuantizationConfig] = None,
|
195
194
|
) -> None:
|
196
195
|
super().__init__()
|
197
196
|
self.config = config
|
@@ -202,7 +201,7 @@ class GemmaModel(nn.Module):
|
|
202
201
|
)
|
203
202
|
self.layers = nn.ModuleList(
|
204
203
|
[
|
205
|
-
GemmaDecoderLayer(config, i,
|
204
|
+
GemmaDecoderLayer(config, i, quant_config=quant_config)
|
206
205
|
for i in range(config.num_hidden_layers)
|
207
206
|
]
|
208
207
|
)
|
@@ -263,14 +262,15 @@ class GemmaForCausalLM(nn.Module):
|
|
263
262
|
def __init__(
|
264
263
|
self,
|
265
264
|
config: PretrainedConfig,
|
266
|
-
|
265
|
+
quant_config: Optional[QuantizationConfig] = None,
|
267
266
|
lora_config: Optional[LoRAConfig] = None,
|
267
|
+
cache_config: Optional[CacheConfig] = None,
|
268
268
|
) -> None:
|
269
269
|
del lora_config # Unused.
|
270
270
|
super().__init__()
|
271
271
|
self.config = config
|
272
|
-
self.
|
273
|
-
self.model = GemmaModel(config,
|
272
|
+
self.quant_config = quant_config
|
273
|
+
self.model = GemmaModel(config, quant_config=quant_config)
|
274
274
|
self.logits_processor = LogitsProcessor(config)
|
275
275
|
|
276
276
|
@torch.no_grad()
|
@@ -286,13 +286,7 @@ class GemmaForCausalLM(nn.Module):
|
|
286
286
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
287
287
|
)
|
288
288
|
|
289
|
-
def load_weights(
|
290
|
-
self,
|
291
|
-
model_name_or_path: str,
|
292
|
-
cache_dir: Optional[str] = None,
|
293
|
-
load_format: str = "auto",
|
294
|
-
revision: Optional[str] = None,
|
295
|
-
):
|
289
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
296
290
|
stacked_params_mapping = [
|
297
291
|
# (param_name, shard_name, shard_id)
|
298
292
|
("qkv_proj", "q_proj", "q"),
|
@@ -303,9 +297,7 @@ class GemmaForCausalLM(nn.Module):
|
|
303
297
|
]
|
304
298
|
params_dict = dict(self.named_parameters())
|
305
299
|
loaded_params = set()
|
306
|
-
for name, loaded_weight in
|
307
|
-
model_name_or_path, cache_dir, load_format, revision
|
308
|
-
):
|
300
|
+
for name, loaded_weight in weights:
|
309
301
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
310
302
|
if shard_name not in name:
|
311
303
|
continue
|
@@ -318,6 +310,10 @@ class GemmaForCausalLM(nn.Module):
|
|
318
310
|
weight_loader(param, loaded_weight, shard_id)
|
319
311
|
break
|
320
312
|
else:
|
313
|
+
# lm_head is not used in vllm as it is tied with embed_token.
|
314
|
+
# To prevent errors, skip loading lm_head.weight.
|
315
|
+
if "lm_head.weight" in name:
|
316
|
+
continue
|
321
317
|
# Skip loading extra bias for GPTQ models.
|
322
318
|
if name.endswith(".bias") and name not in params_dict:
|
323
319
|
continue
|