sglang 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_serving.py +243 -25
- sglang/global_config.py +3 -2
- sglang/lang/interpreter.py +1 -0
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/radix_attention.py +38 -49
- sglang/srt/managers/controller/cuda_graph_runner.py +58 -16
- sglang/srt/managers/controller/infer_batch.py +51 -22
- sglang/srt/managers/controller/model_runner.py +7 -4
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +9 -11
- sglang/srt/memory_pool.py +13 -5
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/llama2.py +19 -10
- sglang/srt/server.py +20 -1
- sglang/srt/server_args.py +12 -6
- sglang/srt/utils.py +49 -0
- {sglang-0.1.22.dist-info → sglang-0.1.24.dist-info}/METADATA +9 -5
- {sglang-0.1.22.dist-info → sglang-0.1.24.dist-info}/RECORD +24 -22
- {sglang-0.1.22.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
- {sglang-0.1.22.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
- {sglang-0.1.22.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,430 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py
|
3
|
+
"""Inference-only Deepseek model."""
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import PretrainedConfig
|
9
|
+
from vllm.config import CacheConfig
|
10
|
+
from vllm.distributed import (
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
tensor_model_parallel_all_reduce,
|
14
|
+
)
|
15
|
+
from vllm.model_executor.layers.activation import SiluAndMul
|
16
|
+
from vllm.model_executor.layers.fused_moe import fused_moe
|
17
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
18
|
+
from vllm.model_executor.layers.linear import (
|
19
|
+
MergedColumnParallelLinear,
|
20
|
+
QKVParallelLinear,
|
21
|
+
ReplicatedLinear,
|
22
|
+
RowParallelLinear,
|
23
|
+
)
|
24
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
25
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
27
|
+
ParallelLMHead,
|
28
|
+
VocabParallelEmbedding,
|
29
|
+
)
|
30
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
|
+
|
32
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
34
|
+
from sglang.srt.managers.controller.infer_batch import InputMetadata
|
35
|
+
|
36
|
+
|
37
|
+
class DeepseekMLP(nn.Module):
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
hidden_size: int,
|
42
|
+
intermediate_size: int,
|
43
|
+
hidden_act: str,
|
44
|
+
quant_config: Optional[QuantizationConfig] = None,
|
45
|
+
reduce_results: bool = True,
|
46
|
+
) -> None:
|
47
|
+
super().__init__()
|
48
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
49
|
+
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
50
|
+
)
|
51
|
+
self.down_proj = RowParallelLinear(
|
52
|
+
intermediate_size,
|
53
|
+
hidden_size,
|
54
|
+
bias=False,
|
55
|
+
quant_config=quant_config,
|
56
|
+
reduce_results=reduce_results,
|
57
|
+
)
|
58
|
+
if hidden_act != "silu":
|
59
|
+
raise ValueError(
|
60
|
+
f"Unsupported activation: {hidden_act}. "
|
61
|
+
"Only silu is supported for now."
|
62
|
+
)
|
63
|
+
self.act_fn = SiluAndMul()
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
gate_up, _ = self.gate_up_proj(x)
|
67
|
+
x = self.act_fn(gate_up)
|
68
|
+
x, _ = self.down_proj(x)
|
69
|
+
return x
|
70
|
+
|
71
|
+
|
72
|
+
class DeepseekMoE(nn.Module):
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
config: PretrainedConfig,
|
77
|
+
quant_config: Optional[QuantizationConfig] = None,
|
78
|
+
):
|
79
|
+
super().__init__()
|
80
|
+
self.config = config
|
81
|
+
self.rank = get_tensor_model_parallel_rank()
|
82
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
83
|
+
self.n_routed_experts = config.n_routed_experts
|
84
|
+
self.top_k = config.num_experts_per_tok
|
85
|
+
if self.tp_size > self.n_routed_experts:
|
86
|
+
raise ValueError(
|
87
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
88
|
+
f"the number of experts {self.n_routed_experts}."
|
89
|
+
)
|
90
|
+
|
91
|
+
self.experts = nn.ModuleList(
|
92
|
+
[
|
93
|
+
DeepseekMLP(
|
94
|
+
hidden_size=config.hidden_size,
|
95
|
+
intermediate_size=config.moe_intermediate_size,
|
96
|
+
hidden_act=config.hidden_act,
|
97
|
+
quant_config=quant_config,
|
98
|
+
reduce_results=False,
|
99
|
+
)
|
100
|
+
for idx in range(self.n_routed_experts)
|
101
|
+
]
|
102
|
+
)
|
103
|
+
self.pack_params()
|
104
|
+
|
105
|
+
self.gate = ReplicatedLinear(
|
106
|
+
config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
|
107
|
+
)
|
108
|
+
|
109
|
+
if config.n_shared_experts is not None:
|
110
|
+
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
111
|
+
self.shared_experts = DeepseekMLP(
|
112
|
+
hidden_size=config.hidden_size,
|
113
|
+
intermediate_size=intermediate_size,
|
114
|
+
hidden_act=config.hidden_act,
|
115
|
+
quant_config=quant_config,
|
116
|
+
reduce_results=False,
|
117
|
+
)
|
118
|
+
|
119
|
+
def pack_params(self):
|
120
|
+
w1 = []
|
121
|
+
w2 = []
|
122
|
+
for expert in self.experts:
|
123
|
+
w1.append(expert.gate_up_proj.weight)
|
124
|
+
w2.append(expert.down_proj.weight)
|
125
|
+
self.w1 = torch._utils._flatten_dense_tensors(w1)
|
126
|
+
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
127
|
+
for data, param in zip(w1s, w1):
|
128
|
+
param.data = data
|
129
|
+
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
130
|
+
|
131
|
+
self.w2 = torch._utils._flatten_dense_tensors(w2)
|
132
|
+
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
133
|
+
for data, param in zip(w2s, w2):
|
134
|
+
param.data = data
|
135
|
+
|
136
|
+
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
137
|
+
|
138
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
139
|
+
num_tokens, hidden_dim = hidden_states.shape
|
140
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
141
|
+
if self.config.n_shared_experts is not None:
|
142
|
+
shared_output = self.shared_experts(hidden_states)
|
143
|
+
# router_logits: (num_tokens, n_experts)
|
144
|
+
router_logits, _ = self.gate(hidden_states)
|
145
|
+
final_hidden_states = fused_moe(
|
146
|
+
hidden_states,
|
147
|
+
self.w1,
|
148
|
+
self.w2,
|
149
|
+
router_logits,
|
150
|
+
self.top_k,
|
151
|
+
renormalize=self.config.norm_topk_prob,
|
152
|
+
inplace=True,
|
153
|
+
)
|
154
|
+
|
155
|
+
if self.config.n_shared_experts is not None:
|
156
|
+
final_hidden_states = final_hidden_states + shared_output
|
157
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
158
|
+
|
159
|
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
160
|
+
|
161
|
+
|
162
|
+
class DeepseekAttention(nn.Module):
|
163
|
+
|
164
|
+
def __init__(
|
165
|
+
self,
|
166
|
+
hidden_size: int,
|
167
|
+
num_heads: int,
|
168
|
+
num_kv_heads: int,
|
169
|
+
layer_id: int = 0,
|
170
|
+
rope_theta: float = 10000,
|
171
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
172
|
+
max_position_embeddings: int = 8192,
|
173
|
+
cache_config: Optional[CacheConfig] = None,
|
174
|
+
quant_config: Optional[QuantizationConfig] = None,
|
175
|
+
) -> None:
|
176
|
+
super().__init__()
|
177
|
+
self.hidden_size = hidden_size
|
178
|
+
tp_size = get_tensor_model_parallel_world_size()
|
179
|
+
self.total_num_heads = num_heads
|
180
|
+
assert self.total_num_heads % tp_size == 0
|
181
|
+
self.num_heads = self.total_num_heads // tp_size
|
182
|
+
self.total_num_kv_heads = num_kv_heads
|
183
|
+
if self.total_num_kv_heads >= tp_size:
|
184
|
+
# Number of KV heads is greater than TP size, so we partition
|
185
|
+
# the KV heads across multiple tensor parallel GPUs.
|
186
|
+
assert self.total_num_kv_heads % tp_size == 0
|
187
|
+
else:
|
188
|
+
# Number of KV heads is less than TP size, so we replicate
|
189
|
+
# the KV heads across multiple tensor parallel GPUs.
|
190
|
+
assert tp_size % self.total_num_kv_heads == 0
|
191
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
192
|
+
self.head_dim = hidden_size // self.total_num_heads
|
193
|
+
self.q_size = self.num_heads * self.head_dim
|
194
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
195
|
+
self.scaling = self.head_dim**-0.5
|
196
|
+
self.rope_theta = rope_theta
|
197
|
+
self.max_position_embeddings = max_position_embeddings
|
198
|
+
|
199
|
+
self.qkv_proj = QKVParallelLinear(
|
200
|
+
hidden_size,
|
201
|
+
self.head_dim,
|
202
|
+
self.total_num_heads,
|
203
|
+
self.total_num_kv_heads,
|
204
|
+
bias=False,
|
205
|
+
quant_config=quant_config,
|
206
|
+
)
|
207
|
+
|
208
|
+
self.o_proj = RowParallelLinear(
|
209
|
+
self.total_num_heads * self.head_dim,
|
210
|
+
hidden_size,
|
211
|
+
bias=False,
|
212
|
+
quant_config=quant_config,
|
213
|
+
)
|
214
|
+
|
215
|
+
self.rotary_emb = get_rope(
|
216
|
+
self.head_dim,
|
217
|
+
rotary_dim=self.head_dim,
|
218
|
+
max_position=max_position_embeddings,
|
219
|
+
base=rope_theta,
|
220
|
+
rope_scaling=rope_scaling,
|
221
|
+
)
|
222
|
+
self.attn = RadixAttention(
|
223
|
+
self.num_heads,
|
224
|
+
self.head_dim,
|
225
|
+
self.scaling,
|
226
|
+
num_kv_heads=self.num_kv_heads,
|
227
|
+
layer_id=layer_id,
|
228
|
+
)
|
229
|
+
|
230
|
+
def forward(
|
231
|
+
self,
|
232
|
+
positions: torch.Tensor,
|
233
|
+
hidden_states: torch.Tensor,
|
234
|
+
input_metadata: InputMetadata,
|
235
|
+
) -> torch.Tensor:
|
236
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
237
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
238
|
+
q, k = self.rotary_emb(positions, q, k)
|
239
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
240
|
+
output, _ = self.o_proj(attn_output)
|
241
|
+
return output
|
242
|
+
|
243
|
+
|
244
|
+
class DeepseekDecoderLayer(nn.Module):
|
245
|
+
|
246
|
+
def __init__(
|
247
|
+
self,
|
248
|
+
config: PretrainedConfig,
|
249
|
+
layer_id: int,
|
250
|
+
cache_config: Optional[CacheConfig] = None,
|
251
|
+
quant_config: Optional[QuantizationConfig] = None,
|
252
|
+
) -> None:
|
253
|
+
super().__init__()
|
254
|
+
self.hidden_size = config.hidden_size
|
255
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
256
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
257
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
258
|
+
self.self_attn = DeepseekAttention(
|
259
|
+
hidden_size=self.hidden_size,
|
260
|
+
num_heads=config.num_attention_heads,
|
261
|
+
num_kv_heads=config.num_key_value_heads,
|
262
|
+
layer_id=layer_id,
|
263
|
+
rope_theta=rope_theta,
|
264
|
+
rope_scaling=rope_scaling,
|
265
|
+
max_position_embeddings=max_position_embeddings,
|
266
|
+
cache_config=cache_config,
|
267
|
+
quant_config=quant_config,
|
268
|
+
)
|
269
|
+
if (
|
270
|
+
config.n_routed_experts is not None
|
271
|
+
and layer_id >= config.first_k_dense_replace
|
272
|
+
and layer_id % config.moe_layer_freq == 0
|
273
|
+
):
|
274
|
+
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
|
275
|
+
else:
|
276
|
+
self.mlp = DeepseekMLP(
|
277
|
+
hidden_size=config.hidden_size,
|
278
|
+
intermediate_size=config.intermediate_size,
|
279
|
+
hidden_act=config.hidden_act,
|
280
|
+
quant_config=quant_config,
|
281
|
+
)
|
282
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
283
|
+
self.post_attention_layernorm = RMSNorm(
|
284
|
+
config.hidden_size, eps=config.rms_norm_eps
|
285
|
+
)
|
286
|
+
|
287
|
+
def forward(
|
288
|
+
self,
|
289
|
+
positions: torch.Tensor,
|
290
|
+
hidden_states: torch.Tensor,
|
291
|
+
input_metadata: InputMetadata,
|
292
|
+
residual: Optional[torch.Tensor],
|
293
|
+
) -> torch.Tensor:
|
294
|
+
# Self Attention
|
295
|
+
if residual is None:
|
296
|
+
residual = hidden_states
|
297
|
+
hidden_states = self.input_layernorm(hidden_states)
|
298
|
+
else:
|
299
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
300
|
+
hidden_states = self.self_attn(
|
301
|
+
positions=positions,
|
302
|
+
hidden_states=hidden_states,
|
303
|
+
input_metadata=input_metadata,
|
304
|
+
)
|
305
|
+
|
306
|
+
# Fully Connected
|
307
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
308
|
+
hidden_states = self.mlp(hidden_states)
|
309
|
+
return hidden_states, residual
|
310
|
+
|
311
|
+
|
312
|
+
class DeepseekModel(nn.Module):
|
313
|
+
|
314
|
+
fall_back_to_pt_during_load = False
|
315
|
+
|
316
|
+
def __init__(
|
317
|
+
self,
|
318
|
+
config: PretrainedConfig,
|
319
|
+
cache_config: Optional[CacheConfig] = None,
|
320
|
+
quant_config: Optional[QuantizationConfig] = None,
|
321
|
+
) -> None:
|
322
|
+
super().__init__()
|
323
|
+
self.padding_idx = config.pad_token_id
|
324
|
+
self.vocab_size = config.vocab_size
|
325
|
+
|
326
|
+
self.embed_tokens = VocabParallelEmbedding(
|
327
|
+
config.vocab_size,
|
328
|
+
config.hidden_size,
|
329
|
+
)
|
330
|
+
self.layers = nn.ModuleList(
|
331
|
+
[
|
332
|
+
DeepseekDecoderLayer(
|
333
|
+
config, layer_id, cache_config, quant_config=quant_config
|
334
|
+
)
|
335
|
+
for layer_id in range(config.num_hidden_layers)
|
336
|
+
]
|
337
|
+
)
|
338
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
339
|
+
|
340
|
+
def forward(
|
341
|
+
self,
|
342
|
+
input_ids: torch.Tensor,
|
343
|
+
positions: torch.Tensor,
|
344
|
+
input_metadata: InputMetadata,
|
345
|
+
) -> torch.Tensor:
|
346
|
+
hidden_states = self.embed_tokens(input_ids)
|
347
|
+
residual = None
|
348
|
+
for i in range(len(self.layers)):
|
349
|
+
layer = self.layers[i]
|
350
|
+
hidden_states, residual = layer(
|
351
|
+
positions, hidden_states, input_metadata, residual
|
352
|
+
)
|
353
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
354
|
+
return hidden_states
|
355
|
+
|
356
|
+
|
357
|
+
class DeepseekForCausalLM(nn.Module):
|
358
|
+
|
359
|
+
def __init__(
|
360
|
+
self,
|
361
|
+
config: PretrainedConfig,
|
362
|
+
cache_config: Optional[CacheConfig] = None,
|
363
|
+
quant_config: Optional[QuantizationConfig] = None,
|
364
|
+
) -> None:
|
365
|
+
super().__init__()
|
366
|
+
self.config = config
|
367
|
+
self.quant_config = quant_config
|
368
|
+
self.model = DeepseekModel(config, cache_config, quant_config)
|
369
|
+
self.lm_head = ParallelLMHead(
|
370
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
371
|
+
)
|
372
|
+
self.logits_processor = LogitsProcessor(config)
|
373
|
+
|
374
|
+
@torch.no_grad()
|
375
|
+
def forward(
|
376
|
+
self,
|
377
|
+
input_ids: torch.Tensor,
|
378
|
+
positions: torch.Tensor,
|
379
|
+
input_metadata: InputMetadata,
|
380
|
+
) -> torch.Tensor:
|
381
|
+
hidden_states = self.model(input_ids, positions, input_metadata)
|
382
|
+
return self.logits_processor(
|
383
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
384
|
+
)
|
385
|
+
|
386
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
387
|
+
stacked_params_mapping = [
|
388
|
+
# (param_name, shard_name, shard_id)
|
389
|
+
("qkv_proj", "q_proj", "q"),
|
390
|
+
("qkv_proj", "k_proj", "k"),
|
391
|
+
("qkv_proj", "v_proj", "v"),
|
392
|
+
("gate_up_proj", "gate_proj", 0),
|
393
|
+
("gate_up_proj", "up_proj", 1),
|
394
|
+
]
|
395
|
+
|
396
|
+
params_dict = dict(self.named_parameters())
|
397
|
+
for name, loaded_weight in weights:
|
398
|
+
if "rotary_emb.inv_freq" in name:
|
399
|
+
continue
|
400
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
401
|
+
if weight_name not in name:
|
402
|
+
continue
|
403
|
+
name = name.replace(weight_name, param_name)
|
404
|
+
# Skip loading extra bias for GPTQ models.
|
405
|
+
if name.endswith(".bias") and name not in params_dict:
|
406
|
+
continue
|
407
|
+
# Skip experts that are not assigned to this worker.
|
408
|
+
if (
|
409
|
+
"mlp.experts." in name or "mlp.shared_experts." in name
|
410
|
+
) and name not in params_dict:
|
411
|
+
continue
|
412
|
+
param = params_dict[name]
|
413
|
+
weight_loader = param.weight_loader
|
414
|
+
weight_loader(param, loaded_weight, shard_id)
|
415
|
+
break
|
416
|
+
else:
|
417
|
+
# Skip loading extra bias for GPTQ models.
|
418
|
+
if name.endswith(".bias") and name not in params_dict:
|
419
|
+
continue
|
420
|
+
# Skip experts that are not assigned to this worker.
|
421
|
+
if (
|
422
|
+
"mlp.experts." in name or "mlp.shared_experts." in name
|
423
|
+
) and name not in params_dict:
|
424
|
+
continue
|
425
|
+
param = params_dict[name]
|
426
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
427
|
+
weight_loader(param, loaded_weight)
|
428
|
+
|
429
|
+
|
430
|
+
EntryClass = DeepseekForCausalLM
|
@@ -0,0 +1,282 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
|
3
|
+
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import GPTBigCodeConfig
|
9
|
+
from vllm.config import CacheConfig, LoRAConfig
|
10
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
11
|
+
from vllm.model_executor.layers.activation import get_act_fn
|
12
|
+
from vllm.model_executor.layers.linear import (
|
13
|
+
ColumnParallelLinear,
|
14
|
+
QKVParallelLinear,
|
15
|
+
RowParallelLinear,
|
16
|
+
)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
18
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
19
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
20
|
+
|
21
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
22
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
23
|
+
from sglang.srt.managers.controller.infer_batch import InputMetadata
|
24
|
+
|
25
|
+
|
26
|
+
class GPTBigCodeAttention(nn.Module):
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
layer_id: int,
|
31
|
+
config: GPTBigCodeConfig,
|
32
|
+
cache_config: Optional[CacheConfig] = None,
|
33
|
+
quant_config: Optional[QuantizationConfig] = None,
|
34
|
+
):
|
35
|
+
super().__init__()
|
36
|
+
self.hidden_size = config.hidden_size
|
37
|
+
total_num_heads = config.num_attention_heads
|
38
|
+
self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
39
|
+
assert total_num_heads % self.tensor_model_parallel_world_size == 0
|
40
|
+
self.num_heads = total_num_heads // self.tensor_model_parallel_world_size
|
41
|
+
self.head_dim = self.hidden_size // total_num_heads
|
42
|
+
self.scale = self.head_dim**-0.5
|
43
|
+
|
44
|
+
self.multi_query = config.multi_query
|
45
|
+
if self.multi_query:
|
46
|
+
total_num_kv_heads = 1
|
47
|
+
self.num_kv_heads = 1
|
48
|
+
else:
|
49
|
+
total_num_kv_heads = total_num_heads
|
50
|
+
self.num_kv_heads = self.num_heads
|
51
|
+
self.kv_dim = self.head_dim * self.num_kv_heads
|
52
|
+
self.c_attn = QKVParallelLinear(
|
53
|
+
self.hidden_size,
|
54
|
+
self.head_dim,
|
55
|
+
total_num_heads,
|
56
|
+
total_num_kv_heads,
|
57
|
+
bias=True,
|
58
|
+
quant_config=quant_config,
|
59
|
+
)
|
60
|
+
|
61
|
+
self.c_proj = RowParallelLinear(
|
62
|
+
self.hidden_size,
|
63
|
+
self.hidden_size,
|
64
|
+
bias=True,
|
65
|
+
quant_config=quant_config,
|
66
|
+
)
|
67
|
+
self.attn = RadixAttention(
|
68
|
+
self.num_heads,
|
69
|
+
self.head_dim,
|
70
|
+
scaling=self.scale,
|
71
|
+
num_kv_heads=self.num_kv_heads,
|
72
|
+
layer_id=layer_id,
|
73
|
+
)
|
74
|
+
|
75
|
+
def forward(
|
76
|
+
self,
|
77
|
+
hidden_states: torch.Tensor,
|
78
|
+
input_metadata: InputMetadata,
|
79
|
+
) -> torch.Tensor:
|
80
|
+
qkv, _ = self.c_attn(hidden_states)
|
81
|
+
q, k, v = qkv.split(
|
82
|
+
[
|
83
|
+
self.hidden_size // self.tensor_model_parallel_world_size,
|
84
|
+
self.kv_dim,
|
85
|
+
self.kv_dim,
|
86
|
+
],
|
87
|
+
dim=-1,
|
88
|
+
)
|
89
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
90
|
+
attn_output, _ = self.c_proj(attn_output)
|
91
|
+
return attn_output
|
92
|
+
|
93
|
+
|
94
|
+
class GPTBigMLP(nn.Module):
|
95
|
+
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
intermediate_size: int,
|
99
|
+
config: GPTBigCodeConfig,
|
100
|
+
quant_config: Optional[QuantizationConfig] = None,
|
101
|
+
):
|
102
|
+
super().__init__()
|
103
|
+
hidden_size = config.hidden_size
|
104
|
+
self.c_fc = ColumnParallelLinear(
|
105
|
+
hidden_size,
|
106
|
+
intermediate_size,
|
107
|
+
bias=True,
|
108
|
+
quant_config=quant_config,
|
109
|
+
)
|
110
|
+
self.c_proj = RowParallelLinear(
|
111
|
+
intermediate_size,
|
112
|
+
hidden_size,
|
113
|
+
bias=True,
|
114
|
+
quant_config=quant_config,
|
115
|
+
)
|
116
|
+
self.act = get_act_fn(
|
117
|
+
config.activation_function, quant_config, intermediate_size
|
118
|
+
)
|
119
|
+
|
120
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
121
|
+
hidden_states, _ = self.c_fc(hidden_states)
|
122
|
+
hidden_states = self.act(hidden_states)
|
123
|
+
hidden_states, _ = self.c_proj(hidden_states)
|
124
|
+
return hidden_states
|
125
|
+
|
126
|
+
|
127
|
+
class GPTBigCodeBlock(nn.Module):
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
layer_id: int,
|
132
|
+
config: GPTBigCodeConfig,
|
133
|
+
cache_config: Optional[CacheConfig] = None,
|
134
|
+
quant_config: Optional[QuantizationConfig] = None,
|
135
|
+
):
|
136
|
+
super().__init__()
|
137
|
+
hidden_size = config.hidden_size
|
138
|
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
139
|
+
|
140
|
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
141
|
+
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
|
142
|
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
143
|
+
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
144
|
+
|
145
|
+
def forward(
|
146
|
+
self,
|
147
|
+
hidden_states: torch.Tensor,
|
148
|
+
input_metadata: InputMetadata,
|
149
|
+
) -> torch.Tensor:
|
150
|
+
residual = hidden_states
|
151
|
+
hidden_states = self.ln_1(hidden_states)
|
152
|
+
attn_output = self.attn(
|
153
|
+
hidden_states=hidden_states, input_metadata=input_metadata
|
154
|
+
)
|
155
|
+
# residual connection
|
156
|
+
hidden_states = attn_output + residual
|
157
|
+
|
158
|
+
residual = hidden_states
|
159
|
+
hidden_states = self.ln_2(hidden_states)
|
160
|
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
161
|
+
# residual connection
|
162
|
+
hidden_states = residual + feed_forward_hidden_states
|
163
|
+
return hidden_states
|
164
|
+
|
165
|
+
|
166
|
+
class GPTBigCodeModel(nn.Module):
|
167
|
+
|
168
|
+
def __init__(
|
169
|
+
self,
|
170
|
+
config: GPTBigCodeConfig,
|
171
|
+
cache_config: Optional[CacheConfig] = None,
|
172
|
+
quant_config: Optional[QuantizationConfig] = None,
|
173
|
+
lora_config: Optional[LoRAConfig] = None,
|
174
|
+
):
|
175
|
+
super().__init__()
|
176
|
+
self.config = config
|
177
|
+
assert not config.add_cross_attention
|
178
|
+
|
179
|
+
self.embed_dim = config.hidden_size
|
180
|
+
lora_vocab = (
|
181
|
+
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
182
|
+
if lora_config
|
183
|
+
else 0
|
184
|
+
)
|
185
|
+
self.vocab_size = config.vocab_size + lora_vocab
|
186
|
+
self.wte = VocabParallelEmbedding(
|
187
|
+
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
188
|
+
)
|
189
|
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
190
|
+
self.h = nn.ModuleList(
|
191
|
+
[
|
192
|
+
GPTBigCodeBlock(i, config, cache_config, quant_config)
|
193
|
+
for i in range(config.num_hidden_layers)
|
194
|
+
]
|
195
|
+
)
|
196
|
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
197
|
+
|
198
|
+
def forward(
|
199
|
+
self,
|
200
|
+
input_ids: torch.Tensor,
|
201
|
+
position_ids: torch.Tensor,
|
202
|
+
input_metadata: InputMetadata,
|
203
|
+
) -> torch.Tensor:
|
204
|
+
inputs_embeds = self.wte(input_ids)
|
205
|
+
position_embeds = self.wpe(position_ids)
|
206
|
+
hidden_states = inputs_embeds + position_embeds
|
207
|
+
|
208
|
+
for i in range(len(self.h)):
|
209
|
+
layer = self.h[i]
|
210
|
+
hidden_states = layer(hidden_states, input_metadata)
|
211
|
+
|
212
|
+
hidden_states = self.ln_f(hidden_states)
|
213
|
+
return hidden_states
|
214
|
+
|
215
|
+
|
216
|
+
class GPTBigCodeForCausalLM(nn.Module):
|
217
|
+
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
218
|
+
|
219
|
+
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
|
220
|
+
|
221
|
+
embedding_modules = {
|
222
|
+
"wte": "input_embeddings",
|
223
|
+
"lm_head": "output_embeddings",
|
224
|
+
}
|
225
|
+
|
226
|
+
embedding_padding_modules = []
|
227
|
+
|
228
|
+
def __init__(
|
229
|
+
self,
|
230
|
+
config: GPTBigCodeConfig,
|
231
|
+
cache_config: Optional[CacheConfig] = None,
|
232
|
+
quant_config: Optional[QuantizationConfig] = None,
|
233
|
+
lora_config: Optional[LoRAConfig] = None,
|
234
|
+
):
|
235
|
+
super().__init__()
|
236
|
+
|
237
|
+
self.config = config
|
238
|
+
self.lora_config = lora_config
|
239
|
+
|
240
|
+
self.quant_config = quant_config
|
241
|
+
self.transformer = GPTBigCodeModel(
|
242
|
+
config, cache_config, quant_config, lora_config
|
243
|
+
)
|
244
|
+
self.lm_head = self.transformer.wte
|
245
|
+
self.unpadded_vocab_size = config.vocab_size
|
246
|
+
if lora_config:
|
247
|
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
248
|
+
self.logits_processor = LogitsProcessor(config)
|
249
|
+
|
250
|
+
@torch.no_grad()
|
251
|
+
def forward(
|
252
|
+
self,
|
253
|
+
input_ids: torch.Tensor,
|
254
|
+
positions: torch.Tensor,
|
255
|
+
input_metadata: InputMetadata,
|
256
|
+
) -> torch.Tensor:
|
257
|
+
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
258
|
+
return self.logits_processor(
|
259
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
260
|
+
)
|
261
|
+
|
262
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
263
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
264
|
+
for name, loaded_weight in weights:
|
265
|
+
if "lm_head.weight" in name:
|
266
|
+
continue
|
267
|
+
if ".attn.bias" in name:
|
268
|
+
# Skip attention mask.
|
269
|
+
# NOTE: "c_attn.bias" should not be skipped.
|
270
|
+
continue
|
271
|
+
param = params_dict[name]
|
272
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
273
|
+
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
274
|
+
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
275
|
+
weight_loader(param, loaded_weight, "q")
|
276
|
+
weight_loader(param, loaded_weight, "k")
|
277
|
+
weight_loader(param, loaded_weight, "v")
|
278
|
+
else:
|
279
|
+
weight_loader(param, loaded_weight)
|
280
|
+
|
281
|
+
|
282
|
+
EntryClass = GPTBigCodeForCausalLM
|