sglang 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +5 -0
- sglang/global_config.py +4 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +52 -19
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +8 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/router/infer_batch.py +31 -19
- sglang/srt/managers/router/manager.py +6 -8
- sglang/srt/managers/router/model_rpc.py +59 -23
- sglang/srt/managers/router/model_runner.py +6 -6
- sglang/srt/managers/router/radix_cache.py +47 -17
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +54 -22
- sglang/srt/model_config.py +4 -0
- sglang/srt/models/commandr.py +6 -10
- sglang/srt/models/dbrx.py +14 -15
- sglang/srt/models/gemma.py +7 -10
- sglang/srt/models/llama2.py +7 -10
- sglang/srt/models/llava.py +2 -6
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +7 -13
- sglang/srt/models/qwen.py +20 -13
- sglang/srt/models/qwen2.py +7 -10
- sglang/srt/models/stablelm.py +13 -12
- sglang/srt/models/yivl.py +1 -4
- sglang/srt/server.py +32 -18
- sglang/srt/server_args.py +9 -6
- sglang/srt/utils.py +126 -17
- sglang/srt/weight_utils.py +66 -51
- sglang/utils.py +77 -26
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/METADATA +9 -5
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/srt/models/dbrx.py
CHANGED
@@ -5,37 +5,31 @@ from typing import Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn as nn
|
8
|
+
from vllm.distributed import (
|
9
|
+
get_tensor_model_parallel_rank,
|
10
|
+
get_tensor_model_parallel_world_size,
|
11
|
+
tensor_model_parallel_all_reduce,
|
12
|
+
)
|
8
13
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
9
14
|
from vllm.model_executor.layers.linear import (
|
10
15
|
QKVParallelLinear,
|
11
16
|
ReplicatedLinear,
|
12
17
|
RowParallelLinear,
|
13
18
|
)
|
14
|
-
from vllm.model_executor.layers.quantization.base_config import
|
15
|
-
QuantizationConfig)
|
19
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
16
20
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
17
21
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
18
22
|
DEFAULT_VOCAB_PADDING_SIZE,
|
19
23
|
ParallelLMHead,
|
20
24
|
VocabParallelEmbedding,
|
21
25
|
)
|
22
|
-
from vllm.distributed import (
|
23
|
-
tensor_model_parallel_all_reduce,
|
24
|
-
)
|
25
|
-
from vllm.distributed import (
|
26
|
-
get_tensor_model_parallel_rank,
|
27
|
-
get_tensor_model_parallel_world_size,
|
28
|
-
)
|
29
26
|
from vllm.model_executor.utils import set_weight_attrs
|
30
|
-
from sglang.srt.weight_utils import (
|
31
|
-
default_weight_loader,
|
32
|
-
hf_model_weights_iterator,
|
33
|
-
)
|
34
27
|
|
35
28
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
29
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
30
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
38
31
|
from sglang.srt.models.dbrx_config import DbrxConfig
|
32
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
39
33
|
|
40
34
|
|
41
35
|
class DbrxRouter(nn.Module):
|
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
|
|
291
285
|
quant_config: Optional[QuantizationConfig] = None,
|
292
286
|
):
|
293
287
|
super().__init__()
|
294
|
-
self.norm_attn_norm = DbrxFusedNormAttention(
|
288
|
+
self.norm_attn_norm = DbrxFusedNormAttention(
|
289
|
+
config, layer_id, quant_config=quant_config
|
290
|
+
)
|
295
291
|
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
296
292
|
|
297
293
|
def forward(
|
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
|
|
322
318
|
config.d_model,
|
323
319
|
)
|
324
320
|
self.blocks = nn.ModuleList(
|
325
|
-
[
|
321
|
+
[
|
322
|
+
DbrxBlock(config, i, quant_config=quant_config)
|
323
|
+
for i in range(config.n_layers)
|
324
|
+
]
|
326
325
|
)
|
327
326
|
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
328
327
|
for module in self.modules():
|
sglang/srt/models/gemma.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7
7
|
from torch import nn
|
8
8
|
from transformers import PretrainedConfig
|
9
9
|
from vllm.config import LoRAConfig
|
10
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
10
11
|
from vllm.model_executor.layers.activation import GeluAndMul
|
11
12
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
12
13
|
from vllm.model_executor.layers.linear import (
|
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
|
|
14
15
|
QKVParallelLinear,
|
15
16
|
RowParallelLinear,
|
16
17
|
)
|
17
|
-
from vllm.model_executor.layers.quantization.base_config import
|
18
|
-
QuantizationConfig)
|
18
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
19
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
20
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
21
|
-
from vllm.distributed import (
|
22
|
-
get_tensor_model_parallel_world_size,
|
23
|
-
)
|
24
|
-
from sglang.srt.weight_utils import (
|
25
|
-
default_weight_loader,
|
26
|
-
hf_model_weights_iterator,
|
27
|
-
)
|
28
21
|
|
29
22
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
30
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
31
24
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
25
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
32
26
|
|
33
27
|
|
34
28
|
class GemmaMLP(nn.Module):
|
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
|
|
46
40
|
quant_config=quant_config,
|
47
41
|
)
|
48
42
|
self.down_proj = RowParallelLinear(
|
49
|
-
intermediate_size,
|
43
|
+
intermediate_size,
|
44
|
+
hidden_size,
|
45
|
+
bias=False,
|
46
|
+
quant_config=quant_config,
|
50
47
|
)
|
51
48
|
self.act_fn = GeluAndMul()
|
52
49
|
|
sglang/srt/models/llama2.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
8
|
from transformers import LlamaConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
9
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
10
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
11
12
|
from vllm.model_executor.layers.linear import (
|
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
|
|
13
14
|
QKVParallelLinear,
|
14
15
|
RowParallelLinear,
|
15
16
|
)
|
16
|
-
from vllm.model_executor.layers.quantization.base_config import
|
17
|
-
QuantizationConfig)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
18
18
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
19
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
20
20
|
ParallelLMHead,
|
21
21
|
VocabParallelEmbedding,
|
22
22
|
)
|
23
|
-
from vllm.distributed import (
|
24
|
-
get_tensor_model_parallel_world_size,
|
25
|
-
)
|
26
|
-
from sglang.srt.weight_utils import (
|
27
|
-
default_weight_loader,
|
28
|
-
hf_model_weights_iterator,
|
29
|
-
)
|
30
23
|
|
31
24
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
25
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
26
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
27
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
34
28
|
|
35
29
|
|
36
30
|
class LlamaMLP(nn.Module):
|
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
|
|
49
43
|
quant_config=quant_config,
|
50
44
|
)
|
51
45
|
self.down_proj = RowParallelLinear(
|
52
|
-
intermediate_size,
|
46
|
+
intermediate_size,
|
47
|
+
hidden_size,
|
48
|
+
bias=False,
|
49
|
+
quant_config=quant_config,
|
53
50
|
)
|
54
51
|
if hidden_act != "silu":
|
55
52
|
raise ValueError(
|
sglang/srt/models/llava.py
CHANGED
@@ -7,12 +7,7 @@ import torch
|
|
7
7
|
from torch import nn
|
8
8
|
from transformers import CLIPVisionModel, LlavaConfig
|
9
9
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
|
-
from vllm.model_executor.layers.quantization.base_config import
|
11
|
-
QuantizationConfig)
|
12
|
-
from sglang.srt.weight_utils import (
|
13
|
-
default_weight_loader,
|
14
|
-
hf_model_weights_iterator,
|
15
|
-
)
|
10
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
16
11
|
|
17
12
|
from sglang.srt.managers.router.infer_batch import ForwardMode
|
18
13
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
|
|
22
17
|
unpad_image_shape,
|
23
18
|
)
|
24
19
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
20
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
25
21
|
|
26
22
|
|
27
23
|
class LlavaLlamaForCausalLM(nn.Module):
|
@@ -0,0 +1,307 @@
|
|
1
|
+
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from torch import nn
|
9
|
+
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
10
|
+
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
|
13
|
+
from sglang.srt.managers.router.infer_batch import ForwardMode
|
14
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
15
|
+
from sglang.srt.mm_utils import (
|
16
|
+
get_anyres_image_grid_shape,
|
17
|
+
unpad_image,
|
18
|
+
unpad_image_shape,
|
19
|
+
)
|
20
|
+
from sglang.srt.models.llama2 import LlamaForCausalLM
|
21
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
22
|
+
|
23
|
+
|
24
|
+
class LlavaVidForCausalLM(nn.Module):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
config: LlavaConfig,
|
28
|
+
quant_config: Optional[QuantizationConfig] = None,
|
29
|
+
) -> None:
|
30
|
+
super().__init__()
|
31
|
+
self.config = config
|
32
|
+
self.vision_tower = None
|
33
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
34
|
+
self.config.text_config.hidden_size = config.hidden_size
|
35
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
36
|
+
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
|
37
|
+
self.resampler = nn.AvgPool2d(
|
38
|
+
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
39
|
+
)
|
40
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
41
|
+
self.num_frames = getattr(self.config, "num_frames", 16)
|
42
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
43
|
+
self.language_model.model.image_newline = nn.Parameter(
|
44
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
45
|
+
)
|
46
|
+
|
47
|
+
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
48
|
+
new_image_feature_len = self.image_feature_len
|
49
|
+
# now only support spatial_unpad + anyres
|
50
|
+
# if self.mm_patch_merge_type.startswith("spatial"):
|
51
|
+
# height = width = self.num_patches_per_side
|
52
|
+
# if pt_shape[0] > 1:
|
53
|
+
# if self.image_aspect_ratio == "anyres":
|
54
|
+
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
55
|
+
# image_size,
|
56
|
+
# self.image_grid_pinpoints,
|
57
|
+
# self.vision_tower.config.image_size,
|
58
|
+
# )
|
59
|
+
# if "unpad" in self.mm_patch_merge_type:
|
60
|
+
# h = num_patch_height * height
|
61
|
+
# w = num_patch_width * width
|
62
|
+
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
63
|
+
# new_image_feature_len += new_h * (new_w + 1)
|
64
|
+
|
65
|
+
pad_ids = pad_value * (
|
66
|
+
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
67
|
+
)
|
68
|
+
# print(input_ids)
|
69
|
+
offset = input_ids.index(self.config.image_token_index)
|
70
|
+
# old_len + pad_len - 1, because we need to remove image_token_id
|
71
|
+
new_input_ids = (
|
72
|
+
input_ids[:offset]
|
73
|
+
+ pad_ids[:new_image_feature_len]
|
74
|
+
+ input_ids[offset + 1 :]
|
75
|
+
)
|
76
|
+
return new_input_ids, offset
|
77
|
+
|
78
|
+
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
79
|
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
80
|
+
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
81
|
+
|
82
|
+
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
83
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
84
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
85
|
+
elif self.vision_feature_select_strategy == "full":
|
86
|
+
selected_image_feature = selected_image_feature
|
87
|
+
else:
|
88
|
+
raise ValueError(
|
89
|
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
90
|
+
)
|
91
|
+
|
92
|
+
height = width = self.num_patches_per_side
|
93
|
+
num_of_frames = selected_image_feature.shape[0]
|
94
|
+
selected_image_feature = selected_image_feature.view(
|
95
|
+
num_of_frames, height, width, -1
|
96
|
+
)
|
97
|
+
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
|
98
|
+
selected_image_feature = (
|
99
|
+
self.resampler(selected_image_feature)
|
100
|
+
.flatten(2)
|
101
|
+
.transpose(1, 2)
|
102
|
+
.contiguous()
|
103
|
+
)
|
104
|
+
|
105
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
106
|
+
|
107
|
+
return image_features
|
108
|
+
|
109
|
+
def forward(
|
110
|
+
self,
|
111
|
+
input_ids: torch.LongTensor,
|
112
|
+
positions: torch.Tensor,
|
113
|
+
input_metadata: InputMetadata,
|
114
|
+
pixel_values: Optional[List[Optional[np.array]]] = None,
|
115
|
+
image_sizes: Optional[List[List[int]]] = None,
|
116
|
+
image_offsets: Optional[List[int]] = None,
|
117
|
+
) -> torch.Tensor:
|
118
|
+
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
119
|
+
bs = input_metadata.batch_size
|
120
|
+
|
121
|
+
# Embed text input
|
122
|
+
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
123
|
+
|
124
|
+
# Embed vision input
|
125
|
+
need_vision = (
|
126
|
+
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
127
|
+
.cpu()
|
128
|
+
.numpy()
|
129
|
+
)
|
130
|
+
# FIXME: We need to substract the length of the system prompt
|
131
|
+
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
132
|
+
need_vision = need_vision & has_pixel
|
133
|
+
|
134
|
+
if need_vision.any():
|
135
|
+
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
136
|
+
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
137
|
+
|
138
|
+
########## Encode Image ########
|
139
|
+
|
140
|
+
if pixel_values[0].ndim == 4:
|
141
|
+
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
142
|
+
np.concatenate(pixel_values, axis=0)
|
143
|
+
# ndim=4
|
144
|
+
concat_images = torch.tensor(
|
145
|
+
np.concatenate(pixel_values, axis=0),
|
146
|
+
device=self.vision_tower.device,
|
147
|
+
)
|
148
|
+
# image_features = self.encode_images(concat_images)
|
149
|
+
# split_sizes = [image.shape[0] for image in pixel_values]
|
150
|
+
# image_features = torch.split(image_features, split_sizes, dim=0)
|
151
|
+
image_features = self.encode_images(
|
152
|
+
concat_images
|
153
|
+
) # , prompts)#, image_counts, long_video=long_video)
|
154
|
+
split_sizes = [image.shape[0] for image in pixel_values]
|
155
|
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
156
|
+
|
157
|
+
# hd image_features: BS, num_patch, 576, 4096
|
158
|
+
else:
|
159
|
+
# normal pixel: BS, C=3, H=336, W=336
|
160
|
+
pixel_values = torch.tensor(
|
161
|
+
np.array(pixel_values), device=self.vision_tower.device
|
162
|
+
)
|
163
|
+
image_features = self.encode_images(pixel_values)
|
164
|
+
# image_features: BS, 576, 4096
|
165
|
+
|
166
|
+
new_image_features = []
|
167
|
+
for image_idx, image_feature in enumerate(image_features):
|
168
|
+
new_image_features.append(image_feature.flatten(0, 1))
|
169
|
+
image_features = new_image_features
|
170
|
+
|
171
|
+
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
172
|
+
pt = 0
|
173
|
+
for i in range(bs):
|
174
|
+
if not need_vision[i]:
|
175
|
+
continue
|
176
|
+
|
177
|
+
start_idx = extend_start_loc_cpu[i]
|
178
|
+
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
179
|
+
dim = input_embeds.shape[1]
|
180
|
+
assert (
|
181
|
+
pad_dim == dim
|
182
|
+
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
183
|
+
# Fill in the placeholder for the image
|
184
|
+
try:
|
185
|
+
input_embeds[
|
186
|
+
start_idx
|
187
|
+
+ image_offsets[i] : start_idx
|
188
|
+
+ image_offsets[i]
|
189
|
+
+ pad_len
|
190
|
+
] = image_features[pt]
|
191
|
+
except RuntimeError as e:
|
192
|
+
print(f"RuntimeError in llava image encoding: {e}")
|
193
|
+
print(input_embeds.shape)
|
194
|
+
print(start_idx, image_offsets[i])
|
195
|
+
pt += 1
|
196
|
+
|
197
|
+
return self.language_model(
|
198
|
+
input_ids, positions, input_metadata, input_embeds=input_embeds
|
199
|
+
)
|
200
|
+
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
201
|
+
return self.language_model(input_ids, positions, input_metadata)
|
202
|
+
|
203
|
+
def load_weights(
|
204
|
+
self,
|
205
|
+
model_name_or_path: str,
|
206
|
+
cache_dir: Optional[str] = None,
|
207
|
+
load_format: str = "auto",
|
208
|
+
revision: Optional[str] = None,
|
209
|
+
):
|
210
|
+
# load clip vision model by cfg['mm_vision_tower']:
|
211
|
+
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
212
|
+
vision_path = self.config.mm_vision_tower
|
213
|
+
self.vision_tower = CLIPVisionModel.from_pretrained(
|
214
|
+
vision_path, torch_dtype=torch.float16
|
215
|
+
).cuda()
|
216
|
+
self.vision_tower.eval()
|
217
|
+
|
218
|
+
self.vision_feature_layer = self.config.mm_vision_select_layer
|
219
|
+
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
220
|
+
self.image_size = self.vision_tower.config.image_size
|
221
|
+
self.patch_size = self.vision_tower.config.patch_size
|
222
|
+
|
223
|
+
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
224
|
+
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
225
|
+
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
226
|
+
|
227
|
+
print(f"target_frames: {self.num_frames}")
|
228
|
+
self.image_feature_len = self.num_frames * int(
|
229
|
+
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
|
230
|
+
)
|
231
|
+
if self.vision_feature_select_strategy == "patch":
|
232
|
+
pass
|
233
|
+
elif self.vision_feature_select_strategy == "cls_patch":
|
234
|
+
self.image_feature_len += 1
|
235
|
+
else:
|
236
|
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
237
|
+
|
238
|
+
# load mm_projector
|
239
|
+
projector_weights = {
|
240
|
+
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
241
|
+
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
242
|
+
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
243
|
+
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
244
|
+
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
245
|
+
}
|
246
|
+
params_dict = dict(self.named_parameters())
|
247
|
+
for name, loaded_weight in hf_model_weights_iterator(
|
248
|
+
model_name_or_path, cache_dir, load_format, revision
|
249
|
+
):
|
250
|
+
# FIXME: why projector weights read two times?
|
251
|
+
if "projector" in name or "vision_tower" in name:
|
252
|
+
for weight_name, param_name in projector_weights.items():
|
253
|
+
if weight_name in name:
|
254
|
+
name = name.replace(weight_name, param_name)
|
255
|
+
if name in params_dict:
|
256
|
+
param = params_dict[name]
|
257
|
+
else:
|
258
|
+
print(f"Warning: {name} not found in the model")
|
259
|
+
continue
|
260
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
261
|
+
weight_loader(param, loaded_weight)
|
262
|
+
|
263
|
+
# load language model
|
264
|
+
self.language_model.load_weights(
|
265
|
+
model_name_or_path, cache_dir, load_format, revision
|
266
|
+
)
|
267
|
+
|
268
|
+
monkey_path_clip_vision_embed_forward()
|
269
|
+
|
270
|
+
@property
|
271
|
+
def num_patches_per_side(self):
|
272
|
+
return self.image_size // self.patch_size
|
273
|
+
|
274
|
+
|
275
|
+
first_call = True
|
276
|
+
|
277
|
+
|
278
|
+
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
279
|
+
batch_size = pixel_values.shape[0]
|
280
|
+
|
281
|
+
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
282
|
+
global first_call
|
283
|
+
if first_call:
|
284
|
+
self.patch_embedding.cpu().float()
|
285
|
+
first_call = False
|
286
|
+
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
287
|
+
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
288
|
+
|
289
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
290
|
+
|
291
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
292
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
293
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
294
|
+
return embeddings
|
295
|
+
|
296
|
+
|
297
|
+
def monkey_path_clip_vision_embed_forward():
|
298
|
+
import transformers
|
299
|
+
|
300
|
+
setattr(
|
301
|
+
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
302
|
+
"forward",
|
303
|
+
clip_vision_embed_forward,
|
304
|
+
)
|
305
|
+
|
306
|
+
|
307
|
+
EntryClass = LlavaVidForCausalLM
|
sglang/srt/models/mixtral.py
CHANGED
@@ -8,34 +8,28 @@ import torch
|
|
8
8
|
import torch.nn.functional as F
|
9
9
|
from torch import nn
|
10
10
|
from transformers import MixtralConfig
|
11
|
+
from vllm.distributed import (
|
12
|
+
get_tensor_model_parallel_rank,
|
13
|
+
get_tensor_model_parallel_world_size,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
11
16
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
12
17
|
from vllm.model_executor.layers.linear import (
|
13
18
|
QKVParallelLinear,
|
14
19
|
ReplicatedLinear,
|
15
20
|
RowParallelLinear,
|
16
21
|
)
|
17
|
-
from vllm.model_executor.layers.quantization.base_config import
|
18
|
-
QuantizationConfig)
|
22
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
19
23
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
20
24
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
21
25
|
ParallelLMHead,
|
22
26
|
VocabParallelEmbedding,
|
23
27
|
)
|
24
|
-
from vllm.distributed import (
|
25
|
-
tensor_model_parallel_all_reduce,
|
26
|
-
)
|
27
|
-
from vllm.distributed import (
|
28
|
-
get_tensor_model_parallel_rank,
|
29
|
-
get_tensor_model_parallel_world_size,
|
30
|
-
)
|
31
|
-
from sglang.srt.weight_utils import (
|
32
|
-
default_weight_loader,
|
33
|
-
hf_model_weights_iterator,
|
34
|
-
)
|
35
28
|
|
36
29
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
30
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
31
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
32
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
39
33
|
|
40
34
|
|
41
35
|
class MixtralMLP(nn.Module):
|
sglang/srt/models/qwen.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
|
3
3
|
import torch
|
4
4
|
from torch import nn
|
5
5
|
from transformers import PretrainedConfig
|
6
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
6
7
|
from vllm.model_executor.layers.activation import SiluAndMul
|
7
8
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
8
9
|
from vllm.model_executor.layers.linear import (
|
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
|
|
10
11
|
QKVParallelLinear,
|
11
12
|
RowParallelLinear,
|
12
13
|
)
|
13
|
-
from vllm.model_executor.layers.quantization.base_config import
|
14
|
-
QuantizationConfig)
|
14
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
15
15
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
16
16
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
17
17
|
ParallelLMHead,
|
18
18
|
VocabParallelEmbedding,
|
19
19
|
)
|
20
|
-
from vllm.distributed import (
|
21
|
-
get_tensor_model_parallel_world_size,
|
22
|
-
)
|
23
|
-
from sglang.srt.weight_utils import (
|
24
|
-
default_weight_loader,
|
25
|
-
hf_model_weights_iterator,
|
26
|
-
)
|
27
20
|
|
28
21
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
22
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
23
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
24
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
31
25
|
|
32
26
|
|
33
27
|
class QWenMLP(nn.Module):
|
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
|
|
132
126
|
|
133
127
|
|
134
128
|
class QWenBlock(nn.Module):
|
135
|
-
def __init__(
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
config: PretrainedConfig,
|
132
|
+
layer_id,
|
133
|
+
quant_config: Optional[QuantizationConfig] = None,
|
134
|
+
):
|
136
135
|
super().__init__()
|
137
136
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
138
137
|
|
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
|
|
181
180
|
|
182
181
|
|
183
182
|
class QWenModel(nn.Module):
|
184
|
-
def __init__(
|
183
|
+
def __init__(
|
184
|
+
self,
|
185
|
+
config: PretrainedConfig,
|
186
|
+
quant_config: Optional[QuantizationConfig] = None,
|
187
|
+
):
|
185
188
|
super().__init__()
|
186
189
|
self.config = config
|
187
190
|
self.vocab_size = config.vocab_size
|
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
|
|
218
221
|
|
219
222
|
|
220
223
|
class QWenLMHeadModel(nn.Module):
|
221
|
-
def __init__(
|
224
|
+
def __init__(
|
225
|
+
self,
|
226
|
+
config: PretrainedConfig,
|
227
|
+
quant_config: Optional[QuantizationConfig] = None,
|
228
|
+
):
|
222
229
|
super().__init__()
|
223
230
|
self.config = config
|
224
231
|
self.transformer = QWenModel(config, quant_config=quant_config)
|
@@ -276,4 +283,4 @@ class QWenLMHeadModel(nn.Module):
|
|
276
283
|
weight_loader(param, loaded_weight)
|
277
284
|
|
278
285
|
|
279
|
-
EntryClass = QWenLMHeadModel
|
286
|
+
EntryClass = QWenLMHeadModel
|
sglang/srt/models/qwen2.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
8
9
|
from vllm.model_executor.layers.activation import SiluAndMul
|
9
10
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
10
11
|
from vllm.model_executor.layers.linear import (
|
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
|
|
12
13
|
QKVParallelLinear,
|
13
14
|
RowParallelLinear,
|
14
15
|
)
|
15
|
-
from vllm.model_executor.layers.quantization.base_config import
|
16
|
-
QuantizationConfig)
|
16
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
17
17
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
18
18
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
19
19
|
ParallelLMHead,
|
20
20
|
VocabParallelEmbedding,
|
21
21
|
)
|
22
|
-
from vllm.distributed import (
|
23
|
-
get_tensor_model_parallel_world_size,
|
24
|
-
)
|
25
|
-
from sglang.srt.weight_utils import (
|
26
|
-
default_weight_loader,
|
27
|
-
hf_model_weights_iterator,
|
28
|
-
)
|
29
22
|
|
30
23
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
31
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
32
25
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
26
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
33
27
|
|
34
28
|
Qwen2Config = None
|
35
29
|
|
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
|
|
50
44
|
quant_config=quant_config,
|
51
45
|
)
|
52
46
|
self.down_proj = RowParallelLinear(
|
53
|
-
intermediate_size,
|
47
|
+
intermediate_size,
|
48
|
+
hidden_size,
|
49
|
+
bias=False,
|
50
|
+
quant_config=quant_config,
|
54
51
|
)
|
55
52
|
if hidden_act != "silu":
|
56
53
|
raise ValueError(
|