sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +14 -0
- sglang/backend/anthropic.py +18 -12
- sglang/backend/base_backend.py +6 -0
- sglang/backend/openai.py +41 -12
- sglang/backend/runtime_endpoint.py +57 -6
- sglang/lang/chat_template.py +47 -26
- sglang/lang/interpreter.py +15 -2
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +23 -1
- sglang/srt/constrained/fsm_cache.py +14 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -1
- sglang/srt/layers/extend_attention.py +7 -6
- sglang/srt/layers/radix_attention.py +2 -10
- sglang/srt/layers/token_attention.py +12 -4
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/router/infer_batch.py +6 -2
- sglang/srt/managers/router/model_rpc.py +45 -32
- sglang/srt/managers/router/model_runner.py +40 -25
- sglang/srt/managers/tokenizer_manager.py +2 -0
- sglang/srt/model_config.py +12 -5
- sglang/srt/models/gemma.py +340 -0
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llava.py +2 -4
- sglang/srt/models/mixtral.py +5 -5
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +5 -5
- sglang/srt/models/stablelm.py +293 -0
- sglang/srt/server.py +111 -47
- sglang/srt/server_args.py +44 -9
- sglang/srt/utils.py +1 -0
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +15 -12
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
- sglang-0.1.14.dist-info/RECORD +64 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
- sglang/srt/models/gpt_neox.py +0 -274
- sglang-0.1.12.dist-info/RECORD +0 -63
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,340 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py
|
3
|
+
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
4
|
+
from typing import Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
8
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
9
|
+
from torch import nn
|
10
|
+
from transformers import PretrainedConfig
|
11
|
+
from vllm.config import LoRAConfig
|
12
|
+
from vllm.model_executor.input_metadata import InputMetadata
|
13
|
+
from vllm.model_executor.layers.activation import GeluAndMul
|
14
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
15
|
+
from vllm.model_executor.layers.linear import (
|
16
|
+
LinearMethodBase,
|
17
|
+
MergedColumnParallelLinear,
|
18
|
+
QKVParallelLinear,
|
19
|
+
RowParallelLinear,
|
20
|
+
)
|
21
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
23
|
+
from vllm.model_executor.parallel_utils.parallel_state import (
|
24
|
+
get_tensor_model_parallel_world_size,
|
25
|
+
)
|
26
|
+
from vllm.model_executor.weight_utils import (
|
27
|
+
default_weight_loader,
|
28
|
+
hf_model_weights_iterator,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
class GemmaMLP(nn.Module):
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
hidden_size: int,
|
36
|
+
intermediate_size: int,
|
37
|
+
linear_method: Optional[LinearMethodBase] = None,
|
38
|
+
) -> None:
|
39
|
+
super().__init__()
|
40
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
41
|
+
hidden_size,
|
42
|
+
[intermediate_size] * 2,
|
43
|
+
bias=False,
|
44
|
+
linear_method=linear_method,
|
45
|
+
)
|
46
|
+
self.down_proj = RowParallelLinear(
|
47
|
+
intermediate_size, hidden_size, bias=False, linear_method=linear_method
|
48
|
+
)
|
49
|
+
self.act_fn = GeluAndMul()
|
50
|
+
|
51
|
+
def forward(self, x):
|
52
|
+
gate_up, _ = self.gate_up_proj(x)
|
53
|
+
x = self.act_fn(gate_up)
|
54
|
+
x, _ = self.down_proj(x)
|
55
|
+
return x
|
56
|
+
|
57
|
+
|
58
|
+
class GemmaAttention(nn.Module):
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
hidden_size: int,
|
62
|
+
num_heads: int,
|
63
|
+
num_kv_heads: int,
|
64
|
+
head_dim: int,
|
65
|
+
layer_id: int = 0,
|
66
|
+
max_position_embeddings: int = 8192,
|
67
|
+
rope_theta: float = 10000,
|
68
|
+
linear_method: Optional[LinearMethodBase] = None,
|
69
|
+
) -> None:
|
70
|
+
super().__init__()
|
71
|
+
self.hidden_size = hidden_size
|
72
|
+
tp_size = get_tensor_model_parallel_world_size()
|
73
|
+
self.total_num_heads = num_heads
|
74
|
+
assert self.total_num_heads % tp_size == 0
|
75
|
+
self.num_heads = self.total_num_heads // tp_size
|
76
|
+
self.total_num_kv_heads = num_kv_heads
|
77
|
+
if self.total_num_kv_heads >= tp_size:
|
78
|
+
# Number of KV heads is greater than TP size, so we partition
|
79
|
+
# the KV heads across multiple tensor parallel GPUs.
|
80
|
+
assert self.total_num_kv_heads % tp_size == 0
|
81
|
+
else:
|
82
|
+
# Number of KV heads is less than TP size, so we replicate
|
83
|
+
# the KV heads across multiple tensor parallel GPUs.
|
84
|
+
assert tp_size % self.total_num_kv_heads == 0
|
85
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
86
|
+
self.head_dim = head_dim
|
87
|
+
self.q_size = self.num_heads * self.head_dim
|
88
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
89
|
+
self.scaling = self.head_dim**-0.5
|
90
|
+
self.rope_theta = rope_theta
|
91
|
+
|
92
|
+
self.qkv_proj = QKVParallelLinear(
|
93
|
+
hidden_size,
|
94
|
+
self.head_dim,
|
95
|
+
self.total_num_heads,
|
96
|
+
self.total_num_kv_heads,
|
97
|
+
bias=False,
|
98
|
+
linear_method=linear_method,
|
99
|
+
)
|
100
|
+
self.o_proj = RowParallelLinear(
|
101
|
+
self.total_num_heads * self.head_dim,
|
102
|
+
hidden_size,
|
103
|
+
bias=False,
|
104
|
+
linear_method=linear_method,
|
105
|
+
)
|
106
|
+
|
107
|
+
self.rotary_emb = get_rope(
|
108
|
+
self.head_dim,
|
109
|
+
rotary_dim=self.head_dim,
|
110
|
+
max_position=max_position_embeddings,
|
111
|
+
base=self.rope_theta,
|
112
|
+
is_neox_style=True,
|
113
|
+
)
|
114
|
+
self.attn = RadixAttention(
|
115
|
+
self.num_heads,
|
116
|
+
self.head_dim,
|
117
|
+
self.scaling,
|
118
|
+
num_kv_heads=self.num_kv_heads,
|
119
|
+
layer_id=layer_id,
|
120
|
+
)
|
121
|
+
|
122
|
+
def forward(
|
123
|
+
self,
|
124
|
+
positions: torch.Tensor,
|
125
|
+
hidden_states: torch.Tensor,
|
126
|
+
input_metadata: InputMetadata,
|
127
|
+
) -> torch.Tensor:
|
128
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
129
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
130
|
+
q, k = self.rotary_emb(positions, q, k)
|
131
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
132
|
+
output, _ = self.o_proj(attn_output)
|
133
|
+
return output
|
134
|
+
|
135
|
+
|
136
|
+
class GemmaDecoderLayer(nn.Module):
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
config: PretrainedConfig,
|
140
|
+
layer_id: int = 0,
|
141
|
+
linear_method: Optional[LinearMethodBase] = None,
|
142
|
+
) -> None:
|
143
|
+
super().__init__()
|
144
|
+
self.hidden_size = config.hidden_size
|
145
|
+
self.self_attn = GemmaAttention(
|
146
|
+
hidden_size=self.hidden_size,
|
147
|
+
num_heads=config.num_attention_heads,
|
148
|
+
num_kv_heads=config.num_key_value_heads,
|
149
|
+
head_dim=config.head_dim,
|
150
|
+
layer_id=layer_id,
|
151
|
+
max_position_embeddings=config.max_position_embeddings,
|
152
|
+
rope_theta=config.rope_theta,
|
153
|
+
linear_method=linear_method,
|
154
|
+
)
|
155
|
+
self.mlp = GemmaMLP(
|
156
|
+
hidden_size=self.hidden_size,
|
157
|
+
intermediate_size=config.intermediate_size,
|
158
|
+
linear_method=linear_method,
|
159
|
+
)
|
160
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
161
|
+
self.post_attention_layernorm = RMSNorm(
|
162
|
+
config.hidden_size, eps=config.rms_norm_eps
|
163
|
+
)
|
164
|
+
|
165
|
+
def forward(
|
166
|
+
self,
|
167
|
+
positions: torch.Tensor,
|
168
|
+
hidden_states: torch.Tensor,
|
169
|
+
input_metadata: InputMetadata,
|
170
|
+
residual: Optional[torch.Tensor],
|
171
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
172
|
+
# Self Attention
|
173
|
+
if residual is None:
|
174
|
+
residual = hidden_states
|
175
|
+
hidden_states = self.input_layernorm(hidden_states)
|
176
|
+
else:
|
177
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
178
|
+
hidden_states = self.self_attn(
|
179
|
+
positions=positions,
|
180
|
+
hidden_states=hidden_states,
|
181
|
+
input_metadata=input_metadata,
|
182
|
+
)
|
183
|
+
|
184
|
+
# Fully Connected
|
185
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
186
|
+
hidden_states = self.mlp(hidden_states)
|
187
|
+
return hidden_states, residual
|
188
|
+
|
189
|
+
|
190
|
+
class GemmaModel(nn.Module):
|
191
|
+
def __init__(
|
192
|
+
self,
|
193
|
+
config: PretrainedConfig,
|
194
|
+
linear_method: Optional[LinearMethodBase] = None,
|
195
|
+
) -> None:
|
196
|
+
super().__init__()
|
197
|
+
self.config = config
|
198
|
+
|
199
|
+
self.embed_tokens = VocabParallelEmbedding(
|
200
|
+
config.vocab_size,
|
201
|
+
config.hidden_size,
|
202
|
+
)
|
203
|
+
self.layers = nn.ModuleList(
|
204
|
+
[
|
205
|
+
GemmaDecoderLayer(config, i, linear_method)
|
206
|
+
for i in range(config.num_hidden_layers)
|
207
|
+
]
|
208
|
+
)
|
209
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
210
|
+
|
211
|
+
def forward(
|
212
|
+
self,
|
213
|
+
input_ids: torch.Tensor,
|
214
|
+
positions: torch.Tensor,
|
215
|
+
input_metadata: InputMetadata,
|
216
|
+
input_embeds: torch.Tensor = None,
|
217
|
+
) -> torch.Tensor:
|
218
|
+
if input_embeds is None:
|
219
|
+
hidden_states = self.embed_tokens(input_ids)
|
220
|
+
else:
|
221
|
+
hidden_states = input_embeds
|
222
|
+
|
223
|
+
# Normalize the embedding by sqrt(hidden_size)
|
224
|
+
hidden_states *= self.config.hidden_size**0.5
|
225
|
+
|
226
|
+
residual = None
|
227
|
+
for i in range(len(self.layers)):
|
228
|
+
layer = self.layers[i]
|
229
|
+
hidden_states, residual = layer(
|
230
|
+
positions,
|
231
|
+
hidden_states,
|
232
|
+
input_metadata,
|
233
|
+
residual,
|
234
|
+
)
|
235
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
236
|
+
return hidden_states
|
237
|
+
|
238
|
+
|
239
|
+
class GemmaForCausalLM(nn.Module):
|
240
|
+
packed_modules_mapping = {
|
241
|
+
"qkv_proj": [
|
242
|
+
"q_proj",
|
243
|
+
"k_proj",
|
244
|
+
"v_proj",
|
245
|
+
],
|
246
|
+
"gate_up_proj": [
|
247
|
+
"gate_proj",
|
248
|
+
"up_proj",
|
249
|
+
],
|
250
|
+
}
|
251
|
+
|
252
|
+
# LoRA specific attributes
|
253
|
+
supported_lora_modules = [
|
254
|
+
"qkv_proj",
|
255
|
+
"o_proj",
|
256
|
+
"gate_up_proj",
|
257
|
+
"down_proj",
|
258
|
+
]
|
259
|
+
# Gemma does not apply LoRA to the embedding layer.
|
260
|
+
embedding_modules = {}
|
261
|
+
embedding_padding_modules = []
|
262
|
+
|
263
|
+
def __init__(
|
264
|
+
self,
|
265
|
+
config: PretrainedConfig,
|
266
|
+
linear_method: Optional[LinearMethodBase] = None,
|
267
|
+
lora_config: Optional[LoRAConfig] = None,
|
268
|
+
) -> None:
|
269
|
+
del lora_config # Unused.
|
270
|
+
super().__init__()
|
271
|
+
self.config = config
|
272
|
+
self.linear_method = linear_method
|
273
|
+
self.model = GemmaModel(config, linear_method)
|
274
|
+
self.logits_processor = LogitsProcessor(config)
|
275
|
+
|
276
|
+
@torch.no_grad()
|
277
|
+
def forward(
|
278
|
+
self,
|
279
|
+
input_ids: torch.Tensor,
|
280
|
+
positions: torch.Tensor,
|
281
|
+
input_metadata: InputMetadata,
|
282
|
+
input_embeds: torch.Tensor = None,
|
283
|
+
) -> torch.Tensor:
|
284
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
285
|
+
return self.logits_processor(
|
286
|
+
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
287
|
+
)
|
288
|
+
|
289
|
+
def load_weights(
|
290
|
+
self,
|
291
|
+
model_name_or_path: str,
|
292
|
+
cache_dir: Optional[str] = None,
|
293
|
+
load_format: str = "auto",
|
294
|
+
revision: Optional[str] = None,
|
295
|
+
):
|
296
|
+
stacked_params_mapping = [
|
297
|
+
# (param_name, shard_name, shard_id)
|
298
|
+
("qkv_proj", "q_proj", "q"),
|
299
|
+
("qkv_proj", "k_proj", "k"),
|
300
|
+
("qkv_proj", "v_proj", "v"),
|
301
|
+
("gate_up_proj", "gate_proj", 0),
|
302
|
+
("gate_up_proj", "up_proj", 1),
|
303
|
+
]
|
304
|
+
params_dict = dict(self.named_parameters())
|
305
|
+
loaded_params = set()
|
306
|
+
for name, loaded_weight in hf_model_weights_iterator(
|
307
|
+
model_name_or_path, cache_dir, load_format, revision
|
308
|
+
):
|
309
|
+
for param_name, shard_name, shard_id in stacked_params_mapping:
|
310
|
+
if shard_name not in name:
|
311
|
+
continue
|
312
|
+
name = name.replace(shard_name, param_name)
|
313
|
+
# Skip loading extra bias for GPTQ models.
|
314
|
+
if name.endswith(".bias") and name not in params_dict:
|
315
|
+
continue
|
316
|
+
param = params_dict[name]
|
317
|
+
weight_loader = param.weight_loader
|
318
|
+
weight_loader(param, loaded_weight, shard_id)
|
319
|
+
break
|
320
|
+
else:
|
321
|
+
# Skip loading extra bias for GPTQ models.
|
322
|
+
if name.endswith(".bias") and name not in params_dict:
|
323
|
+
continue
|
324
|
+
# GemmaRMSNorm is different from Llama's in that it multiplies
|
325
|
+
# (1 + weight) to the output, instead of just weight.
|
326
|
+
if "norm.weight" in name:
|
327
|
+
loaded_weight += 1.0
|
328
|
+
param = params_dict[name]
|
329
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
330
|
+
weight_loader(param, loaded_weight)
|
331
|
+
loaded_params.add(name)
|
332
|
+
unloaded_params = params_dict.keys() - loaded_params
|
333
|
+
if unloaded_params:
|
334
|
+
raise RuntimeError(
|
335
|
+
"Some weights are not initialized from checkpoints: "
|
336
|
+
f"{unloaded_params}"
|
337
|
+
)
|
338
|
+
|
339
|
+
|
340
|
+
EntryClass = GemmaForCausalLM
|
sglang/srt/models/llama2.py
CHANGED
@@ -227,12 +227,12 @@ class LlamaModel(nn.Module):
|
|
227
227
|
input_ids: torch.Tensor,
|
228
228
|
positions: torch.Tensor,
|
229
229
|
input_metadata: InputMetadata,
|
230
|
-
|
230
|
+
input_embeds: torch.Tensor = None,
|
231
231
|
) -> torch.Tensor:
|
232
|
-
if
|
232
|
+
if input_embeds is None:
|
233
233
|
hidden_states = self.embed_tokens(input_ids)
|
234
234
|
else:
|
235
|
-
hidden_states =
|
235
|
+
hidden_states = input_embeds
|
236
236
|
residual = None
|
237
237
|
for i in range(len(self.layers)):
|
238
238
|
layer = self.layers[i]
|
@@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module):
|
|
264
264
|
input_ids: torch.Tensor,
|
265
265
|
positions: torch.Tensor,
|
266
266
|
input_metadata: InputMetadata,
|
267
|
-
|
267
|
+
input_embeds: torch.Tensor = None,
|
268
268
|
) -> torch.Tensor:
|
269
|
-
hidden_states = self.model(input_ids, positions, input_metadata,
|
269
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
270
270
|
return self.logits_processor(
|
271
271
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
272
272
|
)
|
sglang/srt/models/llava.py
CHANGED
@@ -230,12 +230,10 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
230
230
|
pt += 1
|
231
231
|
|
232
232
|
return self.language_model(
|
233
|
-
|
233
|
+
input_ids, positions, input_metadata, input_embeds=input_embeds
|
234
234
|
)
|
235
235
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
236
|
-
return self.language_model(
|
237
|
-
input_ids, positions, input_metadata, skip_embed=False
|
238
|
-
)
|
236
|
+
return self.language_model(input_ids, positions, input_metadata)
|
239
237
|
|
240
238
|
def load_weights(
|
241
239
|
self,
|
sglang/srt/models/mixtral.py
CHANGED
@@ -296,12 +296,12 @@ class MixtralModel(nn.Module):
|
|
296
296
|
input_ids: torch.Tensor,
|
297
297
|
positions: torch.Tensor,
|
298
298
|
input_metadata: InputMetadata,
|
299
|
-
|
299
|
+
input_embeds: torch.Tensor = None,
|
300
300
|
) -> torch.Tensor:
|
301
|
-
if
|
301
|
+
if input_embeds is None:
|
302
302
|
hidden_states = self.embed_tokens(input_ids)
|
303
303
|
else:
|
304
|
-
hidden_states =
|
304
|
+
hidden_states = input_embeds
|
305
305
|
residual = None
|
306
306
|
for i in range(len(self.layers)):
|
307
307
|
layer = self.layers[i]
|
@@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module):
|
|
330
330
|
input_ids: torch.Tensor,
|
331
331
|
positions: torch.Tensor,
|
332
332
|
input_metadata: InputMetadata,
|
333
|
-
|
333
|
+
input_embeds: torch.Tensor = None,
|
334
334
|
) -> torch.Tensor:
|
335
|
-
hidden_states = self.model(input_ids, positions, input_metadata,
|
335
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
336
336
|
return self.logits_processor(
|
337
337
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
338
338
|
)
|
sglang/srt/models/qwen.py
CHANGED
@@ -5,6 +5,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
5
5
|
from sglang.srt.layers.radix_attention import RadixAttention
|
6
6
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
7
7
|
from torch import nn
|
8
|
+
from transformers import PretrainedConfig
|
8
9
|
from vllm.model_executor.layers.activation import SiluAndMul
|
9
10
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
10
11
|
from vllm.model_executor.layers.linear import (
|
@@ -25,7 +26,6 @@ from vllm.model_executor.weight_utils import (
|
|
25
26
|
default_weight_loader,
|
26
27
|
hf_model_weights_iterator,
|
27
28
|
)
|
28
|
-
from vllm.transformers_utils.configs.qwen import QWenConfig
|
29
29
|
|
30
30
|
|
31
31
|
class QWenMLP(nn.Module):
|
@@ -130,7 +130,7 @@ class QWenAttention(nn.Module):
|
|
130
130
|
|
131
131
|
|
132
132
|
class QWenBlock(nn.Module):
|
133
|
-
def __init__(self, config:
|
133
|
+
def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
|
134
134
|
super().__init__()
|
135
135
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
136
136
|
|
@@ -179,7 +179,7 @@ class QWenBlock(nn.Module):
|
|
179
179
|
|
180
180
|
|
181
181
|
class QWenModel(nn.Module):
|
182
|
-
def __init__(self, config:
|
182
|
+
def __init__(self, config: PretrainedConfig, linear_method=None):
|
183
183
|
super().__init__()
|
184
184
|
self.config = config
|
185
185
|
self.vocab_size = config.vocab_size
|
@@ -216,7 +216,7 @@ class QWenModel(nn.Module):
|
|
216
216
|
|
217
217
|
|
218
218
|
class QWenLMHeadModel(nn.Module):
|
219
|
-
def __init__(self, config:
|
219
|
+
def __init__(self, config: PretrainedConfig, linear_method=None):
|
220
220
|
super().__init__()
|
221
221
|
self.config = config
|
222
222
|
self.transformer = QWenModel(config, linear_method=linear_method)
|
sglang/srt/models/qwen2.py
CHANGED
@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module):
|
|
228
228
|
input_ids: torch.Tensor,
|
229
229
|
positions: torch.Tensor,
|
230
230
|
input_metadata: InputMetadata,
|
231
|
-
|
231
|
+
input_embeds: torch.Tensor = None,
|
232
232
|
) -> torch.Tensor:
|
233
|
-
if
|
233
|
+
if input_embeds is None:
|
234
234
|
hidden_states = self.embed_tokens(input_ids)
|
235
235
|
else:
|
236
|
-
hidden_states =
|
236
|
+
hidden_states = input_embeds
|
237
237
|
residual = None
|
238
238
|
for i in range(len(self.layers)):
|
239
239
|
layer = self.layers[i]
|
@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|
265
265
|
input_ids: torch.Tensor,
|
266
266
|
positions: torch.Tensor,
|
267
267
|
input_metadata: InputMetadata,
|
268
|
-
|
268
|
+
input_embeds: torch.Tensor = None,
|
269
269
|
) -> torch.Tensor:
|
270
|
-
hidden_states = self.model(input_ids, positions, input_metadata,
|
270
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
271
271
|
return self.logits_processor(
|
272
272
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
273
273
|
)
|