sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +51 -13
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +6 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +7 -6
- sglang/srt/managers/detokenizer_manager.py +9 -11
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +70 -78
- sglang/srt/managers/schedule_batch.py +33 -49
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +137 -80
- sglang/srt/managers/tokenizer_manager.py +224 -336
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/model_runner.py +8 -17
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/server.py +31 -35
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/runners.py +2 -1
- sglang/test/test_utils.py +73 -25
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post2.dist-info/METADATA +0 -899
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,7 +15,6 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A tensor parallel worker."""
|
17
17
|
|
18
|
-
import json
|
19
18
|
import logging
|
20
19
|
from typing import Optional
|
21
20
|
|
@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
|
|
26
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
26
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
28
27
|
from sglang.srt.server_args import ServerArgs
|
29
|
-
from sglang.srt.utils import broadcast_pyobj,
|
28
|
+
from sglang.srt.utils import broadcast_pyobj, set_random_seed
|
30
29
|
|
31
30
|
logger = logging.getLogger(__name__)
|
32
31
|
|
@@ -48,9 +47,10 @@ class TpModelWorker:
|
|
48
47
|
# Init model and tokenizer
|
49
48
|
self.model_config = ModelConfig(
|
50
49
|
server_args.model_path,
|
51
|
-
server_args.trust_remote_code,
|
50
|
+
trust_remote_code=server_args.trust_remote_code,
|
52
51
|
context_length=server_args.context_length,
|
53
|
-
model_override_args=
|
52
|
+
model_override_args=server_args.json_model_override_args,
|
53
|
+
is_embedding=server_args.is_embedding,
|
54
54
|
)
|
55
55
|
self.model_runner = ModelRunner(
|
56
56
|
model_config=self.model_config,
|
@@ -64,7 +64,7 @@ class TpModelWorker:
|
|
64
64
|
if server_args.skip_tokenizer_init:
|
65
65
|
self.tokenizer = self.processor = None
|
66
66
|
else:
|
67
|
-
if
|
67
|
+
if self.model_config.is_multimodal:
|
68
68
|
self.processor = get_processor(
|
69
69
|
server_args.tokenizer_path,
|
70
70
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -29,5 +29,5 @@ if __name__ == "__main__":
|
|
29
29
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
30
|
args = parser.parse_args()
|
31
31
|
|
32
|
-
response = requests.
|
32
|
+
response = requests.post(args.url + "/flush_cache")
|
33
33
|
assert response.status_code == 200
|
@@ -113,18 +113,21 @@ class CudaGraphRunner:
|
|
113
113
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
114
114
|
|
115
115
|
# Batch sizes to capture
|
116
|
-
if
|
116
|
+
if model_runner.server_args.disable_cuda_graph_padding:
|
117
117
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
118
118
|
else:
|
119
|
-
self.capture_bs = [1, 2,
|
119
|
+
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
120
120
|
self.capture_bs = [
|
121
|
-
bs
|
121
|
+
bs
|
122
|
+
for bs in self.capture_bs
|
123
|
+
if bs <= model_runner.req_to_token_pool.size
|
124
|
+
and bs <= model_runner.server_args.cuda_graph_max_bs
|
122
125
|
]
|
123
126
|
self.compile_bs = (
|
124
127
|
[
|
125
128
|
bs
|
126
129
|
for bs in self.capture_bs
|
127
|
-
if bs <= self.model_runner.server_args.
|
130
|
+
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
128
131
|
]
|
129
132
|
if self.use_torch_compile
|
130
133
|
else []
|
@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
|
|
59
59
|
from sglang.srt.utils import (
|
60
60
|
enable_show_time_cost,
|
61
61
|
get_available_gpu_memory,
|
62
|
-
is_attention_free_model,
|
63
|
-
is_embedding_model,
|
64
|
-
is_generation_model,
|
65
|
-
is_multimodal_model,
|
66
|
-
model_has_inner_state,
|
67
62
|
monkey_patch_vllm_dummy_weight_loader,
|
68
63
|
monkey_patch_vllm_p2p_access_check,
|
69
64
|
)
|
@@ -93,9 +88,8 @@ class ModelRunner:
|
|
93
88
|
self.tp_size = tp_size
|
94
89
|
self.dist_port = nccl_port
|
95
90
|
self.server_args = server_args
|
96
|
-
self.
|
97
|
-
|
98
|
-
)
|
91
|
+
self.is_generation = model_config.is_generation
|
92
|
+
self.is_multimodal = model_config.is_multimodal
|
99
93
|
|
100
94
|
# Model-specific adjustment
|
101
95
|
if (
|
@@ -119,12 +113,12 @@ class ModelRunner:
|
|
119
113
|
self.server_args.ds_heavy_channel_type
|
120
114
|
)
|
121
115
|
|
122
|
-
if self.
|
116
|
+
if self.is_multimodal:
|
123
117
|
logger.warning(
|
124
118
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
125
119
|
)
|
126
120
|
server_args.chunked_prefill_size = None
|
127
|
-
|
121
|
+
self.mem_fraction_static *= 0.95
|
128
122
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
129
123
|
if self.model_config.hf_config.architectures == [
|
130
124
|
"Qwen2VLForConditionalGeneration"
|
@@ -270,9 +264,6 @@ class ModelRunner:
|
|
270
264
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
271
265
|
else None
|
272
266
|
)
|
273
|
-
self.is_generation = is_generation_model(
|
274
|
-
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
275
|
-
)
|
276
267
|
|
277
268
|
logger.info(
|
278
269
|
f"Load weight end. "
|
@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
679
670
|
|
680
671
|
# Monkey patch model loader
|
681
672
|
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
682
|
-
setattr(ModelRegistry, "is_multimodal_model",
|
683
|
-
setattr(ModelRegistry, "is_attention_free_model",
|
684
|
-
setattr(ModelRegistry, "model_has_inner_state",
|
685
|
-
setattr(ModelRegistry, "is_embedding_model",
|
673
|
+
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
|
674
|
+
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
|
675
|
+
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
|
676
|
+
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
|
sglang/srt/models/baichuan.py
CHANGED
@@ -34,10 +34,6 @@ from vllm.model_executor.layers.linear import (
|
|
34
34
|
RowParallelLinear,
|
35
35
|
)
|
36
36
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
38
|
-
ParallelLMHead,
|
39
|
-
VocabParallelEmbedding,
|
40
|
-
)
|
41
37
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
42
38
|
|
43
39
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -45,6 +41,10 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|
45
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
46
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
47
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
|
+
ParallelLMHead,
|
46
|
+
VocabParallelEmbedding,
|
47
|
+
)
|
48
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
49
|
|
50
50
|
|
sglang/srt/models/chatglm.py
CHANGED
@@ -24,10 +24,6 @@ from torch import nn
|
|
24
24
|
from torch.nn import LayerNorm
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
28
|
-
ParallelLMHead,
|
29
|
-
VocabParallelEmbedding,
|
30
|
-
)
|
31
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
32
28
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
33
29
|
|
@@ -41,6 +37,10 @@ from sglang.srt.layers.linear import (
|
|
41
37
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
41
|
+
ParallelLMHead,
|
42
|
+
VocabParallelEmbedding,
|
43
|
+
)
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
45
|
|
46
46
|
LoraConfig = None
|
sglang/srt/models/commandr.py
CHANGED
@@ -50,7 +50,6 @@ from vllm.distributed import (
|
|
50
50
|
get_tensor_model_parallel_world_size,
|
51
51
|
)
|
52
52
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
53
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
54
53
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
55
54
|
|
56
55
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -62,6 +61,7 @@ from sglang.srt.layers.linear import (
|
|
62
61
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
63
62
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
64
63
|
from sglang.srt.layers.radix_attention import RadixAttention
|
64
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
65
65
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
66
66
|
from sglang.srt.utils import set_weight_attrs
|
67
67
|
|
sglang/srt/models/dbrx.py
CHANGED
@@ -27,11 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
)
|
28
28
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
31
|
-
DEFAULT_VOCAB_PADDING_SIZE,
|
32
|
-
ParallelLMHead,
|
33
|
-
VocabParallelEmbedding,
|
34
|
-
)
|
35
30
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
36
31
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
37
32
|
|
@@ -43,6 +38,11 @@ from sglang.srt.layers.linear import (
|
|
43
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
43
|
+
ParallelLMHead,
|
44
|
+
VocabParallelEmbedding,
|
45
|
+
)
|
46
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
47
|
from sglang.srt.utils import set_weight_attrs
|
48
48
|
|
sglang/srt/models/deepseek.py
CHANGED
@@ -28,10 +28,6 @@ from vllm.distributed import (
|
|
28
28
|
)
|
29
29
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
30
30
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
32
|
-
ParallelLMHead,
|
33
|
-
VocabParallelEmbedding,
|
34
|
-
)
|
35
31
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
36
32
|
|
37
33
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -45,6 +41,10 @@ from sglang.srt.layers.linear import (
|
|
45
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
46
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
47
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
|
+
ParallelLMHead,
|
46
|
+
VocabParallelEmbedding,
|
47
|
+
)
|
48
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
49
|
|
50
50
|
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -27,10 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
)
|
28
28
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
31
|
-
ParallelLMHead,
|
32
|
-
VocabParallelEmbedding,
|
33
|
-
)
|
34
30
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
35
31
|
|
36
32
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -44,6 +40,10 @@ from sglang.srt.layers.linear import (
|
|
44
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
46
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
|
+
ParallelLMHead,
|
45
|
+
VocabParallelEmbedding,
|
46
|
+
)
|
47
47
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
49
|
from sglang.srt.utils import is_flashinfer_available
|
sglang/srt/models/exaone.py
CHANGED
@@ -23,10 +23,6 @@ import torch
|
|
23
23
|
from torch import nn
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
27
|
-
ParallelLMHead,
|
28
|
-
VocabParallelEmbedding,
|
29
|
-
)
|
30
26
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
27
|
|
32
28
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
|
|
39
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
40
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
39
|
+
ParallelLMHead,
|
40
|
+
VocabParallelEmbedding,
|
41
|
+
)
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
|
44
44
|
|
sglang/srt/models/gemma.py
CHANGED
@@ -24,7 +24,6 @@ from transformers import PretrainedConfig
|
|
24
24
|
from vllm.config import LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
28
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
29
28
|
|
30
29
|
from sglang.srt.layers.activation import GeluAndMul
|
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
|
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
41
|
|
42
42
|
|
sglang/srt/models/gemma2.py
CHANGED
@@ -24,7 +24,6 @@ from vllm.config import LoRAConfig
|
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
25
|
|
26
26
|
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
27
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
28
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
29
28
|
|
30
29
|
from sglang.srt.layers.activation import GeluAndMul
|
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
|
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
41
|
|
42
42
|
|
@@ -0,0 +1,287 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from
|
3
|
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
4
|
+
# Copyright 2023 The vLLM team.
|
5
|
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
6
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
7
|
+
#
|
8
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9
|
+
# you may not use this file except in compliance with the License.
|
10
|
+
# You may obtain a copy of the License at
|
11
|
+
#
|
12
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13
|
+
#
|
14
|
+
# Unless required by applicable law or agreed to in writing, software
|
15
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17
|
+
# See the License for the specific language governing permissions and
|
18
|
+
# limitations under the License.
|
19
|
+
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
20
|
+
from typing import Iterable, List, Optional, Tuple
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import GPT2Config
|
25
|
+
from vllm.config import CacheConfig
|
26
|
+
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
27
|
+
from vllm.model_executor.layers.activation import get_act_fn
|
28
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
29
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
|
+
|
31
|
+
#from sglang.srt.layers.activation import get_act_fn
|
32
|
+
from sglang.srt.layers.linear import (
|
33
|
+
ColumnParallelLinear,
|
34
|
+
QKVParallelLinear,
|
35
|
+
RowParallelLinear,
|
36
|
+
)
|
37
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
41
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
|
43
|
+
|
44
|
+
class GPT2Attention(nn.Module):
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
layer_id: int,
|
49
|
+
config: GPT2Config,
|
50
|
+
cache_config = None,
|
51
|
+
quant_config: Optional[QuantizationConfig] = None,
|
52
|
+
prefix: str = "",
|
53
|
+
):
|
54
|
+
super().__init__()
|
55
|
+
self.hidden_size = config.hidden_size
|
56
|
+
total_num_heads = config.num_attention_heads
|
57
|
+
tensor_model_parallel_world_size = (
|
58
|
+
get_tensor_model_parallel_world_size())
|
59
|
+
assert total_num_heads % tensor_model_parallel_world_size == 0
|
60
|
+
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
61
|
+
self.head_dim = self.hidden_size // total_num_heads
|
62
|
+
self.scale = self.head_dim**-0.5
|
63
|
+
|
64
|
+
self.c_attn = QKVParallelLinear(
|
65
|
+
self.hidden_size,
|
66
|
+
self.head_dim,
|
67
|
+
total_num_heads,
|
68
|
+
bias=True,
|
69
|
+
quant_config=quant_config,
|
70
|
+
prefix=f"{prefix}.c_attn",
|
71
|
+
)
|
72
|
+
self.c_proj = RowParallelLinear(
|
73
|
+
self.hidden_size,
|
74
|
+
self.hidden_size,
|
75
|
+
bias=True,
|
76
|
+
quant_config=quant_config,
|
77
|
+
prefix=f"{prefix}.c_proj",
|
78
|
+
)
|
79
|
+
self.attn = RadixAttention(self.num_heads,
|
80
|
+
self.head_dim,
|
81
|
+
scaling=self.scale,
|
82
|
+
num_kv_heads=total_num_heads,
|
83
|
+
layer_id=layer_id)
|
84
|
+
|
85
|
+
def forward(
|
86
|
+
self,
|
87
|
+
hidden_states: torch.Tensor,
|
88
|
+
forward_batch: ForwardBatch,
|
89
|
+
) -> torch.Tensor:
|
90
|
+
qkv, _ = self.c_attn(hidden_states)
|
91
|
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
92
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
93
|
+
attn_output, _ = self.c_proj(attn_output)
|
94
|
+
return attn_output
|
95
|
+
|
96
|
+
|
97
|
+
class GPT2MLP(nn.Module):
|
98
|
+
|
99
|
+
def __init__(
|
100
|
+
self,
|
101
|
+
intermediate_size: int,
|
102
|
+
config: GPT2Config,
|
103
|
+
quant_config: Optional[QuantizationConfig] = None,
|
104
|
+
prefix: str = "",
|
105
|
+
):
|
106
|
+
super().__init__()
|
107
|
+
hidden_size = config.hidden_size
|
108
|
+
self.c_fc = ColumnParallelLinear(
|
109
|
+
hidden_size,
|
110
|
+
intermediate_size,
|
111
|
+
bias=True,
|
112
|
+
quant_config=quant_config,
|
113
|
+
prefix=f"{prefix}.c_fc",
|
114
|
+
)
|
115
|
+
self.c_proj = RowParallelLinear(
|
116
|
+
intermediate_size,
|
117
|
+
hidden_size,
|
118
|
+
bias=True,
|
119
|
+
quant_config=quant_config,
|
120
|
+
prefix=f"{prefix}.c_proj",
|
121
|
+
)
|
122
|
+
self.act = get_act_fn(config.activation_function, quant_config,
|
123
|
+
intermediate_size)
|
124
|
+
|
125
|
+
def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:
|
126
|
+
hidden_states, _ = self.c_fc(hidden_states)
|
127
|
+
hidden_states = self.act(hidden_states)
|
128
|
+
hidden_states, _ = self.c_proj(hidden_states)
|
129
|
+
return hidden_states
|
130
|
+
|
131
|
+
|
132
|
+
class GPT2Block(nn.Module):
|
133
|
+
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
layer_id: int,
|
137
|
+
config: GPT2Config,
|
138
|
+
cache_config = None,
|
139
|
+
|
140
|
+
quant_config: Optional[QuantizationConfig] = None,
|
141
|
+
prefix: str = "",
|
142
|
+
):
|
143
|
+
super().__init__()
|
144
|
+
hidden_size = config.hidden_size
|
145
|
+
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
146
|
+
hidden_size)
|
147
|
+
|
148
|
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
149
|
+
self.attn = GPT2Attention(layer_id,
|
150
|
+
config,
|
151
|
+
cache_config,
|
152
|
+
quant_config,
|
153
|
+
prefix=f"{prefix}.attn")
|
154
|
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
155
|
+
self.mlp = GPT2MLP(inner_dim,
|
156
|
+
config,
|
157
|
+
quant_config,
|
158
|
+
prefix=f"{prefix}.mlp")
|
159
|
+
|
160
|
+
def forward(
|
161
|
+
self,
|
162
|
+
hidden_states: torch.Tensor,
|
163
|
+
forward_batch: ForwardBatch,
|
164
|
+
) -> torch.Tensor:
|
165
|
+
residual = hidden_states
|
166
|
+
hidden_states = self.ln_1(hidden_states)
|
167
|
+
attn_output = self.attn(
|
168
|
+
hidden_states=hidden_states,
|
169
|
+
forward_batch=forward_batch,
|
170
|
+
)
|
171
|
+
# residual connection
|
172
|
+
hidden_states = attn_output + residual
|
173
|
+
|
174
|
+
residual = hidden_states
|
175
|
+
hidden_states = self.ln_2(hidden_states)
|
176
|
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
177
|
+
# residual connection
|
178
|
+
hidden_states = residual + feed_forward_hidden_states
|
179
|
+
return hidden_states
|
180
|
+
|
181
|
+
|
182
|
+
|
183
|
+
class GPT2Model(nn.Module):
|
184
|
+
|
185
|
+
def __init__(
|
186
|
+
self,
|
187
|
+
config: GPT2Config,
|
188
|
+
cache_config = None,
|
189
|
+
quant_config: Optional[QuantizationConfig] = None,
|
190
|
+
prefix: str = "",
|
191
|
+
):
|
192
|
+
super().__init__()
|
193
|
+
self.config = config
|
194
|
+
assert not config.add_cross_attention
|
195
|
+
assert not config.scale_attn_by_inverse_layer_idx
|
196
|
+
assert not config.reorder_and_upcast_attn
|
197
|
+
self.embed_dim = config.hidden_size
|
198
|
+
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
199
|
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
200
|
+
self.h = nn.ModuleList(
|
201
|
+
[
|
202
|
+
GPT2Block(i, config, cache_config, quant_config)
|
203
|
+
for i in range(config.num_hidden_layers)
|
204
|
+
]
|
205
|
+
)
|
206
|
+
|
207
|
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
208
|
+
|
209
|
+
def forward(
|
210
|
+
self,
|
211
|
+
input_ids: torch.Tensor,
|
212
|
+
position_ids: torch.Tensor,
|
213
|
+
forward_batch: ForwardBatch,
|
214
|
+
) -> torch.Tensor:
|
215
|
+
inputs_embeds = self.wte(input_ids)
|
216
|
+
position_embeds = self.wpe(position_ids)
|
217
|
+
hidden_states = inputs_embeds + position_embeds
|
218
|
+
|
219
|
+
for i in range(len(self.h)):
|
220
|
+
layer = self.h[i]
|
221
|
+
hidden_states = layer(hidden_states, forward_batch)
|
222
|
+
|
223
|
+
hidden_states = self.ln_f(hidden_states)
|
224
|
+
return hidden_states
|
225
|
+
|
226
|
+
|
227
|
+
class GPT2LMHeadModel(nn.Module):
|
228
|
+
|
229
|
+
def __init__(
|
230
|
+
self,
|
231
|
+
config: GPT2Config,
|
232
|
+
cache_config = None,
|
233
|
+
quant_config: Optional[QuantizationConfig] = None,
|
234
|
+
):
|
235
|
+
super().__init__()
|
236
|
+
self.config = config
|
237
|
+
self.quant_config = quant_config
|
238
|
+
self.transformer = GPT2Model(config,
|
239
|
+
cache_config,
|
240
|
+
quant_config,
|
241
|
+
prefix="transformer")
|
242
|
+
self.lm_head = self.transformer.wte
|
243
|
+
|
244
|
+
self.logits_processor = LogitsProcessor(config)
|
245
|
+
|
246
|
+
def forward(
|
247
|
+
self,
|
248
|
+
input_ids: torch.Tensor,
|
249
|
+
positions: torch.Tensor,
|
250
|
+
forward_batch: ForwardBatch,
|
251
|
+
) -> torch.Tensor:
|
252
|
+
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
253
|
+
return self.logits_processor(
|
254
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
255
|
+
)
|
256
|
+
|
257
|
+
|
258
|
+
|
259
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
260
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
261
|
+
for name, loaded_weight in weights:
|
262
|
+
if "lm_head.weight" in name:
|
263
|
+
# GPT-2 ties the weights of the embedding layer and the final
|
264
|
+
# linear layer.
|
265
|
+
continue
|
266
|
+
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
267
|
+
# Skip attention mask.
|
268
|
+
# NOTE: "c_attn.bias" should not be skipped.
|
269
|
+
continue
|
270
|
+
if not name.startswith("transformer."):
|
271
|
+
name = "transformer." + name
|
272
|
+
|
273
|
+
param = params_dict[name]
|
274
|
+
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
275
|
+
# Because of this, we need to transpose the weights.
|
276
|
+
# Note(zhuohan): the logic below might break quantized models.
|
277
|
+
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
278
|
+
if conv1d_weight_name not in name:
|
279
|
+
continue
|
280
|
+
if not name.endswith(".weight"):
|
281
|
+
continue
|
282
|
+
loaded_weight = loaded_weight.t()
|
283
|
+
weight_loader = getattr(param, "weight_loader",
|
284
|
+
default_weight_loader)
|
285
|
+
weight_loader(param, loaded_weight)
|
286
|
+
|
287
|
+
EntryClass = GPT2LMHeadModel
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
24
|
from vllm.config import LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
27
26
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
27
|
|
29
28
|
from sglang.srt.layers.activation import get_act_fn
|
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|
35
34
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
35
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
38
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
39
|
|
40
40
|
|
sglang/srt/models/grok.py
CHANGED
@@ -28,10 +28,6 @@ from vllm.distributed import (
|
|
28
28
|
get_tensor_model_parallel_world_size,
|
29
29
|
)
|
30
30
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
32
|
-
ParallelLMHead,
|
33
|
-
VocabParallelEmbedding,
|
34
|
-
)
|
35
31
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
36
32
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
37
33
|
|
@@ -45,6 +41,10 @@ from sglang.srt.layers.linear import (
|
|
45
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
46
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
47
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
|
+
ParallelLMHead,
|
46
|
+
VocabParallelEmbedding,
|
47
|
+
)
|
48
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
49
|
|
50
50
|
|
sglang/srt/models/internlm2.py
CHANGED
@@ -23,10 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
27
|
-
ParallelLMHead,
|
28
|
-
VocabParallelEmbedding,
|
29
|
-
)
|
30
26
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
27
|
|
32
28
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
|
|
39
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
39
|
+
ParallelLMHead,
|
40
|
+
VocabParallelEmbedding,
|
41
|
+
)
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
|
44
44
|
|