sglang 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +5 -0
  3. sglang/global_config.py +4 -1
  4. sglang/lang/chat_template.py +9 -2
  5. sglang/lang/interpreter.py +52 -19
  6. sglang/lang/ir.py +12 -9
  7. sglang/lang/tracer.py +1 -1
  8. sglang/launch_server.py +1 -2
  9. sglang/launch_server_llavavid.py +31 -0
  10. sglang/srt/flush_cache.py +16 -0
  11. sglang/srt/hf_transformers_utils.py +8 -1
  12. sglang/srt/managers/io_struct.py +15 -3
  13. sglang/srt/managers/router/infer_batch.py +31 -19
  14. sglang/srt/managers/router/manager.py +6 -8
  15. sglang/srt/managers/router/model_rpc.py +59 -23
  16. sglang/srt/managers/router/model_runner.py +6 -6
  17. sglang/srt/managers/router/radix_cache.py +47 -17
  18. sglang/srt/managers/router/scheduler.py +17 -28
  19. sglang/srt/managers/tokenizer_manager.py +54 -22
  20. sglang/srt/model_config.py +4 -0
  21. sglang/srt/models/commandr.py +6 -10
  22. sglang/srt/models/dbrx.py +14 -15
  23. sglang/srt/models/gemma.py +7 -10
  24. sglang/srt/models/llama2.py +7 -10
  25. sglang/srt/models/llava.py +2 -6
  26. sglang/srt/models/llavavid.py +307 -0
  27. sglang/srt/models/mixtral.py +7 -13
  28. sglang/srt/models/qwen.py +20 -13
  29. sglang/srt/models/qwen2.py +7 -10
  30. sglang/srt/models/stablelm.py +13 -12
  31. sglang/srt/models/yivl.py +1 -4
  32. sglang/srt/server.py +32 -18
  33. sglang/srt/server_args.py +9 -6
  34. sglang/srt/utils.py +126 -17
  35. sglang/srt/weight_utils.py +66 -51
  36. sglang/utils.py +77 -26
  37. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/METADATA +9 -5
  38. sglang-0.1.16.dist-info/RECORD +72 -0
  39. sglang-0.1.15.dist-info/RECORD +0 -69
  40. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  41. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  42. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/srt/models/dbrx.py CHANGED
@@ -5,37 +5,31 @@ from typing import Optional
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
+ from vllm.distributed import (
9
+ get_tensor_model_parallel_rank,
10
+ get_tensor_model_parallel_world_size,
11
+ tensor_model_parallel_all_reduce,
12
+ )
8
13
  from vllm.model_executor.layers.fused_moe import fused_moe
9
14
  from vllm.model_executor.layers.linear import (
10
15
  QKVParallelLinear,
11
16
  ReplicatedLinear,
12
17
  RowParallelLinear,
13
18
  )
14
- from vllm.model_executor.layers.quantization.base_config import (
15
- QuantizationConfig)
19
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
16
20
  from vllm.model_executor.layers.rotary_embedding import get_rope
17
21
  from vllm.model_executor.layers.vocab_parallel_embedding import (
18
22
  DEFAULT_VOCAB_PADDING_SIZE,
19
23
  ParallelLMHead,
20
24
  VocabParallelEmbedding,
21
25
  )
22
- from vllm.distributed import (
23
- tensor_model_parallel_all_reduce,
24
- )
25
- from vllm.distributed import (
26
- get_tensor_model_parallel_rank,
27
- get_tensor_model_parallel_world_size,
28
- )
29
26
  from vllm.model_executor.utils import set_weight_attrs
30
- from sglang.srt.weight_utils import (
31
- default_weight_loader,
32
- hf_model_weights_iterator,
33
- )
34
27
 
35
28
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
29
  from sglang.srt.layers.radix_attention import RadixAttention
37
30
  from sglang.srt.managers.router.model_runner import InputMetadata
38
31
  from sglang.srt.models.dbrx_config import DbrxConfig
32
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
39
33
 
40
34
 
41
35
  class DbrxRouter(nn.Module):
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
291
285
  quant_config: Optional[QuantizationConfig] = None,
292
286
  ):
293
287
  super().__init__()
294
- self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config)
288
+ self.norm_attn_norm = DbrxFusedNormAttention(
289
+ config, layer_id, quant_config=quant_config
290
+ )
295
291
  self.ffn = DbrxExperts(config, quant_config=quant_config)
296
292
 
297
293
  def forward(
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
322
318
  config.d_model,
323
319
  )
324
320
  self.blocks = nn.ModuleList(
325
- [DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)]
321
+ [
322
+ DbrxBlock(config, i, quant_config=quant_config)
323
+ for i in range(config.n_layers)
324
+ ]
326
325
  )
327
326
  self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
328
327
  for module in self.modules():
@@ -7,6 +7,7 @@ import torch
7
7
  from torch import nn
8
8
  from transformers import PretrainedConfig
9
9
  from vllm.config import LoRAConfig
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
10
11
  from vllm.model_executor.layers.activation import GeluAndMul
11
12
  from vllm.model_executor.layers.layernorm import RMSNorm
12
13
  from vllm.model_executor.layers.linear import (
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
14
15
  QKVParallelLinear,
15
16
  RowParallelLinear,
16
17
  )
17
- from vllm.model_executor.layers.quantization.base_config import (
18
- QuantizationConfig)
18
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
20
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
21
- from vllm.distributed import (
22
- get_tensor_model_parallel_world_size,
23
- )
24
- from sglang.srt.weight_utils import (
25
- default_weight_loader,
26
- hf_model_weights_iterator,
27
- )
28
21
 
29
22
  from sglang.srt.layers.logits_processor import LogitsProcessor
30
23
  from sglang.srt.layers.radix_attention import RadixAttention
31
24
  from sglang.srt.managers.router.model_runner import InputMetadata
25
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
32
26
 
33
27
 
34
28
  class GemmaMLP(nn.Module):
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
46
40
  quant_config=quant_config,
47
41
  )
48
42
  self.down_proj = RowParallelLinear(
49
- intermediate_size, hidden_size, bias=False, quant_config=quant_config,
43
+ intermediate_size,
44
+ hidden_size,
45
+ bias=False,
46
+ quant_config=quant_config,
50
47
  )
51
48
  self.act_fn = GeluAndMul()
52
49
 
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
6
6
  import torch
7
7
  from torch import nn
8
8
  from transformers import LlamaConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
9
10
  from vllm.model_executor.layers.activation import SiluAndMul
10
11
  from vllm.model_executor.layers.layernorm import RMSNorm
11
12
  from vllm.model_executor.layers.linear import (
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
13
14
  QKVParallelLinear,
14
15
  RowParallelLinear,
15
16
  )
16
- from vllm.model_executor.layers.quantization.base_config import (
17
- QuantizationConfig)
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
19
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
20
  ParallelLMHead,
21
21
  VocabParallelEmbedding,
22
22
  )
23
- from vllm.distributed import (
24
- get_tensor_model_parallel_world_size,
25
- )
26
- from sglang.srt.weight_utils import (
27
- default_weight_loader,
28
- hf_model_weights_iterator,
29
- )
30
23
 
31
24
  from sglang.srt.layers.logits_processor import LogitsProcessor
32
25
  from sglang.srt.layers.radix_attention import RadixAttention
33
26
  from sglang.srt.managers.router.model_runner import InputMetadata
27
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
34
28
 
35
29
 
36
30
  class LlamaMLP(nn.Module):
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
49
43
  quant_config=quant_config,
50
44
  )
51
45
  self.down_proj = RowParallelLinear(
52
- intermediate_size, hidden_size, bias=False, quant_config=quant_config,
46
+ intermediate_size,
47
+ hidden_size,
48
+ bias=False,
49
+ quant_config=quant_config,
53
50
  )
54
51
  if hidden_act != "silu":
55
52
  raise ValueError(
@@ -7,12 +7,7 @@ import torch
7
7
  from torch import nn
8
8
  from transformers import CLIPVisionModel, LlavaConfig
9
9
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
- from vllm.model_executor.layers.quantization.base_config import (
11
- QuantizationConfig)
12
- from sglang.srt.weight_utils import (
13
- default_weight_loader,
14
- hf_model_weights_iterator,
15
- )
10
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
16
11
 
17
12
  from sglang.srt.managers.router.infer_batch import ForwardMode
18
13
  from sglang.srt.managers.router.model_runner import InputMetadata
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
22
17
  unpad_image_shape,
23
18
  )
24
19
  from sglang.srt.models.llama2 import LlamaForCausalLM
20
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
25
21
 
26
22
 
27
23
  class LlavaLlamaForCausalLM(nn.Module):
@@ -0,0 +1,307 @@
1
+ """Inference-only LLaVa video model compatible with HuggingFace weights."""
2
+
3
+ import os
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
10
+ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
11
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
+
13
+ from sglang.srt.managers.router.infer_batch import ForwardMode
14
+ from sglang.srt.managers.router.model_runner import InputMetadata
15
+ from sglang.srt.mm_utils import (
16
+ get_anyres_image_grid_shape,
17
+ unpad_image,
18
+ unpad_image_shape,
19
+ )
20
+ from sglang.srt.models.llama2 import LlamaForCausalLM
21
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
22
+
23
+
24
+ class LlavaVidForCausalLM(nn.Module):
25
+ def __init__(
26
+ self,
27
+ config: LlavaConfig,
28
+ quant_config: Optional[QuantizationConfig] = None,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.config = config
32
+ self.vision_tower = None
33
+ self.config.vision_config.hidden_size = config.mm_hidden_size
34
+ self.config.text_config.hidden_size = config.hidden_size
35
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
36
+ self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
37
+ self.resampler = nn.AvgPool2d(
38
+ kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
39
+ )
40
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
41
+ self.num_frames = getattr(self.config, "num_frames", 16)
42
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
43
+ self.language_model.model.image_newline = nn.Parameter(
44
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
45
+ )
46
+
47
+ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
48
+ new_image_feature_len = self.image_feature_len
49
+ # now only support spatial_unpad + anyres
50
+ # if self.mm_patch_merge_type.startswith("spatial"):
51
+ # height = width = self.num_patches_per_side
52
+ # if pt_shape[0] > 1:
53
+ # if self.image_aspect_ratio == "anyres":
54
+ # num_patch_width, num_patch_height = get_anyres_image_grid_shape(
55
+ # image_size,
56
+ # self.image_grid_pinpoints,
57
+ # self.vision_tower.config.image_size,
58
+ # )
59
+ # if "unpad" in self.mm_patch_merge_type:
60
+ # h = num_patch_height * height
61
+ # w = num_patch_width * width
62
+ # new_h, new_w = unpad_image_shape(h, w, image_size)
63
+ # new_image_feature_len += new_h * (new_w + 1)
64
+
65
+ pad_ids = pad_value * (
66
+ (new_image_feature_len + len(pad_value)) // len(pad_value)
67
+ )
68
+ # print(input_ids)
69
+ offset = input_ids.index(self.config.image_token_index)
70
+ # old_len + pad_len - 1, because we need to remove image_token_id
71
+ new_input_ids = (
72
+ input_ids[:offset]
73
+ + pad_ids[:new_image_feature_len]
74
+ + input_ids[offset + 1 :]
75
+ )
76
+ return new_input_ids, offset
77
+
78
+ def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
79
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
80
+ # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
81
+
82
+ selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
83
+ if self.vision_feature_select_strategy in ["default", "patch"]:
84
+ selected_image_feature = selected_image_feature[:, 1:]
85
+ elif self.vision_feature_select_strategy == "full":
86
+ selected_image_feature = selected_image_feature
87
+ else:
88
+ raise ValueError(
89
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
90
+ )
91
+
92
+ height = width = self.num_patches_per_side
93
+ num_of_frames = selected_image_feature.shape[0]
94
+ selected_image_feature = selected_image_feature.view(
95
+ num_of_frames, height, width, -1
96
+ )
97
+ selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
98
+ selected_image_feature = (
99
+ self.resampler(selected_image_feature)
100
+ .flatten(2)
101
+ .transpose(1, 2)
102
+ .contiguous()
103
+ )
104
+
105
+ image_features = self.multi_modal_projector(selected_image_feature)
106
+
107
+ return image_features
108
+
109
+ def forward(
110
+ self,
111
+ input_ids: torch.LongTensor,
112
+ positions: torch.Tensor,
113
+ input_metadata: InputMetadata,
114
+ pixel_values: Optional[List[Optional[np.array]]] = None,
115
+ image_sizes: Optional[List[List[int]]] = None,
116
+ image_offsets: Optional[List[int]] = None,
117
+ ) -> torch.Tensor:
118
+ if input_metadata.forward_mode == ForwardMode.EXTEND:
119
+ bs = input_metadata.batch_size
120
+
121
+ # Embed text input
122
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
123
+
124
+ # Embed vision input
125
+ need_vision = (
126
+ (positions[input_metadata.extend_start_loc] < self.image_feature_len)
127
+ .cpu()
128
+ .numpy()
129
+ )
130
+ # FIXME: We need to substract the length of the system prompt
131
+ has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
132
+ need_vision = need_vision & has_pixel
133
+
134
+ if need_vision.any():
135
+ pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
136
+ image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
137
+
138
+ ########## Encode Image ########
139
+
140
+ if pixel_values[0].ndim == 4:
141
+ # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
142
+ np.concatenate(pixel_values, axis=0)
143
+ # ndim=4
144
+ concat_images = torch.tensor(
145
+ np.concatenate(pixel_values, axis=0),
146
+ device=self.vision_tower.device,
147
+ )
148
+ # image_features = self.encode_images(concat_images)
149
+ # split_sizes = [image.shape[0] for image in pixel_values]
150
+ # image_features = torch.split(image_features, split_sizes, dim=0)
151
+ image_features = self.encode_images(
152
+ concat_images
153
+ ) # , prompts)#, image_counts, long_video=long_video)
154
+ split_sizes = [image.shape[0] for image in pixel_values]
155
+ image_features = torch.split(image_features, split_sizes, dim=0)
156
+
157
+ # hd image_features: BS, num_patch, 576, 4096
158
+ else:
159
+ # normal pixel: BS, C=3, H=336, W=336
160
+ pixel_values = torch.tensor(
161
+ np.array(pixel_values), device=self.vision_tower.device
162
+ )
163
+ image_features = self.encode_images(pixel_values)
164
+ # image_features: BS, 576, 4096
165
+
166
+ new_image_features = []
167
+ for image_idx, image_feature in enumerate(image_features):
168
+ new_image_features.append(image_feature.flatten(0, 1))
169
+ image_features = new_image_features
170
+
171
+ extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
172
+ pt = 0
173
+ for i in range(bs):
174
+ if not need_vision[i]:
175
+ continue
176
+
177
+ start_idx = extend_start_loc_cpu[i]
178
+ pad_len, pad_dim = image_features[pt].shape # 576, 4096
179
+ dim = input_embeds.shape[1]
180
+ assert (
181
+ pad_dim == dim
182
+ ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
183
+ # Fill in the placeholder for the image
184
+ try:
185
+ input_embeds[
186
+ start_idx
187
+ + image_offsets[i] : start_idx
188
+ + image_offsets[i]
189
+ + pad_len
190
+ ] = image_features[pt]
191
+ except RuntimeError as e:
192
+ print(f"RuntimeError in llava image encoding: {e}")
193
+ print(input_embeds.shape)
194
+ print(start_idx, image_offsets[i])
195
+ pt += 1
196
+
197
+ return self.language_model(
198
+ input_ids, positions, input_metadata, input_embeds=input_embeds
199
+ )
200
+ elif input_metadata.forward_mode == ForwardMode.DECODE:
201
+ return self.language_model(input_ids, positions, input_metadata)
202
+
203
+ def load_weights(
204
+ self,
205
+ model_name_or_path: str,
206
+ cache_dir: Optional[str] = None,
207
+ load_format: str = "auto",
208
+ revision: Optional[str] = None,
209
+ ):
210
+ # load clip vision model by cfg['mm_vision_tower']:
211
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
212
+ vision_path = self.config.mm_vision_tower
213
+ self.vision_tower = CLIPVisionModel.from_pretrained(
214
+ vision_path, torch_dtype=torch.float16
215
+ ).cuda()
216
+ self.vision_tower.eval()
217
+
218
+ self.vision_feature_layer = self.config.mm_vision_select_layer
219
+ self.vision_feature_select_strategy = self.config.mm_vision_select_feature
220
+ self.image_size = self.vision_tower.config.image_size
221
+ self.patch_size = self.vision_tower.config.patch_size
222
+
223
+ self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
224
+ self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
225
+ self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
226
+
227
+ print(f"target_frames: {self.num_frames}")
228
+ self.image_feature_len = self.num_frames * int(
229
+ (self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
230
+ )
231
+ if self.vision_feature_select_strategy == "patch":
232
+ pass
233
+ elif self.vision_feature_select_strategy == "cls_patch":
234
+ self.image_feature_len += 1
235
+ else:
236
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
237
+
238
+ # load mm_projector
239
+ projector_weights = {
240
+ "model.mm_projector.0": "multi_modal_projector.linear_1",
241
+ "model.mm_projector.2": "multi_modal_projector.linear_2",
242
+ "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
243
+ "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
244
+ "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
245
+ }
246
+ params_dict = dict(self.named_parameters())
247
+ for name, loaded_weight in hf_model_weights_iterator(
248
+ model_name_or_path, cache_dir, load_format, revision
249
+ ):
250
+ # FIXME: why projector weights read two times?
251
+ if "projector" in name or "vision_tower" in name:
252
+ for weight_name, param_name in projector_weights.items():
253
+ if weight_name in name:
254
+ name = name.replace(weight_name, param_name)
255
+ if name in params_dict:
256
+ param = params_dict[name]
257
+ else:
258
+ print(f"Warning: {name} not found in the model")
259
+ continue
260
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
261
+ weight_loader(param, loaded_weight)
262
+
263
+ # load language model
264
+ self.language_model.load_weights(
265
+ model_name_or_path, cache_dir, load_format, revision
266
+ )
267
+
268
+ monkey_path_clip_vision_embed_forward()
269
+
270
+ @property
271
+ def num_patches_per_side(self):
272
+ return self.image_size // self.patch_size
273
+
274
+
275
+ first_call = True
276
+
277
+
278
+ def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
279
+ batch_size = pixel_values.shape[0]
280
+
281
+ # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
282
+ global first_call
283
+ if first_call:
284
+ self.patch_embedding.cpu().float()
285
+ first_call = False
286
+ pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
287
+ patch_embeds = self.patch_embedding(pixel_values).cuda().half()
288
+
289
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
290
+
291
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
292
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
293
+ embeddings = embeddings + self.position_embedding(self.position_ids)
294
+ return embeddings
295
+
296
+
297
+ def monkey_path_clip_vision_embed_forward():
298
+ import transformers
299
+
300
+ setattr(
301
+ transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
302
+ "forward",
303
+ clip_vision_embed_forward,
304
+ )
305
+
306
+
307
+ EntryClass = LlavaVidForCausalLM
@@ -8,34 +8,28 @@ import torch
8
8
  import torch.nn.functional as F
9
9
  from torch import nn
10
10
  from transformers import MixtralConfig
11
+ from vllm.distributed import (
12
+ get_tensor_model_parallel_rank,
13
+ get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce,
15
+ )
11
16
  from vllm.model_executor.layers.layernorm import RMSNorm
12
17
  from vllm.model_executor.layers.linear import (
13
18
  QKVParallelLinear,
14
19
  ReplicatedLinear,
15
20
  RowParallelLinear,
16
21
  )
17
- from vllm.model_executor.layers.quantization.base_config import (
18
- QuantizationConfig)
22
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
23
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
24
  from vllm.model_executor.layers.vocab_parallel_embedding import (
21
25
  ParallelLMHead,
22
26
  VocabParallelEmbedding,
23
27
  )
24
- from vllm.distributed import (
25
- tensor_model_parallel_all_reduce,
26
- )
27
- from vllm.distributed import (
28
- get_tensor_model_parallel_rank,
29
- get_tensor_model_parallel_world_size,
30
- )
31
- from sglang.srt.weight_utils import (
32
- default_weight_loader,
33
- hf_model_weights_iterator,
34
- )
35
28
 
36
29
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
30
  from sglang.srt.layers.radix_attention import RadixAttention
38
31
  from sglang.srt.managers.router.model_runner import InputMetadata
32
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
39
33
 
40
34
 
41
35
  class MixtralMLP(nn.Module):
sglang/srt/models/qwen.py CHANGED
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
3
3
  import torch
4
4
  from torch import nn
5
5
  from transformers import PretrainedConfig
6
+ from vllm.distributed import get_tensor_model_parallel_world_size
6
7
  from vllm.model_executor.layers.activation import SiluAndMul
7
8
  from vllm.model_executor.layers.layernorm import RMSNorm
8
9
  from vllm.model_executor.layers.linear import (
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
10
11
  QKVParallelLinear,
11
12
  RowParallelLinear,
12
13
  )
13
- from vllm.model_executor.layers.quantization.base_config import (
14
- QuantizationConfig)
14
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
15
15
  from vllm.model_executor.layers.rotary_embedding import get_rope
16
16
  from vllm.model_executor.layers.vocab_parallel_embedding import (
17
17
  ParallelLMHead,
18
18
  VocabParallelEmbedding,
19
19
  )
20
- from vllm.distributed import (
21
- get_tensor_model_parallel_world_size,
22
- )
23
- from sglang.srt.weight_utils import (
24
- default_weight_loader,
25
- hf_model_weights_iterator,
26
- )
27
20
 
28
21
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
22
  from sglang.srt.layers.radix_attention import RadixAttention
30
23
  from sglang.srt.managers.router.model_runner import InputMetadata
24
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
31
25
 
32
26
 
33
27
  class QWenMLP(nn.Module):
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
132
126
 
133
127
 
134
128
  class QWenBlock(nn.Module):
135
- def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
129
+ def __init__(
130
+ self,
131
+ config: PretrainedConfig,
132
+ layer_id,
133
+ quant_config: Optional[QuantizationConfig] = None,
134
+ ):
136
135
  super().__init__()
137
136
  self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
138
137
 
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
181
180
 
182
181
 
183
182
  class QWenModel(nn.Module):
184
- def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
183
+ def __init__(
184
+ self,
185
+ config: PretrainedConfig,
186
+ quant_config: Optional[QuantizationConfig] = None,
187
+ ):
185
188
  super().__init__()
186
189
  self.config = config
187
190
  self.vocab_size = config.vocab_size
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
218
221
 
219
222
 
220
223
  class QWenLMHeadModel(nn.Module):
221
- def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
224
+ def __init__(
225
+ self,
226
+ config: PretrainedConfig,
227
+ quant_config: Optional[QuantizationConfig] = None,
228
+ ):
222
229
  super().__init__()
223
230
  self.config = config
224
231
  self.transformer = QWenModel(config, quant_config=quant_config)
@@ -276,4 +283,4 @@ class QWenLMHeadModel(nn.Module):
276
283
  weight_loader(param, loaded_weight)
277
284
 
278
285
 
279
- EntryClass = QWenLMHeadModel
286
+ EntryClass = QWenLMHeadModel
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from torch import nn
8
+ from vllm.distributed import get_tensor_model_parallel_world_size
8
9
  from vllm.model_executor.layers.activation import SiluAndMul
9
10
  from vllm.model_executor.layers.layernorm import RMSNorm
10
11
  from vllm.model_executor.layers.linear import (
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
12
13
  QKVParallelLinear,
13
14
  RowParallelLinear,
14
15
  )
15
- from vllm.model_executor.layers.quantization.base_config import (
16
- QuantizationConfig)
16
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
17
  from vllm.model_executor.layers.rotary_embedding import get_rope
18
18
  from vllm.model_executor.layers.vocab_parallel_embedding import (
19
19
  ParallelLMHead,
20
20
  VocabParallelEmbedding,
21
21
  )
22
- from vllm.distributed import (
23
- get_tensor_model_parallel_world_size,
24
- )
25
- from sglang.srt.weight_utils import (
26
- default_weight_loader,
27
- hf_model_weights_iterator,
28
- )
29
22
 
30
23
  from sglang.srt.layers.logits_processor import LogitsProcessor
31
24
  from sglang.srt.layers.radix_attention import RadixAttention
32
25
  from sglang.srt.managers.router.model_runner import InputMetadata
26
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
33
27
 
34
28
  Qwen2Config = None
35
29
 
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
50
44
  quant_config=quant_config,
51
45
  )
52
46
  self.down_proj = RowParallelLinear(
53
- intermediate_size, hidden_size, bias=False, quant_config=quant_config,
47
+ intermediate_size,
48
+ hidden_size,
49
+ bias=False,
50
+ quant_config=quant_config,
54
51
  )
55
52
  if hidden_act != "silu":
56
53
  raise ValueError(