sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +55 -2
- sglang/api.py +3 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +74 -0
- sglang/lang/interpreter.py +40 -16
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +2 -1
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/router/infer_batch.py +70 -33
- sglang/srt/managers/router/manager.py +7 -2
- sglang/srt/managers/router/model_rpc.py +116 -73
- sglang/srt/managers/router/model_runner.py +111 -167
- sglang/srt/managers/router/radix_cache.py +46 -38
- sglang/srt/managers/tokenizer_manager.py +56 -11
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +7 -0
- sglang/srt/models/commandr.py +376 -0
- sglang/srt/models/dbrx.py +413 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +22 -20
- sglang/srt/models/llama2.py +23 -21
- sglang/srt/models/llava.py +12 -10
- sglang/srt/models/mixtral.py +27 -25
- sglang/srt/models/qwen.py +23 -21
- sglang/srt/models/qwen2.py +23 -21
- sglang/srt/models/stablelm.py +20 -21
- sglang/srt/models/yivl.py +6 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +68 -447
- sglang/srt/server_args.py +76 -49
- sglang/srt/utils.py +88 -32
- sglang/srt/weight_utils.py +402 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
- sglang-0.1.15.dist-info/RECORD +69 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma.py
CHANGED
@@ -4,47 +4,49 @@
|
|
4
4
|
from typing import Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
8
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
9
7
|
from torch import nn
|
10
8
|
from transformers import PretrainedConfig
|
11
9
|
from vllm.config import LoRAConfig
|
12
|
-
from vllm.model_executor.input_metadata import InputMetadata
|
13
10
|
from vllm.model_executor.layers.activation import GeluAndMul
|
14
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
15
12
|
from vllm.model_executor.layers.linear import (
|
16
|
-
LinearMethodBase,
|
17
13
|
MergedColumnParallelLinear,
|
18
14
|
QKVParallelLinear,
|
19
15
|
RowParallelLinear,
|
20
16
|
)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
18
|
+
QuantizationConfig)
|
21
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
23
|
-
from vllm.
|
21
|
+
from vllm.distributed import (
|
24
22
|
get_tensor_model_parallel_world_size,
|
25
23
|
)
|
26
|
-
from
|
24
|
+
from sglang.srt.weight_utils import (
|
27
25
|
default_weight_loader,
|
28
26
|
hf_model_weights_iterator,
|
29
27
|
)
|
30
28
|
|
29
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
30
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
31
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
32
|
+
|
31
33
|
|
32
34
|
class GemmaMLP(nn.Module):
|
33
35
|
def __init__(
|
34
36
|
self,
|
35
37
|
hidden_size: int,
|
36
38
|
intermediate_size: int,
|
37
|
-
|
39
|
+
quant_config: Optional[QuantizationConfig] = None,
|
38
40
|
) -> None:
|
39
41
|
super().__init__()
|
40
42
|
self.gate_up_proj = MergedColumnParallelLinear(
|
41
43
|
hidden_size,
|
42
44
|
[intermediate_size] * 2,
|
43
45
|
bias=False,
|
44
|
-
|
46
|
+
quant_config=quant_config,
|
45
47
|
)
|
46
48
|
self.down_proj = RowParallelLinear(
|
47
|
-
intermediate_size, hidden_size, bias=False,
|
49
|
+
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
48
50
|
)
|
49
51
|
self.act_fn = GeluAndMul()
|
50
52
|
|
@@ -65,7 +67,7 @@ class GemmaAttention(nn.Module):
|
|
65
67
|
layer_id: int = 0,
|
66
68
|
max_position_embeddings: int = 8192,
|
67
69
|
rope_theta: float = 10000,
|
68
|
-
|
70
|
+
quant_config: Optional[QuantizationConfig] = None,
|
69
71
|
) -> None:
|
70
72
|
super().__init__()
|
71
73
|
self.hidden_size = hidden_size
|
@@ -95,13 +97,13 @@ class GemmaAttention(nn.Module):
|
|
95
97
|
self.total_num_heads,
|
96
98
|
self.total_num_kv_heads,
|
97
99
|
bias=False,
|
98
|
-
|
100
|
+
quant_config=quant_config,
|
99
101
|
)
|
100
102
|
self.o_proj = RowParallelLinear(
|
101
103
|
self.total_num_heads * self.head_dim,
|
102
104
|
hidden_size,
|
103
105
|
bias=False,
|
104
|
-
|
106
|
+
quant_config=quant_config,
|
105
107
|
)
|
106
108
|
|
107
109
|
self.rotary_emb = get_rope(
|
@@ -138,7 +140,7 @@ class GemmaDecoderLayer(nn.Module):
|
|
138
140
|
self,
|
139
141
|
config: PretrainedConfig,
|
140
142
|
layer_id: int = 0,
|
141
|
-
|
143
|
+
quant_config: Optional[QuantizationConfig] = None,
|
142
144
|
) -> None:
|
143
145
|
super().__init__()
|
144
146
|
self.hidden_size = config.hidden_size
|
@@ -150,12 +152,12 @@ class GemmaDecoderLayer(nn.Module):
|
|
150
152
|
layer_id=layer_id,
|
151
153
|
max_position_embeddings=config.max_position_embeddings,
|
152
154
|
rope_theta=config.rope_theta,
|
153
|
-
|
155
|
+
quant_config=quant_config,
|
154
156
|
)
|
155
157
|
self.mlp = GemmaMLP(
|
156
158
|
hidden_size=self.hidden_size,
|
157
159
|
intermediate_size=config.intermediate_size,
|
158
|
-
|
160
|
+
quant_config=quant_config,
|
159
161
|
)
|
160
162
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
161
163
|
self.post_attention_layernorm = RMSNorm(
|
@@ -191,7 +193,7 @@ class GemmaModel(nn.Module):
|
|
191
193
|
def __init__(
|
192
194
|
self,
|
193
195
|
config: PretrainedConfig,
|
194
|
-
|
196
|
+
quant_config: Optional[QuantizationConfig] = None,
|
195
197
|
) -> None:
|
196
198
|
super().__init__()
|
197
199
|
self.config = config
|
@@ -202,7 +204,7 @@ class GemmaModel(nn.Module):
|
|
202
204
|
)
|
203
205
|
self.layers = nn.ModuleList(
|
204
206
|
[
|
205
|
-
GemmaDecoderLayer(config, i,
|
207
|
+
GemmaDecoderLayer(config, i, quant_config=quant_config)
|
206
208
|
for i in range(config.num_hidden_layers)
|
207
209
|
]
|
208
210
|
)
|
@@ -263,14 +265,14 @@ class GemmaForCausalLM(nn.Module):
|
|
263
265
|
def __init__(
|
264
266
|
self,
|
265
267
|
config: PretrainedConfig,
|
266
|
-
|
268
|
+
quant_config: Optional[QuantizationConfig] = None,
|
267
269
|
lora_config: Optional[LoRAConfig] = None,
|
268
270
|
) -> None:
|
269
271
|
del lora_config # Unused.
|
270
272
|
super().__init__()
|
271
273
|
self.config = config
|
272
|
-
self.
|
273
|
-
self.model = GemmaModel(config,
|
274
|
+
self.quant_config = quant_config
|
275
|
+
self.model = GemmaModel(config, quant_config=quant_config)
|
274
276
|
self.logits_processor = LogitsProcessor(config)
|
275
277
|
|
276
278
|
@torch.no_grad()
|
sglang/srt/models/llama2.py
CHANGED
@@ -1,35 +1,37 @@
|
|
1
1
|
# Adapted from
|
2
2
|
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
|
3
3
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict,
|
4
|
+
from typing import Any, Dict, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
8
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
9
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
10
7
|
from torch import nn
|
11
8
|
from transformers import LlamaConfig
|
12
9
|
from vllm.model_executor.layers.activation import SiluAndMul
|
13
10
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
14
11
|
from vllm.model_executor.layers.linear import (
|
15
|
-
LinearMethodBase,
|
16
12
|
MergedColumnParallelLinear,
|
17
13
|
QKVParallelLinear,
|
18
14
|
RowParallelLinear,
|
19
15
|
)
|
16
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
17
|
+
QuantizationConfig)
|
20
18
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
21
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
22
20
|
ParallelLMHead,
|
23
21
|
VocabParallelEmbedding,
|
24
22
|
)
|
25
|
-
from vllm.
|
23
|
+
from vllm.distributed import (
|
26
24
|
get_tensor_model_parallel_world_size,
|
27
25
|
)
|
28
|
-
from
|
26
|
+
from sglang.srt.weight_utils import (
|
29
27
|
default_weight_loader,
|
30
28
|
hf_model_weights_iterator,
|
31
29
|
)
|
32
30
|
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
34
|
+
|
33
35
|
|
34
36
|
class LlamaMLP(nn.Module):
|
35
37
|
def __init__(
|
@@ -37,17 +39,17 @@ class LlamaMLP(nn.Module):
|
|
37
39
|
hidden_size: int,
|
38
40
|
intermediate_size: int,
|
39
41
|
hidden_act: str,
|
40
|
-
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
41
43
|
) -> None:
|
42
44
|
super().__init__()
|
43
45
|
self.gate_up_proj = MergedColumnParallelLinear(
|
44
46
|
hidden_size,
|
45
47
|
[intermediate_size] * 2,
|
46
48
|
bias=False,
|
47
|
-
|
49
|
+
quant_config=quant_config,
|
48
50
|
)
|
49
51
|
self.down_proj = RowParallelLinear(
|
50
|
-
intermediate_size, hidden_size, bias=False,
|
52
|
+
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
51
53
|
)
|
52
54
|
if hidden_act != "silu":
|
53
55
|
raise ValueError(
|
@@ -73,7 +75,7 @@ class LlamaAttention(nn.Module):
|
|
73
75
|
rope_theta: float = 10000,
|
74
76
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
75
77
|
max_position_embeddings: int = 8192,
|
76
|
-
|
78
|
+
quant_config: Optional[QuantizationConfig] = None,
|
77
79
|
) -> None:
|
78
80
|
super().__init__()
|
79
81
|
self.hidden_size = hidden_size
|
@@ -104,13 +106,13 @@ class LlamaAttention(nn.Module):
|
|
104
106
|
self.total_num_heads,
|
105
107
|
self.total_num_kv_heads,
|
106
108
|
bias=False,
|
107
|
-
|
109
|
+
quant_config=quant_config,
|
108
110
|
)
|
109
111
|
self.o_proj = RowParallelLinear(
|
110
112
|
self.total_num_heads * self.head_dim,
|
111
113
|
hidden_size,
|
112
114
|
bias=False,
|
113
|
-
|
115
|
+
quant_config=quant_config,
|
114
116
|
)
|
115
117
|
|
116
118
|
self.rotary_emb = get_rope(
|
@@ -147,7 +149,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
147
149
|
self,
|
148
150
|
config: LlamaConfig,
|
149
151
|
layer_id: int = 0,
|
150
|
-
|
152
|
+
quant_config: Optional[QuantizationConfig] = None,
|
151
153
|
) -> None:
|
152
154
|
super().__init__()
|
153
155
|
self.hidden_size = config.hidden_size
|
@@ -162,13 +164,13 @@ class LlamaDecoderLayer(nn.Module):
|
|
162
164
|
rope_theta=rope_theta,
|
163
165
|
rope_scaling=rope_scaling,
|
164
166
|
max_position_embeddings=max_position_embeddings,
|
165
|
-
|
167
|
+
quant_config=quant_config,
|
166
168
|
)
|
167
169
|
self.mlp = LlamaMLP(
|
168
170
|
hidden_size=self.hidden_size,
|
169
171
|
intermediate_size=config.intermediate_size,
|
170
172
|
hidden_act=config.hidden_act,
|
171
|
-
|
173
|
+
quant_config=quant_config,
|
172
174
|
)
|
173
175
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
174
176
|
self.post_attention_layernorm = RMSNorm(
|
@@ -204,7 +206,7 @@ class LlamaModel(nn.Module):
|
|
204
206
|
def __init__(
|
205
207
|
self,
|
206
208
|
config: LlamaConfig,
|
207
|
-
|
209
|
+
quant_config: Optional[QuantizationConfig] = None,
|
208
210
|
) -> None:
|
209
211
|
super().__init__()
|
210
212
|
self.config = config
|
@@ -216,7 +218,7 @@ class LlamaModel(nn.Module):
|
|
216
218
|
)
|
217
219
|
self.layers = nn.ModuleList(
|
218
220
|
[
|
219
|
-
LlamaDecoderLayer(config, i,
|
221
|
+
LlamaDecoderLayer(config, i, quant_config=quant_config)
|
220
222
|
for i in range(config.num_hidden_layers)
|
221
223
|
]
|
222
224
|
)
|
@@ -250,12 +252,12 @@ class LlamaForCausalLM(nn.Module):
|
|
250
252
|
def __init__(
|
251
253
|
self,
|
252
254
|
config: LlamaConfig,
|
253
|
-
|
255
|
+
quant_config: Optional[QuantizationConfig] = None,
|
254
256
|
) -> None:
|
255
257
|
super().__init__()
|
256
258
|
self.config = config
|
257
|
-
self.
|
258
|
-
self.model = LlamaModel(config,
|
259
|
+
self.quant_config = quant_config
|
260
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
259
261
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
260
262
|
self.logits_processor = LogitsProcessor(config)
|
261
263
|
|
sglang/srt/models/llava.py
CHANGED
@@ -4,6 +4,16 @@ from typing import List, Optional
|
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import CLIPVisionModel, LlavaConfig
|
9
|
+
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
11
|
+
QuantizationConfig)
|
12
|
+
from sglang.srt.weight_utils import (
|
13
|
+
default_weight_loader,
|
14
|
+
hf_model_weights_iterator,
|
15
|
+
)
|
16
|
+
|
7
17
|
from sglang.srt.managers.router.infer_batch import ForwardMode
|
8
18
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
9
19
|
from sglang.srt.mm_utils import (
|
@@ -12,21 +22,13 @@ from sglang.srt.mm_utils import (
|
|
12
22
|
unpad_image_shape,
|
13
23
|
)
|
14
24
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
15
|
-
from torch import nn
|
16
|
-
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
17
|
-
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
18
|
-
from vllm.model_executor.layers.linear import LinearMethodBase
|
19
|
-
from vllm.model_executor.weight_utils import (
|
20
|
-
default_weight_loader,
|
21
|
-
hf_model_weights_iterator,
|
22
|
-
)
|
23
25
|
|
24
26
|
|
25
27
|
class LlavaLlamaForCausalLM(nn.Module):
|
26
28
|
def __init__(
|
27
29
|
self,
|
28
30
|
config: LlavaConfig,
|
29
|
-
|
31
|
+
quant_config: Optional[QuantizationConfig] = None,
|
30
32
|
) -> None:
|
31
33
|
super().__init__()
|
32
34
|
self.config = config
|
@@ -34,7 +36,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
34
36
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
35
37
|
self.config.text_config.hidden_size = config.hidden_size
|
36
38
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
37
|
-
self.language_model = LlamaForCausalLM(config,
|
39
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
38
40
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
39
41
|
self.language_model.model.image_newline = nn.Parameter(
|
40
42
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
sglang/srt/models/mixtral.py
CHANGED
@@ -1,40 +1,42 @@
|
|
1
1
|
# Adapted from
|
2
2
|
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
|
3
3
|
"""Inference-only Mixtral model."""
|
4
|
-
from typing import
|
4
|
+
from typing import Optional
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
8
|
import torch.nn.functional as F
|
9
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
10
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
11
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
12
9
|
from torch import nn
|
13
10
|
from transformers import MixtralConfig
|
14
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
15
12
|
from vllm.model_executor.layers.linear import (
|
16
|
-
LinearMethodBase,
|
17
13
|
QKVParallelLinear,
|
18
14
|
ReplicatedLinear,
|
19
15
|
RowParallelLinear,
|
20
16
|
)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
18
|
+
QuantizationConfig)
|
21
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
23
21
|
ParallelLMHead,
|
24
22
|
VocabParallelEmbedding,
|
25
23
|
)
|
26
|
-
from vllm.
|
24
|
+
from vllm.distributed import (
|
27
25
|
tensor_model_parallel_all_reduce,
|
28
26
|
)
|
29
|
-
from vllm.
|
27
|
+
from vllm.distributed import (
|
30
28
|
get_tensor_model_parallel_rank,
|
31
29
|
get_tensor_model_parallel_world_size,
|
32
30
|
)
|
33
|
-
from
|
31
|
+
from sglang.srt.weight_utils import (
|
34
32
|
default_weight_loader,
|
35
33
|
hf_model_weights_iterator,
|
36
34
|
)
|
37
35
|
|
36
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
39
|
+
|
38
40
|
|
39
41
|
class MixtralMLP(nn.Module):
|
40
42
|
def __init__(
|
@@ -42,7 +44,7 @@ class MixtralMLP(nn.Module):
|
|
42
44
|
num_experts: int,
|
43
45
|
hidden_size: int,
|
44
46
|
intermediate_size: int,
|
45
|
-
|
47
|
+
quant_config: Optional[QuantizationConfig] = None,
|
46
48
|
) -> None:
|
47
49
|
super().__init__()
|
48
50
|
self.num_experts = num_experts
|
@@ -50,13 +52,13 @@ class MixtralMLP(nn.Module):
|
|
50
52
|
self.hidden_dim = hidden_size
|
51
53
|
|
52
54
|
self.w1 = ReplicatedLinear(
|
53
|
-
self.hidden_dim, self.ffn_dim, bias=False,
|
55
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
54
56
|
)
|
55
57
|
self.w2 = ReplicatedLinear(
|
56
|
-
self.ffn_dim, self.hidden_dim, bias=False,
|
58
|
+
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
57
59
|
)
|
58
60
|
self.w3 = ReplicatedLinear(
|
59
|
-
self.hidden_dim, self.ffn_dim, bias=False,
|
61
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
60
62
|
)
|
61
63
|
|
62
64
|
# TODO: Use vllm's SiluAndMul
|
@@ -75,7 +77,7 @@ class MixtralMoE(nn.Module):
|
|
75
77
|
def __init__(
|
76
78
|
self,
|
77
79
|
config: MixtralConfig,
|
78
|
-
|
80
|
+
quant_config: Optional[QuantizationConfig] = None,
|
79
81
|
):
|
80
82
|
super().__init__()
|
81
83
|
self.config = config
|
@@ -102,7 +104,7 @@ class MixtralMoE(nn.Module):
|
|
102
104
|
self.num_total_experts,
|
103
105
|
config.hidden_size,
|
104
106
|
config.intermediate_size,
|
105
|
-
|
107
|
+
quant_config=quant_config,
|
106
108
|
)
|
107
109
|
if idx in self.expert_indicies
|
108
110
|
else None
|
@@ -147,7 +149,7 @@ class MixtralAttention(nn.Module):
|
|
147
149
|
layer_id: int = 0,
|
148
150
|
max_position: int = 4096 * 32,
|
149
151
|
rope_theta: float = 10000,
|
150
|
-
|
152
|
+
quant_config: Optional[QuantizationConfig] = None,
|
151
153
|
sliding_window: Optional[int] = None,
|
152
154
|
) -> None:
|
153
155
|
super().__init__()
|
@@ -179,13 +181,13 @@ class MixtralAttention(nn.Module):
|
|
179
181
|
self.total_num_heads,
|
180
182
|
self.total_num_kv_heads,
|
181
183
|
bias=False,
|
182
|
-
|
184
|
+
quant_config=quant_config,
|
183
185
|
)
|
184
186
|
self.o_proj = RowParallelLinear(
|
185
187
|
self.total_num_heads * self.head_dim,
|
186
188
|
hidden_size,
|
187
189
|
bias=False,
|
188
|
-
|
190
|
+
quant_config=quant_config,
|
189
191
|
)
|
190
192
|
self.rotary_emb = get_rope(
|
191
193
|
self.head_dim,
|
@@ -221,7 +223,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
221
223
|
self,
|
222
224
|
config: MixtralConfig,
|
223
225
|
layer_id: int = 0,
|
224
|
-
|
226
|
+
quant_config: Optional[QuantizationConfig] = None,
|
225
227
|
) -> None:
|
226
228
|
super().__init__()
|
227
229
|
self.hidden_size = config.hidden_size
|
@@ -235,9 +237,9 @@ class MixtralDecoderLayer(nn.Module):
|
|
235
237
|
layer_id=layer_id,
|
236
238
|
rope_theta=rope_theta,
|
237
239
|
sliding_window=config.sliding_window,
|
238
|
-
|
240
|
+
quant_config=quant_config,
|
239
241
|
)
|
240
|
-
self.block_sparse_moe = MixtralMoE(config=config,
|
242
|
+
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
241
243
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
242
244
|
self.post_attention_layernorm = RMSNorm(
|
243
245
|
config.hidden_size, eps=config.rms_norm_eps
|
@@ -272,7 +274,7 @@ class MixtralModel(nn.Module):
|
|
272
274
|
def __init__(
|
273
275
|
self,
|
274
276
|
config: MixtralConfig,
|
275
|
-
|
277
|
+
quant_config: Optional[QuantizationConfig] = None,
|
276
278
|
) -> None:
|
277
279
|
super().__init__()
|
278
280
|
self.padding_idx = config.pad_token_id
|
@@ -285,7 +287,7 @@ class MixtralModel(nn.Module):
|
|
285
287
|
# config.num_hidden_layers=16
|
286
288
|
self.layers = nn.ModuleList(
|
287
289
|
[
|
288
|
-
MixtralDecoderLayer(config, i,
|
290
|
+
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
289
291
|
for i in range(config.num_hidden_layers)
|
290
292
|
]
|
291
293
|
)
|
@@ -316,12 +318,12 @@ class MixtralForCausalLM(nn.Module):
|
|
316
318
|
def __init__(
|
317
319
|
self,
|
318
320
|
config: MixtralConfig,
|
319
|
-
|
321
|
+
quant_config: Optional[QuantizationConfig] = None,
|
320
322
|
) -> None:
|
321
323
|
super().__init__()
|
322
324
|
self.config = config
|
323
|
-
self.
|
324
|
-
self.model = MixtralModel(config,
|
325
|
+
self.quant_config = quant_config
|
326
|
+
self.model = MixtralModel(config, quant_config=quant_config)
|
325
327
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
326
328
|
self.logits_processor = LogitsProcessor(config)
|
327
329
|
|
sglang/srt/models/qwen.py
CHANGED
@@ -1,32 +1,34 @@
|
|
1
|
-
from typing import Any, Dict,
|
1
|
+
from typing import Any, Dict, Optional
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
5
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
6
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
7
4
|
from torch import nn
|
8
5
|
from transformers import PretrainedConfig
|
9
6
|
from vllm.model_executor.layers.activation import SiluAndMul
|
10
7
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
11
8
|
from vllm.model_executor.layers.linear import (
|
12
|
-
LinearMethodBase,
|
13
9
|
MergedColumnParallelLinear,
|
14
10
|
QKVParallelLinear,
|
15
11
|
RowParallelLinear,
|
16
12
|
)
|
13
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
14
|
+
QuantizationConfig)
|
17
15
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
18
16
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
19
17
|
ParallelLMHead,
|
20
18
|
VocabParallelEmbedding,
|
21
19
|
)
|
22
|
-
from vllm.
|
20
|
+
from vllm.distributed import (
|
23
21
|
get_tensor_model_parallel_world_size,
|
24
22
|
)
|
25
|
-
from
|
23
|
+
from sglang.srt.weight_utils import (
|
26
24
|
default_weight_loader,
|
27
25
|
hf_model_weights_iterator,
|
28
26
|
)
|
29
27
|
|
28
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
31
|
+
|
30
32
|
|
31
33
|
class QWenMLP(nn.Module):
|
32
34
|
def __init__(
|
@@ -34,7 +36,7 @@ class QWenMLP(nn.Module):
|
|
34
36
|
hidden_size: int,
|
35
37
|
intermediate_size: int,
|
36
38
|
hidden_act: str = "silu",
|
37
|
-
|
39
|
+
quant_config: Optional[QuantizationConfig] = None,
|
38
40
|
):
|
39
41
|
super().__init__()
|
40
42
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -42,14 +44,14 @@ class QWenMLP(nn.Module):
|
|
42
44
|
2 * [intermediate_size],
|
43
45
|
bias=False,
|
44
46
|
gather_output=False,
|
45
|
-
|
47
|
+
quant_config=quant_config,
|
46
48
|
)
|
47
49
|
self.c_proj = RowParallelLinear(
|
48
50
|
intermediate_size,
|
49
51
|
hidden_size,
|
50
52
|
bias=False,
|
51
53
|
input_is_parallel=True,
|
52
|
-
|
54
|
+
quant_config=quant_config,
|
53
55
|
)
|
54
56
|
if hidden_act != "silu":
|
55
57
|
raise ValueError(
|
@@ -74,7 +76,7 @@ class QWenAttention(nn.Module):
|
|
74
76
|
layer_id: int = 0,
|
75
77
|
rope_theta: float = 10000,
|
76
78
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
77
|
-
|
79
|
+
quant_config: Optional[QuantizationConfig] = None,
|
78
80
|
):
|
79
81
|
super().__init__()
|
80
82
|
self.hidden_size = hidden_size
|
@@ -90,14 +92,14 @@ class QWenAttention(nn.Module):
|
|
90
92
|
self.head_dim,
|
91
93
|
self.total_num_heads,
|
92
94
|
bias=True,
|
93
|
-
|
95
|
+
quant_config=quant_config,
|
94
96
|
)
|
95
97
|
self.c_proj = RowParallelLinear(
|
96
98
|
self.total_num_heads * self.head_dim,
|
97
99
|
hidden_size,
|
98
100
|
bias=False,
|
99
101
|
input_is_parallel=True,
|
100
|
-
|
102
|
+
quant_config=quant_config,
|
101
103
|
)
|
102
104
|
self.rotary_emb = get_rope(
|
103
105
|
self.head_dim,
|
@@ -130,7 +132,7 @@ class QWenAttention(nn.Module):
|
|
130
132
|
|
131
133
|
|
132
134
|
class QWenBlock(nn.Module):
|
133
|
-
def __init__(self, config: PretrainedConfig, layer_id,
|
135
|
+
def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
|
134
136
|
super().__init__()
|
135
137
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
136
138
|
|
@@ -143,7 +145,7 @@ class QWenBlock(nn.Module):
|
|
143
145
|
rope_theta=rope_theta,
|
144
146
|
rope_scaling=rope_scaling,
|
145
147
|
layer_id=layer_id,
|
146
|
-
|
148
|
+
quant_config=quant_config,
|
147
149
|
)
|
148
150
|
|
149
151
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
@@ -151,7 +153,7 @@ class QWenBlock(nn.Module):
|
|
151
153
|
self.mlp = QWenMLP(
|
152
154
|
config.hidden_size,
|
153
155
|
config.intermediate_size // 2,
|
154
|
-
|
156
|
+
quant_config=quant_config,
|
155
157
|
)
|
156
158
|
|
157
159
|
def forward(
|
@@ -179,7 +181,7 @@ class QWenBlock(nn.Module):
|
|
179
181
|
|
180
182
|
|
181
183
|
class QWenModel(nn.Module):
|
182
|
-
def __init__(self, config: PretrainedConfig,
|
184
|
+
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
183
185
|
super().__init__()
|
184
186
|
self.config = config
|
185
187
|
self.vocab_size = config.vocab_size
|
@@ -191,7 +193,7 @@ class QWenModel(nn.Module):
|
|
191
193
|
)
|
192
194
|
self.h = nn.ModuleList(
|
193
195
|
[
|
194
|
-
QWenBlock(config, i,
|
196
|
+
QWenBlock(config, i, quant_config=quant_config)
|
195
197
|
for i in range(config.num_hidden_layers)
|
196
198
|
]
|
197
199
|
)
|
@@ -216,10 +218,10 @@ class QWenModel(nn.Module):
|
|
216
218
|
|
217
219
|
|
218
220
|
class QWenLMHeadModel(nn.Module):
|
219
|
-
def __init__(self, config: PretrainedConfig,
|
221
|
+
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
220
222
|
super().__init__()
|
221
223
|
self.config = config
|
222
|
-
self.transformer = QWenModel(config,
|
224
|
+
self.transformer = QWenModel(config, quant_config=quant_config)
|
223
225
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
224
226
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
225
227
|
self.logits_processor = LogitsProcessor(config)
|
@@ -274,4 +276,4 @@ class QWenLMHeadModel(nn.Module):
|
|
274
276
|
weight_loader(param, loaded_weight)
|
275
277
|
|
276
278
|
|
277
|
-
EntryClass = QWenLMHeadModel
|
279
|
+
EntryClass = QWenLMHeadModel
|