sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
|
|
1
|
+
"""TokenizerManager is a process that tokenizes the text."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import concurrent.futures
|
3
5
|
import dataclasses
|
4
6
|
import logging
|
5
7
|
import multiprocessing as mp
|
6
8
|
import os
|
7
|
-
from typing import
|
9
|
+
from typing import Dict, List
|
8
10
|
|
9
11
|
import numpy as np
|
10
12
|
import transformers
|
@@ -22,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
|
|
22
24
|
from sglang.srt.managers.io_struct import (
|
23
25
|
AbortReq,
|
24
26
|
BatchStrOut,
|
27
|
+
BatchTokenIDOut,
|
25
28
|
FlushCacheReq,
|
26
29
|
GenerateReqInput,
|
27
30
|
TokenizedGenerateReqInput,
|
28
31
|
)
|
29
|
-
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
30
32
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
31
33
|
from sglang.srt.sampling_params import SamplingParams
|
32
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -90,7 +92,7 @@ class TokenizerManager:
|
|
90
92
|
)
|
91
93
|
|
92
94
|
self.to_create_loop = True
|
93
|
-
self.rid_to_state: Dict[str, ReqState] = {}
|
95
|
+
self.rid_to_state: Dict[str, ReqState] = {}
|
94
96
|
|
95
97
|
async def get_pixel_values(self, image_data):
|
96
98
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
@@ -283,7 +285,7 @@ class TokenizerManager:
|
|
283
285
|
req = AbortReq(rid)
|
284
286
|
self.send_to_router.send_pyobj(req)
|
285
287
|
|
286
|
-
def create_abort_task(self, obj):
|
288
|
+
def create_abort_task(self, obj: GenerateReqInput):
|
287
289
|
# Abort the request if the client is disconnected.
|
288
290
|
async def abort_request():
|
289
291
|
await asyncio.sleep(3)
|
@@ -314,14 +316,13 @@ class TokenizerManager:
|
|
314
316
|
|
315
317
|
recv_obj.meta_info[i]["id"] = rid
|
316
318
|
out_dict = {
|
317
|
-
"text": recv_obj.
|
319
|
+
"text": recv_obj.output_strs[i],
|
318
320
|
"meta_info": recv_obj.meta_info[i],
|
319
321
|
}
|
320
322
|
state.out_list.append(out_dict)
|
321
323
|
state.finished = recv_obj.finished_reason[i] is not None
|
322
324
|
state.event.set()
|
323
325
|
|
324
|
-
|
325
326
|
def convert_logprob_style(
|
326
327
|
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
327
328
|
):
|
@@ -332,17 +333,18 @@ class TokenizerManager:
|
|
332
333
|
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
333
334
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
334
335
|
)
|
335
|
-
|
336
|
-
|
337
|
-
|
336
|
+
|
337
|
+
if top_logprobs_num > 0:
|
338
|
+
ret["meta_info"][
|
339
|
+
"prefill_top_logprobs"
|
340
|
+
] = self.detokenize_top_logprobs_tokens(
|
338
341
|
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
339
342
|
)
|
340
|
-
|
341
|
-
|
342
|
-
self.detokenize_top_logprobs_tokens(
|
343
|
+
ret["meta_info"][
|
344
|
+
"decode_top_logprobs"
|
345
|
+
] = self.detokenize_top_logprobs_tokens(
|
343
346
|
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
344
347
|
)
|
345
|
-
)
|
346
348
|
return ret
|
347
349
|
|
348
350
|
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
@@ -382,7 +384,7 @@ def get_pixel_values(
|
|
382
384
|
try:
|
383
385
|
processor = processor or global_processor
|
384
386
|
image, image_size = load_image(image_data)
|
385
|
-
if image_size
|
387
|
+
if image_size is not None:
|
386
388
|
image_hash = hash(image_data)
|
387
389
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
388
390
|
for _ in range(len(pixel_values)):
|
sglang/srt/model_config.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Optional
|
2
2
|
|
3
|
+
from transformers import PretrainedConfig
|
4
|
+
|
3
5
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
4
6
|
|
5
7
|
|
@@ -16,9 +18,13 @@ class ModelConfig:
|
|
16
18
|
self.trust_remote_code = trust_remote_code
|
17
19
|
self.revision = revision
|
18
20
|
self.model_overide_args = model_overide_args
|
19
|
-
self.hf_config = get_config(
|
20
|
-
|
21
|
-
|
21
|
+
self.hf_config = get_config(
|
22
|
+
self.path,
|
23
|
+
trust_remote_code,
|
24
|
+
revision,
|
25
|
+
model_overide_args=model_overide_args,
|
26
|
+
)
|
27
|
+
self.hf_text_config = get_hf_text_config(self.hf_config)
|
22
28
|
if context_length is not None:
|
23
29
|
self.context_len = context_length
|
24
30
|
else:
|
@@ -43,4 +49,83 @@ class ModelConfig:
|
|
43
49
|
self.num_key_value_heads = self.num_attention_heads
|
44
50
|
self.hidden_size = self.hf_config.hidden_size
|
45
51
|
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
46
|
-
self.vocab_size = self.hf_config.vocab_size
|
52
|
+
self.vocab_size = self.hf_config.vocab_size
|
53
|
+
|
54
|
+
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
55
|
+
def get_total_num_kv_heads(self) -> int:
|
56
|
+
"""Returns the total number of KV heads."""
|
57
|
+
# For GPTBigCode & Falcon:
|
58
|
+
# NOTE: for falcon, when new_decoder_architecture is True, the
|
59
|
+
# multi_query flag is ignored and we use n_head_kv for the number of
|
60
|
+
# KV heads.
|
61
|
+
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
62
|
+
new_decoder_arch_falcon = (
|
63
|
+
self.hf_config.model_type in falcon_model_types
|
64
|
+
and getattr(self.hf_config, "new_decoder_architecture", False)
|
65
|
+
)
|
66
|
+
if not new_decoder_arch_falcon and getattr(
|
67
|
+
self.hf_text_config, "multi_query", False
|
68
|
+
):
|
69
|
+
# Multi-query attention, only one KV head.
|
70
|
+
# Currently, tensor parallelism is not supported in this case.
|
71
|
+
return 1
|
72
|
+
|
73
|
+
# For DBRX and MPT
|
74
|
+
if self.hf_config.model_type in ["mpt"]:
|
75
|
+
if "kv_n_heads" in self.hf_config.attn_config:
|
76
|
+
return self.hf_config.attn_config["kv_n_heads"]
|
77
|
+
return self.hf_config.num_attention_heads
|
78
|
+
if self.hf_config.model_type in ["dbrx"]:
|
79
|
+
return getattr(
|
80
|
+
self.hf_config.attn_config,
|
81
|
+
"kv_n_heads",
|
82
|
+
self.hf_config.num_attention_heads,
|
83
|
+
)
|
84
|
+
|
85
|
+
attributes = [
|
86
|
+
# For Falcon:
|
87
|
+
"n_head_kv",
|
88
|
+
"num_kv_heads",
|
89
|
+
# For LLaMA-2:
|
90
|
+
"num_key_value_heads",
|
91
|
+
# For ChatGLM:
|
92
|
+
"multi_query_group_num",
|
93
|
+
]
|
94
|
+
for attr in attributes:
|
95
|
+
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
96
|
+
if num_kv_heads is not None:
|
97
|
+
return num_kv_heads
|
98
|
+
|
99
|
+
# For non-grouped-query attention models, the number of KV heads is
|
100
|
+
# equal to the number of attention heads.
|
101
|
+
return self.hf_text_config.num_attention_heads
|
102
|
+
|
103
|
+
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
104
|
+
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
105
|
+
"""Returns the number of KV heads per GPU."""
|
106
|
+
total_num_kv_heads = self.get_total_num_kv_heads()
|
107
|
+
# If tensor parallelism is used, we divide the number of KV heads by
|
108
|
+
# the tensor parallel size. We will replicate the KV heads in the
|
109
|
+
# case where the number of KV heads is smaller than the tensor
|
110
|
+
# parallel size so each GPU has at least one KV head.
|
111
|
+
return max(1, total_num_kv_heads // tensor_parallel_size)
|
112
|
+
|
113
|
+
|
114
|
+
def get_hf_text_config(config: PretrainedConfig):
|
115
|
+
"""Get the "sub" config relevant to llm for multi modal models.
|
116
|
+
No op for pure text models.
|
117
|
+
"""
|
118
|
+
class_name = config.architectures[0]
|
119
|
+
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
120
|
+
# We support non-hf version of llava models, so we do not want to
|
121
|
+
# read the wrong values from the unused default text_config.
|
122
|
+
return config
|
123
|
+
|
124
|
+
if hasattr(config, "text_config"):
|
125
|
+
# The code operates under the assumption that text_config should have
|
126
|
+
# `num_attention_heads` (among others). Assert here to fail early
|
127
|
+
# if transformers config doesn't align with this assumption.
|
128
|
+
assert hasattr(config.text_config, "num_attention_heads")
|
129
|
+
return config.text_config
|
130
|
+
else:
|
131
|
+
return config
|
@@ -0,0 +1,399 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from
|
3
|
+
# https://github.com/THUDM/ChatGLM2-6B
|
4
|
+
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
5
|
+
from typing import Iterable, List, Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from torch import nn
|
9
|
+
from torch.nn import LayerNorm
|
10
|
+
from vllm.config import CacheConfig
|
11
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
12
|
+
from vllm.model_executor.layers.activation import SiluAndMul
|
13
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
14
|
+
from vllm.model_executor.layers.linear import (
|
15
|
+
MergedColumnParallelLinear,
|
16
|
+
QKVParallelLinear,
|
17
|
+
RowParallelLinear,
|
18
|
+
)
|
19
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
20
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
21
|
+
from vllm.model_executor.layers.sampler import Sampler
|
22
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
23
|
+
ParallelLMHead,
|
24
|
+
VocabParallelEmbedding,
|
25
|
+
)
|
26
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
|
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
28
|
+
from vllm.sequence import SamplerOutput
|
29
|
+
from vllm.transformers_utils.configs import ChatGLMConfig
|
30
|
+
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
34
|
+
|
35
|
+
LoraConfig = None
|
36
|
+
|
37
|
+
|
38
|
+
class GLMAttention(nn.Module):
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
config,
|
42
|
+
layer_id: int = 0,
|
43
|
+
cache_config: Optional[CacheConfig] = None,
|
44
|
+
quant_config: Optional[QuantizationConfig] = None,
|
45
|
+
):
|
46
|
+
super().__init__()
|
47
|
+
self.hidden_size = config.hidden_size
|
48
|
+
tp_size = get_tensor_model_parallel_world_size()
|
49
|
+
self.total_num_heads = config.num_attention_heads
|
50
|
+
assert self.total_num_heads % tp_size == 0
|
51
|
+
self.num_heads = self.total_num_heads // tp_size
|
52
|
+
self.multi_query_attention = config.multi_query_attention
|
53
|
+
self.total_num_kv_heads = (
|
54
|
+
config.multi_query_group_num
|
55
|
+
if config.multi_query_attention
|
56
|
+
else config.num_attention_heads
|
57
|
+
)
|
58
|
+
if self.total_num_kv_heads >= tp_size:
|
59
|
+
# Number of KV heads is greater than TP size, so we partition
|
60
|
+
# the KV heads across multiple tensor parallel GPUs.
|
61
|
+
assert self.total_num_kv_heads % tp_size == 0
|
62
|
+
else:
|
63
|
+
# Number of KV heads is less than TP size, so we replicate
|
64
|
+
# the KV heads across multiple tensor parallel GPUs.
|
65
|
+
assert tp_size % self.total_num_kv_heads == 0
|
66
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
67
|
+
self.head_dim = config.hidden_size // self.total_num_heads
|
68
|
+
self.q_size = self.num_heads * self.head_dim
|
69
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
70
|
+
self.scaling = self.head_dim**-0.5
|
71
|
+
|
72
|
+
self.query_key_value = QKVParallelLinear(
|
73
|
+
self.hidden_size,
|
74
|
+
self.head_dim,
|
75
|
+
self.total_num_heads,
|
76
|
+
self.total_num_kv_heads,
|
77
|
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
78
|
+
quant_config=quant_config,
|
79
|
+
)
|
80
|
+
self.dense = RowParallelLinear(
|
81
|
+
self.total_num_heads * self.head_dim,
|
82
|
+
config.hidden_size,
|
83
|
+
bias=config.add_bias_linear,
|
84
|
+
quant_config=quant_config,
|
85
|
+
)
|
86
|
+
|
87
|
+
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
88
|
+
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
89
|
+
max_positions = getattr(config, "seq_length", 8192)
|
90
|
+
self.rotary_emb = get_rope(
|
91
|
+
self.head_dim,
|
92
|
+
rotary_dim=self.head_dim // 2,
|
93
|
+
max_position=max_positions,
|
94
|
+
base=10000 * rope_ratio,
|
95
|
+
is_neox_style=False,
|
96
|
+
)
|
97
|
+
self.attn = RadixAttention(
|
98
|
+
self.num_heads,
|
99
|
+
self.head_dim,
|
100
|
+
self.scaling,
|
101
|
+
num_kv_heads=self.num_kv_heads,
|
102
|
+
layer_id=layer_id,
|
103
|
+
)
|
104
|
+
|
105
|
+
def forward(
|
106
|
+
self,
|
107
|
+
hidden_states: torch.Tensor,
|
108
|
+
position_ids: torch.Tensor,
|
109
|
+
input_metadata: InputMetadata,
|
110
|
+
) -> torch.Tensor:
|
111
|
+
qkv, _ = self.query_key_value(hidden_states)
|
112
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
113
|
+
q, k = self.rotary_emb(position_ids, q, k)
|
114
|
+
context_layer = self.attn(
|
115
|
+
q,
|
116
|
+
k,
|
117
|
+
v,
|
118
|
+
input_metadata,
|
119
|
+
)
|
120
|
+
attn_output, _ = self.dense(context_layer)
|
121
|
+
return attn_output
|
122
|
+
|
123
|
+
|
124
|
+
class GLMMLP(nn.Module):
|
125
|
+
"""MLP.
|
126
|
+
|
127
|
+
MLP will take the input with h hidden state, project it to 4*h
|
128
|
+
hidden dimension, perform nonlinear transformation, and project the
|
129
|
+
state back into h hidden dimension.
|
130
|
+
"""
|
131
|
+
|
132
|
+
def __init__(
|
133
|
+
self,
|
134
|
+
config,
|
135
|
+
quant_config: Optional[QuantizationConfig] = None,
|
136
|
+
):
|
137
|
+
super().__init__()
|
138
|
+
|
139
|
+
self.add_bias = config.add_bias_linear
|
140
|
+
|
141
|
+
# Project to 4h.
|
142
|
+
self.dense_h_to_4h = MergedColumnParallelLinear(
|
143
|
+
config.hidden_size,
|
144
|
+
[config.ffn_hidden_size] * 2,
|
145
|
+
bias=config.add_bias_linear,
|
146
|
+
quant_config=quant_config,
|
147
|
+
)
|
148
|
+
|
149
|
+
self.activation_func = SiluAndMul()
|
150
|
+
|
151
|
+
# Project back to h.
|
152
|
+
self.dense_4h_to_h = RowParallelLinear(
|
153
|
+
config.ffn_hidden_size,
|
154
|
+
config.hidden_size,
|
155
|
+
bias=config.add_bias_linear,
|
156
|
+
quant_config=quant_config,
|
157
|
+
)
|
158
|
+
|
159
|
+
def forward(self, hidden_states):
|
160
|
+
# [s, b, 4hp]
|
161
|
+
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
162
|
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
163
|
+
# [s, b, h]
|
164
|
+
output, _ = self.dense_4h_to_h(intermediate_parallel)
|
165
|
+
return output
|
166
|
+
|
167
|
+
|
168
|
+
class GLMBlock(nn.Module):
|
169
|
+
"""A single transformer layer.
|
170
|
+
|
171
|
+
Transformer layer takes input with size [s, b, h] and returns an
|
172
|
+
output of the same size.
|
173
|
+
"""
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
config,
|
178
|
+
layer_id: int,
|
179
|
+
cache_config: Optional[CacheConfig] = None,
|
180
|
+
quant_config: Optional[QuantizationConfig] = None,
|
181
|
+
):
|
182
|
+
super().__init__()
|
183
|
+
self.apply_residual_connection_post_layernorm = (
|
184
|
+
config.apply_residual_connection_post_layernorm
|
185
|
+
)
|
186
|
+
|
187
|
+
self.fp32_residual_connection = config.fp32_residual_connection
|
188
|
+
|
189
|
+
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
190
|
+
# Layernorm on the input data.
|
191
|
+
self.input_layernorm = layer_norm_func(
|
192
|
+
config.hidden_size, eps=config.layernorm_epsilon
|
193
|
+
)
|
194
|
+
|
195
|
+
# Self attention.
|
196
|
+
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
|
197
|
+
self.hidden_dropout = config.hidden_dropout
|
198
|
+
|
199
|
+
# Layernorm on the attention output
|
200
|
+
self.post_attention_layernorm = layer_norm_func(
|
201
|
+
config.hidden_size, eps=config.layernorm_epsilon
|
202
|
+
)
|
203
|
+
|
204
|
+
# MLP
|
205
|
+
self.mlp = GLMMLP(config, quant_config)
|
206
|
+
|
207
|
+
def forward(
|
208
|
+
self,
|
209
|
+
hidden_states: torch.Tensor,
|
210
|
+
position_ids: torch.Tensor,
|
211
|
+
input_metadata: InputMetadata,
|
212
|
+
) -> torch.Tensor:
|
213
|
+
# hidden_states: [num_tokens, h]
|
214
|
+
# Layer norm at the beginning of the transformer layer.
|
215
|
+
layernorm_output = self.input_layernorm(hidden_states)
|
216
|
+
# Self attention.
|
217
|
+
attention_output = self.self_attention(
|
218
|
+
hidden_states=layernorm_output,
|
219
|
+
position_ids=position_ids,
|
220
|
+
input_metadata=input_metadata,
|
221
|
+
)
|
222
|
+
|
223
|
+
# Residual connection.
|
224
|
+
if self.apply_residual_connection_post_layernorm:
|
225
|
+
residual = layernorm_output
|
226
|
+
else:
|
227
|
+
residual = hidden_states
|
228
|
+
|
229
|
+
layernorm_input = residual + attention_output
|
230
|
+
|
231
|
+
# Layer norm post the self attention.
|
232
|
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
233
|
+
|
234
|
+
# Second residual connection.
|
235
|
+
if self.apply_residual_connection_post_layernorm:
|
236
|
+
residual = layernorm_output
|
237
|
+
else:
|
238
|
+
residual = layernorm_input
|
239
|
+
|
240
|
+
output = self.mlp(layernorm_output) + residual
|
241
|
+
|
242
|
+
return output
|
243
|
+
|
244
|
+
|
245
|
+
class GLMTransformer(nn.Module):
|
246
|
+
"""Transformer class."""
|
247
|
+
|
248
|
+
def __init__(
|
249
|
+
self,
|
250
|
+
config,
|
251
|
+
cache_config: Optional[CacheConfig] = None,
|
252
|
+
quant_config: Optional[QuantizationConfig] = None,
|
253
|
+
):
|
254
|
+
super().__init__()
|
255
|
+
self.post_layer_norm = config.post_layer_norm
|
256
|
+
|
257
|
+
# Number of layers.
|
258
|
+
self.num_layers = config.num_layers
|
259
|
+
|
260
|
+
# Transformer layers.
|
261
|
+
self.layers = nn.ModuleList(
|
262
|
+
[
|
263
|
+
GLMBlock(config, i, cache_config, quant_config)
|
264
|
+
for i in range(self.num_layers)
|
265
|
+
]
|
266
|
+
)
|
267
|
+
|
268
|
+
if self.post_layer_norm:
|
269
|
+
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
270
|
+
# Final layer norm before output.
|
271
|
+
self.final_layernorm = layer_norm_func(
|
272
|
+
config.hidden_size, eps=config.layernorm_epsilon
|
273
|
+
)
|
274
|
+
|
275
|
+
def forward(
|
276
|
+
self,
|
277
|
+
hidden_states: torch.Tensor,
|
278
|
+
position_ids: torch.Tensor,
|
279
|
+
input_metadata: InputMetadata,
|
280
|
+
) -> torch.Tensor:
|
281
|
+
for i in range(self.num_layers):
|
282
|
+
layer = self.layers[i]
|
283
|
+
hidden_states = layer(
|
284
|
+
hidden_states=hidden_states,
|
285
|
+
position_ids=position_ids,
|
286
|
+
input_metadata=input_metadata,
|
287
|
+
)
|
288
|
+
# Final layer norm.
|
289
|
+
if self.post_layer_norm:
|
290
|
+
hidden_states = self.final_layernorm(hidden_states)
|
291
|
+
|
292
|
+
return hidden_states
|
293
|
+
|
294
|
+
|
295
|
+
class ChatGLMModel(nn.Module):
|
296
|
+
def __init__(
|
297
|
+
self,
|
298
|
+
config,
|
299
|
+
cache_config: Optional[CacheConfig] = None,
|
300
|
+
quant_config: Optional[QuantizationConfig] = None,
|
301
|
+
):
|
302
|
+
super().__init__()
|
303
|
+
|
304
|
+
self.embedding = VocabParallelEmbedding(
|
305
|
+
config.padded_vocab_size, config.hidden_size
|
306
|
+
)
|
307
|
+
|
308
|
+
self.num_layers = config.num_layers
|
309
|
+
self.multi_query_group_num = config.multi_query_group_num
|
310
|
+
self.kv_channels = config.kv_channels
|
311
|
+
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
312
|
+
|
313
|
+
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
|
314
|
+
|
315
|
+
def forward(
|
316
|
+
self,
|
317
|
+
input_ids: torch.Tensor,
|
318
|
+
position_ids: torch.Tensor,
|
319
|
+
input_metadata: InputMetadata,
|
320
|
+
) -> torch.Tensor:
|
321
|
+
inputs_embeds = self.embedding(input_ids)
|
322
|
+
|
323
|
+
# Run encoder.
|
324
|
+
hidden_states = self.encoder(
|
325
|
+
hidden_states=inputs_embeds,
|
326
|
+
position_ids=position_ids,
|
327
|
+
input_metadata=input_metadata,
|
328
|
+
)
|
329
|
+
return hidden_states
|
330
|
+
|
331
|
+
|
332
|
+
class ChatGLMForCausalLM(nn.Module):
|
333
|
+
packed_modules_mapping = {
|
334
|
+
"query_key_value": ["query_key_value"],
|
335
|
+
"dense_h_to_4h": ["dense_h_to_4h"],
|
336
|
+
}
|
337
|
+
# LoRA specific attributes
|
338
|
+
supported_lora_modules = [
|
339
|
+
"query_key_value",
|
340
|
+
"dense",
|
341
|
+
"dense_h_to_4h",
|
342
|
+
"dense_4h_to_h",
|
343
|
+
]
|
344
|
+
embedding_modules = {}
|
345
|
+
embedding_padding_modules = []
|
346
|
+
|
347
|
+
def __init__(
|
348
|
+
self,
|
349
|
+
config: ChatGLMConfig,
|
350
|
+
cache_config: Optional[CacheConfig] = None,
|
351
|
+
quant_config: Optional[QuantizationConfig] = None,
|
352
|
+
lora_config: Optional[LoraConfig] = None,
|
353
|
+
):
|
354
|
+
super().__init__()
|
355
|
+
self.config: ChatGLMConfig = config
|
356
|
+
self.quant_config = quant_config
|
357
|
+
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
358
|
+
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
359
|
+
self.lm_head = self.transformer.output_layer
|
360
|
+
self.logits_processor = LogitsProcessor(config)
|
361
|
+
self.sampler = Sampler()
|
362
|
+
|
363
|
+
def forward(
|
364
|
+
self,
|
365
|
+
input_ids: torch.Tensor,
|
366
|
+
positions: torch.Tensor,
|
367
|
+
input_metadata: InputMetadata,
|
368
|
+
) -> torch.Tensor:
|
369
|
+
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
370
|
+
return self.logits_processor(
|
371
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
372
|
+
)
|
373
|
+
|
374
|
+
def sample(
|
375
|
+
self,
|
376
|
+
logits: torch.Tensor,
|
377
|
+
sampling_metadata: SamplingMetadata,
|
378
|
+
) -> Optional[SamplerOutput]:
|
379
|
+
next_tokens = self.sampler(logits, sampling_metadata)
|
380
|
+
return next_tokens
|
381
|
+
|
382
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
383
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
384
|
+
for name, loaded_weight in weights:
|
385
|
+
if "rotary_pos_emb.inv_freq" in name:
|
386
|
+
continue
|
387
|
+
if "word_embeddings" in name:
|
388
|
+
name = name.replace(".word_embeddings", "")
|
389
|
+
# Skip loading extra bias for GPTQ models.
|
390
|
+
if name.endswith(".bias") and name not in params_dict:
|
391
|
+
continue
|
392
|
+
param = params_dict[name]
|
393
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
394
|
+
weight_loader(param, loaded_weight)
|
395
|
+
|
396
|
+
|
397
|
+
EntryClass = ChatGLMForCausalLM
|
398
|
+
# compat: glm model.config class == ChatGLMModel
|
399
|
+
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
|
sglang/srt/models/commandr.py
CHANGED
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
# This file is based on the LLama model definition file in transformers
|
25
25
|
"""PyTorch Cohere model."""
|
26
|
-
from typing import Optional, Tuple
|
26
|
+
from typing import Iterable, Optional, Tuple
|
27
27
|
|
28
28
|
import torch
|
29
29
|
import torch.utils.checkpoint
|
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
|
|
44
44
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
45
45
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
46
46
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
47
|
-
from vllm.model_executor.utils import set_weight_attrs
|
48
47
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
48
|
+
from vllm.model_executor.utils import set_weight_attrs
|
49
49
|
|
50
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
51
|
from sglang.srt.layers.radix_attention import RadixAttention
|
sglang/srt/models/dbrx.py
CHANGED
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
24
24
|
ParallelLMHead,
|
25
25
|
VocabParallelEmbedding,
|
26
26
|
)
|
27
|
-
from vllm.model_executor.utils import set_weight_attrs
|
28
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
|
+
from vllm.model_executor.utils import set_weight_attrs
|
29
29
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
30
30
|
|
31
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
sglang/srt/models/gemma.py
CHANGED
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
8
|
from transformers import PretrainedConfig
|
9
|
-
from vllm.config import
|
9
|
+
from vllm.config import CacheConfig, LoRAConfig
|
10
10
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
11
11
|
from vllm.model_executor.layers.activation import GeluAndMul
|
12
12
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
@@ -310,6 +310,10 @@ class GemmaForCausalLM(nn.Module):
|
|
310
310
|
weight_loader(param, loaded_weight, shard_id)
|
311
311
|
break
|
312
312
|
else:
|
313
|
+
# lm_head is not used in vllm as it is tied with embed_token.
|
314
|
+
# To prevent errors, skip loading lm_head.weight.
|
315
|
+
if "lm_head.weight" in name:
|
316
|
+
continue
|
313
317
|
# Skip loading extra bias for GPTQ models.
|
314
318
|
if name.endswith(".bias") and name not in params_dict:
|
315
319
|
continue
|