sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2.py
CHANGED
@@ -1,33 +1,30 @@
|
|
1
1
|
# Adapted from llama2.py
|
2
2
|
# Modify details for the adaptation of Qwen2 model.
|
3
3
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict,
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
8
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
9
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
10
7
|
from torch import nn
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
11
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
12
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
13
12
|
from vllm.model_executor.layers.linear import (
|
14
|
-
LinearMethodBase,
|
15
13
|
MergedColumnParallelLinear,
|
16
14
|
QKVParallelLinear,
|
17
15
|
RowParallelLinear,
|
18
16
|
)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
19
18
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
20
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
21
20
|
ParallelLMHead,
|
22
21
|
VocabParallelEmbedding,
|
23
22
|
)
|
24
|
-
from vllm.model_executor.
|
25
|
-
|
26
|
-
|
27
|
-
from
|
28
|
-
|
29
|
-
hf_model_weights_iterator,
|
30
|
-
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
|
+
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
27
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
31
28
|
|
32
29
|
Qwen2Config = None
|
33
30
|
|
@@ -38,17 +35,20 @@ class Qwen2MLP(nn.Module):
|
|
38
35
|
hidden_size: int,
|
39
36
|
intermediate_size: int,
|
40
37
|
hidden_act: str,
|
41
|
-
|
38
|
+
quant_config: Optional[QuantizationConfig] = None,
|
42
39
|
) -> None:
|
43
40
|
super().__init__()
|
44
41
|
self.gate_up_proj = MergedColumnParallelLinear(
|
45
42
|
hidden_size,
|
46
43
|
[intermediate_size] * 2,
|
47
44
|
bias=False,
|
48
|
-
|
45
|
+
quant_config=quant_config,
|
49
46
|
)
|
50
47
|
self.down_proj = RowParallelLinear(
|
51
|
-
intermediate_size,
|
48
|
+
intermediate_size,
|
49
|
+
hidden_size,
|
50
|
+
bias=False,
|
51
|
+
quant_config=quant_config,
|
52
52
|
)
|
53
53
|
if hidden_act != "silu":
|
54
54
|
raise ValueError(
|
@@ -74,7 +74,7 @@ class Qwen2Attention(nn.Module):
|
|
74
74
|
rope_theta: float = 1000000,
|
75
75
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
76
76
|
max_position_embeddings: int = 32768,
|
77
|
-
|
77
|
+
quant_config: Optional[QuantizationConfig] = None,
|
78
78
|
) -> None:
|
79
79
|
super().__init__()
|
80
80
|
self.hidden_size = hidden_size
|
@@ -105,13 +105,13 @@ class Qwen2Attention(nn.Module):
|
|
105
105
|
self.total_num_heads,
|
106
106
|
self.total_num_kv_heads,
|
107
107
|
bias=True,
|
108
|
-
|
108
|
+
quant_config=quant_config,
|
109
109
|
)
|
110
110
|
self.o_proj = RowParallelLinear(
|
111
111
|
self.total_num_heads * self.head_dim,
|
112
112
|
hidden_size,
|
113
113
|
bias=False,
|
114
|
-
|
114
|
+
quant_config=quant_config,
|
115
115
|
)
|
116
116
|
|
117
117
|
self.rotary_emb = get_rope(
|
@@ -148,7 +148,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|
148
148
|
self,
|
149
149
|
config: Qwen2Config,
|
150
150
|
layer_id: int = 0,
|
151
|
-
|
151
|
+
quant_config: Optional[QuantizationConfig] = None,
|
152
152
|
) -> None:
|
153
153
|
super().__init__()
|
154
154
|
self.hidden_size = config.hidden_size
|
@@ -163,13 +163,13 @@ class Qwen2DecoderLayer(nn.Module):
|
|
163
163
|
rope_theta=rope_theta,
|
164
164
|
rope_scaling=rope_scaling,
|
165
165
|
max_position_embeddings=max_position_embeddings,
|
166
|
-
|
166
|
+
quant_config=quant_config,
|
167
167
|
)
|
168
168
|
self.mlp = Qwen2MLP(
|
169
169
|
hidden_size=self.hidden_size,
|
170
170
|
intermediate_size=config.intermediate_size,
|
171
171
|
hidden_act=config.hidden_act,
|
172
|
-
|
172
|
+
quant_config=quant_config,
|
173
173
|
)
|
174
174
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
175
175
|
self.post_attention_layernorm = RMSNorm(
|
@@ -205,7 +205,7 @@ class Qwen2Model(nn.Module):
|
|
205
205
|
def __init__(
|
206
206
|
self,
|
207
207
|
config: Qwen2Config,
|
208
|
-
|
208
|
+
quant_config: Optional[QuantizationConfig] = None,
|
209
209
|
) -> None:
|
210
210
|
super().__init__()
|
211
211
|
self.config = config
|
@@ -217,7 +217,7 @@ class Qwen2Model(nn.Module):
|
|
217
217
|
)
|
218
218
|
self.layers = nn.ModuleList(
|
219
219
|
[
|
220
|
-
Qwen2DecoderLayer(config, i,
|
220
|
+
Qwen2DecoderLayer(config, i, quant_config=quant_config)
|
221
221
|
for i in range(config.num_hidden_layers)
|
222
222
|
]
|
223
223
|
)
|
@@ -251,12 +251,13 @@ class Qwen2ForCausalLM(nn.Module):
|
|
251
251
|
def __init__(
|
252
252
|
self,
|
253
253
|
config: Qwen2Config,
|
254
|
-
|
254
|
+
quant_config: Optional[QuantizationConfig] = None,
|
255
|
+
cache_config: Optional[CacheConfig] = None,
|
255
256
|
) -> None:
|
256
257
|
super().__init__()
|
257
258
|
self.config = config
|
258
|
-
self.
|
259
|
-
self.model = Qwen2Model(config,
|
259
|
+
self.quant_config = quant_config
|
260
|
+
self.model = Qwen2Model(config, quant_config=quant_config)
|
260
261
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
261
262
|
self.logits_processor = LogitsProcessor(config)
|
262
263
|
|
@@ -272,13 +273,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
272
273
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
273
274
|
)
|
274
275
|
|
275
|
-
def load_weights(
|
276
|
-
self,
|
277
|
-
model_name_or_path: str,
|
278
|
-
cache_dir: Optional[str] = None,
|
279
|
-
load_format: str = "auto",
|
280
|
-
revision: Optional[str] = None,
|
281
|
-
):
|
276
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
282
277
|
stacked_params_mapping = [
|
283
278
|
# (param_name, shard_name, shard_id)
|
284
279
|
("qkv_proj", "q_proj", "q"),
|
@@ -288,9 +283,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
288
283
|
("gate_up_proj", "up_proj", 1),
|
289
284
|
]
|
290
285
|
params_dict = dict(self.named_parameters())
|
291
|
-
for name, loaded_weight in
|
292
|
-
model_name_or_path, cache_dir, load_format, revision
|
293
|
-
):
|
286
|
+
for name, loaded_weight in weights:
|
294
287
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
295
288
|
continue
|
296
289
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -304,6 +297,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
304
297
|
# Skip loading extra bias for GPTQ models.
|
305
298
|
if name.endswith(".bias") and name not in params_dict:
|
306
299
|
continue
|
300
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
301
|
+
continue
|
307
302
|
param = params_dict[name]
|
308
303
|
weight_loader = param.weight_loader
|
309
304
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -312,6 +307,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
312
307
|
# Skip loading extra bias for GPTQ models.
|
313
308
|
if name.endswith(".bias") and name not in params_dict:
|
314
309
|
continue
|
310
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
311
|
+
continue
|
315
312
|
param = params_dict[name]
|
316
313
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
317
314
|
weight_loader(param, loaded_weight)
|
@@ -0,0 +1,473 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from
|
3
|
+
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
|
4
|
+
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch import nn
|
10
|
+
from transformers import PretrainedConfig
|
11
|
+
from vllm.config import CacheConfig
|
12
|
+
from vllm.distributed import (
|
13
|
+
get_tensor_model_parallel_world_size,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
16
|
+
from vllm.model_executor.layers.activation import SiluAndMul
|
17
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
18
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
19
|
+
from vllm.model_executor.layers.linear import (
|
20
|
+
MergedColumnParallelLinear,
|
21
|
+
QKVParallelLinear,
|
22
|
+
ReplicatedLinear,
|
23
|
+
RowParallelLinear,
|
24
|
+
)
|
25
|
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
26
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
|
+
from vllm.model_executor.layers.sampler import Sampler
|
29
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
30
|
+
ParallelLMHead,
|
31
|
+
VocabParallelEmbedding,
|
32
|
+
)
|
33
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
34
|
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
35
|
+
from vllm.sequence import IntermediateTensors, SamplerOutput
|
36
|
+
|
37
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
40
|
+
|
41
|
+
|
42
|
+
class Qwen2MoeMLP(nn.Module):
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
hidden_size: int,
|
46
|
+
intermediate_size: int,
|
47
|
+
hidden_act: str,
|
48
|
+
quant_config: Optional[QuantizationConfig] = None,
|
49
|
+
reduce_results: bool = True,
|
50
|
+
) -> None:
|
51
|
+
super().__init__()
|
52
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
53
|
+
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
54
|
+
)
|
55
|
+
self.down_proj = RowParallelLinear(
|
56
|
+
intermediate_size,
|
57
|
+
hidden_size,
|
58
|
+
bias=False,
|
59
|
+
quant_config=quant_config,
|
60
|
+
reduce_results=reduce_results,
|
61
|
+
)
|
62
|
+
if hidden_act != "silu":
|
63
|
+
raise ValueError(
|
64
|
+
f"Unsupported activation: {hidden_act}. "
|
65
|
+
"Only silu is supported for now."
|
66
|
+
)
|
67
|
+
self.act_fn = SiluAndMul()
|
68
|
+
|
69
|
+
def forward(self, x):
|
70
|
+
gate_up, _ = self.gate_up_proj(x)
|
71
|
+
x = self.act_fn(gate_up)
|
72
|
+
x, _ = self.down_proj(x)
|
73
|
+
return x
|
74
|
+
|
75
|
+
|
76
|
+
class Qwen2MoeSparseMoeBlock(nn.Module):
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
config: PretrainedConfig,
|
80
|
+
quant_config: Optional[QuantizationConfig] = None,
|
81
|
+
):
|
82
|
+
super().__init__()
|
83
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
84
|
+
|
85
|
+
if self.tp_size > config.num_experts:
|
86
|
+
raise ValueError(
|
87
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
88
|
+
f"the number of experts {config.num_experts}."
|
89
|
+
)
|
90
|
+
|
91
|
+
self.experts = FusedMoE(
|
92
|
+
num_experts=config.num_experts,
|
93
|
+
top_k=config.num_experts_per_tok,
|
94
|
+
hidden_size=config.hidden_size,
|
95
|
+
intermediate_size=config.moe_intermediate_size,
|
96
|
+
reduce_results=False,
|
97
|
+
renormalize=config.norm_topk_prob,
|
98
|
+
quant_config=quant_config,
|
99
|
+
)
|
100
|
+
|
101
|
+
self.gate = ReplicatedLinear(
|
102
|
+
config.hidden_size, config.num_experts, bias=False, quant_config=None
|
103
|
+
)
|
104
|
+
if config.shared_expert_intermediate_size > 0:
|
105
|
+
self.shared_expert = Qwen2MoeMLP(
|
106
|
+
hidden_size=config.hidden_size,
|
107
|
+
intermediate_size=config.shared_expert_intermediate_size,
|
108
|
+
hidden_act=config.hidden_act,
|
109
|
+
quant_config=quant_config,
|
110
|
+
reduce_results=False,
|
111
|
+
)
|
112
|
+
else:
|
113
|
+
self.shared_expert = None
|
114
|
+
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
115
|
+
|
116
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
117
|
+
num_tokens, hidden_dim = hidden_states.shape
|
118
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
119
|
+
shared_output = None
|
120
|
+
if self.shared_expert is not None:
|
121
|
+
shared_output = self.shared_expert(hidden_states)
|
122
|
+
if self.shared_expert_gate is not None:
|
123
|
+
shared_output = (
|
124
|
+
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
125
|
+
)
|
126
|
+
|
127
|
+
# router_logits: (num_tokens, n_experts)
|
128
|
+
router_logits, _ = self.gate(hidden_states)
|
129
|
+
final_hidden_states = self.experts(
|
130
|
+
hidden_states=hidden_states, router_logits=router_logits
|
131
|
+
)
|
132
|
+
if shared_output is not None:
|
133
|
+
final_hidden_states = final_hidden_states + shared_output
|
134
|
+
if self.tp_size > 1:
|
135
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
136
|
+
|
137
|
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
138
|
+
|
139
|
+
|
140
|
+
class Qwen2MoeAttention(nn.Module):
|
141
|
+
def __init__(
|
142
|
+
self,
|
143
|
+
hidden_size: int,
|
144
|
+
num_heads: int,
|
145
|
+
num_kv_heads: int,
|
146
|
+
layer_id: int = 0,
|
147
|
+
rope_theta: float = 10000,
|
148
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
149
|
+
max_position_embeddings: int = 8192,
|
150
|
+
cache_config: Optional[CacheConfig] = None,
|
151
|
+
quant_config: Optional[QuantizationConfig] = None,
|
152
|
+
) -> None:
|
153
|
+
super().__init__()
|
154
|
+
self.hidden_size = hidden_size
|
155
|
+
tp_size = get_tensor_model_parallel_world_size()
|
156
|
+
self.total_num_heads = num_heads
|
157
|
+
assert self.total_num_heads % tp_size == 0
|
158
|
+
self.num_heads = self.total_num_heads // tp_size
|
159
|
+
self.total_num_kv_heads = num_kv_heads
|
160
|
+
if self.total_num_kv_heads >= tp_size:
|
161
|
+
# Number of KV heads is greater than TP size, so we partition
|
162
|
+
# the KV heads across multiple tensor parallel GPUs.
|
163
|
+
assert self.total_num_kv_heads % tp_size == 0
|
164
|
+
else:
|
165
|
+
# Number of KV heads is less than TP size, so we replicate
|
166
|
+
# the KV heads across multiple tensor parallel GPUs.
|
167
|
+
assert tp_size % self.total_num_kv_heads == 0
|
168
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
169
|
+
self.head_dim = hidden_size // self.total_num_heads
|
170
|
+
self.q_size = self.num_heads * self.head_dim
|
171
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
172
|
+
self.scaling = self.head_dim**-0.5
|
173
|
+
self.rope_theta = rope_theta
|
174
|
+
self.max_position_embeddings = max_position_embeddings
|
175
|
+
|
176
|
+
self.qkv_proj = QKVParallelLinear(
|
177
|
+
hidden_size,
|
178
|
+
self.head_dim,
|
179
|
+
self.total_num_heads,
|
180
|
+
self.total_num_kv_heads,
|
181
|
+
bias=True,
|
182
|
+
quant_config=quant_config,
|
183
|
+
)
|
184
|
+
|
185
|
+
self.o_proj = RowParallelLinear(
|
186
|
+
self.total_num_heads * self.head_dim,
|
187
|
+
hidden_size,
|
188
|
+
bias=False,
|
189
|
+
quant_config=quant_config,
|
190
|
+
)
|
191
|
+
|
192
|
+
self.rotary_emb = get_rope(
|
193
|
+
self.head_dim,
|
194
|
+
rotary_dim=self.head_dim,
|
195
|
+
max_position=max_position_embeddings,
|
196
|
+
base=rope_theta,
|
197
|
+
rope_scaling=rope_scaling,
|
198
|
+
)
|
199
|
+
self.attn = RadixAttention(
|
200
|
+
self.num_heads,
|
201
|
+
self.head_dim,
|
202
|
+
self.scaling,
|
203
|
+
num_kv_heads=self.num_kv_heads,
|
204
|
+
layer_id=layer_id,
|
205
|
+
)
|
206
|
+
|
207
|
+
def forward(
|
208
|
+
self,
|
209
|
+
positions: torch.Tensor,
|
210
|
+
hidden_states: torch.Tensor,
|
211
|
+
input_metadata: InputMetadata,
|
212
|
+
) -> torch.Tensor:
|
213
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
214
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
215
|
+
q, k = self.rotary_emb(positions, q, k)
|
216
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
217
|
+
output, _ = self.o_proj(attn_output)
|
218
|
+
return output
|
219
|
+
|
220
|
+
|
221
|
+
class Qwen2MoeDecoderLayer(nn.Module):
|
222
|
+
def __init__(
|
223
|
+
self,
|
224
|
+
config: PretrainedConfig,
|
225
|
+
layer_id: int,
|
226
|
+
cache_config: Optional[CacheConfig] = None,
|
227
|
+
quant_config: Optional[QuantizationConfig] = None,
|
228
|
+
) -> None:
|
229
|
+
super().__init__()
|
230
|
+
self.hidden_size = config.hidden_size
|
231
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
232
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
233
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
234
|
+
self.self_attn = Qwen2MoeAttention(
|
235
|
+
hidden_size=self.hidden_size,
|
236
|
+
num_heads=config.num_attention_heads,
|
237
|
+
num_kv_heads=config.num_key_value_heads,
|
238
|
+
layer_id=layer_id,
|
239
|
+
rope_theta=rope_theta,
|
240
|
+
rope_scaling=rope_scaling,
|
241
|
+
max_position_embeddings=max_position_embeddings,
|
242
|
+
cache_config=cache_config,
|
243
|
+
quant_config=quant_config,
|
244
|
+
)
|
245
|
+
|
246
|
+
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
247
|
+
# `mlp_only_layers` in the config.
|
248
|
+
mlp_only_layers = (
|
249
|
+
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
250
|
+
)
|
251
|
+
if (layer_id not in mlp_only_layers) and (
|
252
|
+
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
253
|
+
):
|
254
|
+
self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config)
|
255
|
+
else:
|
256
|
+
self.mlp = Qwen2MoeMLP(
|
257
|
+
hidden_size=config.hidden_size,
|
258
|
+
intermediate_size=config.intermediate_size,
|
259
|
+
hidden_act=config.hidden_act,
|
260
|
+
quant_config=quant_config,
|
261
|
+
)
|
262
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
263
|
+
self.post_attention_layernorm = RMSNorm(
|
264
|
+
config.hidden_size, eps=config.rms_norm_eps
|
265
|
+
)
|
266
|
+
|
267
|
+
def forward(
|
268
|
+
self,
|
269
|
+
positions: torch.Tensor,
|
270
|
+
hidden_states: torch.Tensor,
|
271
|
+
input_metadata: InputMetadata,
|
272
|
+
residual: Optional[torch.Tensor],
|
273
|
+
) -> torch.Tensor:
|
274
|
+
# Self Attention
|
275
|
+
if residual is None:
|
276
|
+
residual = hidden_states
|
277
|
+
hidden_states = self.input_layernorm(hidden_states)
|
278
|
+
else:
|
279
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
280
|
+
hidden_states = self.self_attn(
|
281
|
+
positions=positions,
|
282
|
+
hidden_states=hidden_states,
|
283
|
+
input_metadata=input_metadata,
|
284
|
+
)
|
285
|
+
|
286
|
+
# Fully Connected
|
287
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
288
|
+
hidden_states = self.mlp(hidden_states)
|
289
|
+
return hidden_states, residual
|
290
|
+
|
291
|
+
|
292
|
+
class Qwen2MoeModel(nn.Module):
|
293
|
+
def __init__(
|
294
|
+
self,
|
295
|
+
config: PretrainedConfig,
|
296
|
+
cache_config: Optional[CacheConfig] = None,
|
297
|
+
quant_config: Optional[QuantizationConfig] = None,
|
298
|
+
) -> None:
|
299
|
+
super().__init__()
|
300
|
+
self.padding_idx = config.pad_token_id
|
301
|
+
self.vocab_size = config.vocab_size
|
302
|
+
|
303
|
+
self.embed_tokens = VocabParallelEmbedding(
|
304
|
+
config.vocab_size,
|
305
|
+
config.hidden_size,
|
306
|
+
)
|
307
|
+
self.layers = nn.ModuleList(
|
308
|
+
[
|
309
|
+
Qwen2MoeDecoderLayer(
|
310
|
+
config, layer_id, cache_config, quant_config=quant_config
|
311
|
+
)
|
312
|
+
for layer_id in range(config.num_hidden_layers)
|
313
|
+
]
|
314
|
+
)
|
315
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
316
|
+
|
317
|
+
def forward(
|
318
|
+
self,
|
319
|
+
input_ids: torch.Tensor,
|
320
|
+
positions: torch.Tensor,
|
321
|
+
input_metadata: InputMetadata,
|
322
|
+
input_embeds: torch.Tensor = None,
|
323
|
+
) -> torch.Tensor:
|
324
|
+
if input_embeds is None:
|
325
|
+
hidden_states = self.embed_tokens(input_ids)
|
326
|
+
else:
|
327
|
+
hidden_states = input_embeds
|
328
|
+
residual = None
|
329
|
+
for i in range(len(self.layers)):
|
330
|
+
layer = self.layers[i]
|
331
|
+
hidden_states, residual = layer(
|
332
|
+
positions, hidden_states, input_metadata, residual
|
333
|
+
)
|
334
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
335
|
+
return hidden_states
|
336
|
+
|
337
|
+
|
338
|
+
class Qwen2MoeForCausalLM(nn.Module):
|
339
|
+
|
340
|
+
fall_back_to_pt_during_load = False
|
341
|
+
|
342
|
+
def __init__(
|
343
|
+
self,
|
344
|
+
config: PretrainedConfig,
|
345
|
+
cache_config: Optional[CacheConfig] = None,
|
346
|
+
quant_config: Optional[QuantizationConfig] = None,
|
347
|
+
) -> None:
|
348
|
+
super().__init__()
|
349
|
+
self.config = config
|
350
|
+
self.quant_config = quant_config
|
351
|
+
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
352
|
+
self.lm_head = ParallelLMHead(
|
353
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
354
|
+
)
|
355
|
+
self.logits_processor = LogitsProcessor(config)
|
356
|
+
self.sampler = Sampler()
|
357
|
+
|
358
|
+
def forward(
|
359
|
+
self,
|
360
|
+
input_ids: torch.Tensor,
|
361
|
+
positions: torch.Tensor,
|
362
|
+
input_metadata: InputMetadata,
|
363
|
+
input_embeds: torch.Tensor = None,
|
364
|
+
) -> torch.Tensor:
|
365
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
366
|
+
return self.logits_processor(
|
367
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
368
|
+
)
|
369
|
+
|
370
|
+
def compute_logits(
|
371
|
+
self,
|
372
|
+
input_ids: torch.Tensor,
|
373
|
+
hidden_states: torch.Tensor,
|
374
|
+
input_metadata: InputMetadata,
|
375
|
+
) -> torch.Tensor:
|
376
|
+
logits = self.logits_processor(
|
377
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
378
|
+
)
|
379
|
+
return logits
|
380
|
+
|
381
|
+
def sample(
|
382
|
+
self,
|
383
|
+
logits: Optional[torch.Tensor],
|
384
|
+
sampling_metadata: SamplingMetadata,
|
385
|
+
) -> Optional[SamplerOutput]:
|
386
|
+
next_tokens = self.sampler(logits, sampling_metadata)
|
387
|
+
return next_tokens
|
388
|
+
|
389
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
390
|
+
stacked_params_mapping = [
|
391
|
+
# (param_name, shard_name, shard_id)
|
392
|
+
("qkv_proj", "q_proj", "q"),
|
393
|
+
("qkv_proj", "k_proj", "k"),
|
394
|
+
("qkv_proj", "v_proj", "v"),
|
395
|
+
("gate_up_proj", "gate_proj", 0),
|
396
|
+
("gate_up_proj", "up_proj", 1),
|
397
|
+
]
|
398
|
+
|
399
|
+
expert_params_mapping = [
|
400
|
+
# These are the weights for the experts
|
401
|
+
# (param_name, weight_name, expert_id, shard_id)
|
402
|
+
(
|
403
|
+
"experts.w13_weight"
|
404
|
+
if weight_name in ["gate_proj", "up_proj"]
|
405
|
+
else "experts.w2_weight",
|
406
|
+
f"experts.{expert_id}.{weight_name}.weight",
|
407
|
+
expert_id,
|
408
|
+
shard_id,
|
409
|
+
)
|
410
|
+
for expert_id in range(self.config.num_experts)
|
411
|
+
for shard_id, weight_name in enumerate(
|
412
|
+
["gate_proj", "down_proj", "up_proj"]
|
413
|
+
)
|
414
|
+
]
|
415
|
+
|
416
|
+
params_dict = dict(self.named_parameters())
|
417
|
+
for name, loaded_weight in weights:
|
418
|
+
if "rotary_emb.inv_freq" in name:
|
419
|
+
continue
|
420
|
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
421
|
+
# Skip non-stacked layers and experts (experts handled below).
|
422
|
+
if weight_name not in name:
|
423
|
+
continue
|
424
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
425
|
+
# Since we handle the experts below in expert_params_mapping,
|
426
|
+
# we need to skip here BEFORE we update the name, otherwise
|
427
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
428
|
+
# will then be updated below in expert_params_mapping
|
429
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
430
|
+
if "mlp.experts" in name:
|
431
|
+
continue
|
432
|
+
name = name.replace(weight_name, param_name)
|
433
|
+
# Skip loading extra bias for GPTQ models.
|
434
|
+
if name.endswith(".bias") and name not in params_dict:
|
435
|
+
continue
|
436
|
+
if name not in params_dict:
|
437
|
+
continue
|
438
|
+
|
439
|
+
param = params_dict[name]
|
440
|
+
weight_loader = param.weight_loader
|
441
|
+
weight_loader(param, loaded_weight, shard_id)
|
442
|
+
break
|
443
|
+
else:
|
444
|
+
for mapping in expert_params_mapping:
|
445
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
446
|
+
if weight_name not in name:
|
447
|
+
continue
|
448
|
+
name = name.replace(weight_name, param_name)
|
449
|
+
param = params_dict[name]
|
450
|
+
weight_loader = param.weight_loader
|
451
|
+
weight_loader(
|
452
|
+
param,
|
453
|
+
loaded_weight,
|
454
|
+
weight_name,
|
455
|
+
shard_id=shard_id,
|
456
|
+
expert_id=expert_id,
|
457
|
+
)
|
458
|
+
break
|
459
|
+
else:
|
460
|
+
# Skip loading extra bias for GPTQ models.
|
461
|
+
if name.endswith(".bias") and name not in params_dict:
|
462
|
+
continue
|
463
|
+
if name not in params_dict:
|
464
|
+
continue
|
465
|
+
|
466
|
+
param = params_dict[name]
|
467
|
+
weight_loader = getattr(
|
468
|
+
param, "weight_loader", default_weight_loader
|
469
|
+
)
|
470
|
+
weight_loader(param, loaded_weight)
|
471
|
+
|
472
|
+
|
473
|
+
EntryClass = Qwen2MoeForCausalLM
|