sglang 0.2.14.post1__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
357
357
  # Models trained using ColossalAI may include these tensors in
358
358
  # the checkpoint. Skip them.
359
359
  return
360
+ if name.startswith("model.vision_tower") and name not in params_dict:
361
+ return
362
+
360
363
  for param_name, weight_name, shard_id in stacked_params_mapping:
361
364
  if weight_name not in name:
362
365
  continue
@@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
364
367
  # Skip loading extra bias for GPTQ models.
365
368
  if name.endswith(".bias") and name not in params_dict:
366
369
  continue
367
- if name.startswith("model.vision_tower") and name not in params_dict:
368
- continue
369
370
  param = params_dict[name]
370
371
  weight_loader = param.weight_loader
371
372
  weight_loader(param, loaded_weight, shard_id)
@@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
374
375
  # Skip loading extra bias for GPTQ models.
375
376
  if name.endswith(".bias") and name not in params_dict:
376
377
  return
377
- if name.startswith("model.vision_tower") and name not in params_dict:
378
- return
379
378
  param = params_dict[name]
380
379
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
381
380
  weight_loader(param, loaded_weight)
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
103
103
  # Skip loading extra bias for GPTQ models.
104
104
  if name.endswith(".bias") and name not in params_dict:
105
105
  continue
106
- if name.startswith("model.vision_tower") and name not in params_dict:
107
- continue
108
106
  param = params_dict[name]
109
107
  weight_loader = param.weight_loader
110
108
  weight_loader(param, loaded_weight, shard_id)
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
113
111
  # Skip loading extra bias for GPTQ models.
114
112
  if name.endswith(".bias") and name not in params_dict:
115
113
  continue
116
- if name.startswith("model.vision_tower") and name not in params_dict:
117
- continue
118
114
  param = params_dict[name]
119
115
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
120
116
  weight_loader(param, loaded_weight)
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
57
57
  # Models trained using ColossalAI may include these tensors in
58
58
  # the checkpoint. Skip them.
59
59
  return
60
+ if name.startswith("model.vision_tower") and name not in params_dict:
61
+ return
62
+
60
63
  for param_name, weight_name, shard_id in stacked_params_mapping:
61
64
  if weight_name not in name:
62
65
  continue
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
64
67
  # Skip loading extra bias for GPTQ models.
65
68
  if name.endswith(".bias") and name not in params_dict:
66
69
  continue
67
- if name.startswith("model.vision_tower") and name not in params_dict:
68
- continue
69
70
  param = params_dict[name]
70
71
  weight_loader = param.weight_loader
71
72
  weight_loader(param, loaded_weight, shard_id)
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
74
75
  # Skip loading extra bias for GPTQ models.
75
76
  if name.endswith(".bias") and name not in params_dict:
76
77
  return
77
- if name.startswith("model.vision_tower") and name not in params_dict:
78
- return
79
78
  param = params_dict[name]
80
79
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
81
80
  weight_loader(param, loaded_weight)
@@ -28,7 +28,6 @@ from transformers import (
28
28
  LlavaConfig,
29
29
  MistralConfig,
30
30
  Qwen2Config,
31
- SiglipVisionConfig,
32
31
  SiglipVisionModel,
33
32
  )
34
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
@@ -47,32 +46,19 @@ from sglang.srt.models.mistral import MistralForCausalLM
47
46
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
48
47
 
49
48
 
50
- class LlavaLlamaForCausalLM(nn.Module):
51
- def __init__(
49
+ class LlavaBaseForCausalLM(nn.Module):
50
+ def pad_input_ids(
52
51
  self,
53
- config: LlavaConfig,
54
- quant_config: Optional[QuantizationConfig] = None,
55
- cache_config: Optional[CacheConfig] = None,
56
- ) -> None:
57
- super().__init__()
58
- self.config = config
59
- self.vision_tower = None
60
- self.config.vision_config.hidden_size = config.mm_hidden_size
61
- self.config.text_config.hidden_size = config.hidden_size
62
- self.multi_modal_projector = LlavaMultiModalProjector(config)
63
- self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
64
- if "unpad" in getattr(config, "mm_patch_merge_type", ""):
65
- self.language_model.model.image_newline = nn.Parameter(
66
- torch.empty(config.text_config.hidden_size, dtype=torch.float16)
67
- )
68
-
69
- def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
70
-
52
+ input_ids: List[int],
53
+ pad_value: List[int],
54
+ pixel_values: List,
55
+ image_sizes: List[List[int]],
56
+ ):
71
57
  # hardcode for spatial_unpad + anyres
72
- image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
58
+ image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
73
59
  offset_list = []
74
- for image_s in image_size:
75
- if len(image_size) > 16:
60
+ for image_s in image_sizes:
61
+ if len(image_sizes) > 16:
76
62
  # 2x2 pooling with stride 2
77
63
  new_image_feature_len = (
78
64
  math.ceil(self.image_size / self.patch_size / 2) ** 2
@@ -153,17 +139,15 @@ class LlavaLlamaForCausalLM(nn.Module):
153
139
  if input_metadata.forward_mode == ForwardMode.EXTEND:
154
140
  bs = input_metadata.batch_size
155
141
 
156
- # Embed text input
142
+ # Embed text inputs
157
143
  input_embeds = self.language_model.model.embed_tokens(input_ids)
158
- # Embed vision input
159
- need_vision = (
160
- (positions[input_metadata.extend_start_loc] < self.image_feature_len)
161
- .cpu()
162
- .numpy()
144
+
145
+ # Whether the requests need vision inputs
146
+ max_image_offset = np.array(
147
+ [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
163
148
  )
164
- # FIXME: We need to substract the length of the system prompt
165
- has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
166
- need_vision = need_vision & has_pixel
149
+ start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
150
+ need_vision = start_positions <= max_image_offset
167
151
 
168
152
  if need_vision.any():
169
153
  pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
@@ -332,31 +316,35 @@ class LlavaLlamaForCausalLM(nn.Module):
332
316
  new_image_features.append(image_feature)
333
317
  image_features = new_image_features
334
318
 
319
+ # Fill in the placeholder for the image
335
320
  extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
321
+ prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
336
322
  pt = 0
337
323
  for i in range(bs):
338
324
  if not need_vision[i]:
339
325
  continue
340
326
 
341
327
  start_idx = extend_start_loc_cpu[i]
342
- pad_dim = image_features[pt].shape[-1] # 576, 4096
343
- dim = input_embeds.shape[1]
344
- assert (
345
- pad_dim == dim
346
- ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
347
- # Fill in the placeholder for the image
348
- try:
349
- for j, image_off in enumerate(image_offsets[i]):
350
- # print("actual image_features length: ", image_features[pt][j].shape[0])
351
- pad_len = image_features[pt][j].shape[0]
352
- input_embeds[
353
- start_idx + image_off : start_idx + image_off + pad_len
354
- ] = image_features[pt][j]
355
- except RuntimeError as e:
356
- print(f"RuntimeError in llava image encoding: {e}")
357
- print(image_features[pt].shape)
358
- print(input_embeds.shape)
359
- print(start_idx, image_offsets[i])
328
+ prefix_len = prefix_lens_cpu[i]
329
+
330
+ # Multiple images
331
+ for j, image_offset in enumerate(image_offsets[i]):
332
+ if image_offset < prefix_len:
333
+ continue
334
+
335
+ tmp_image_feature = image_features[pt][j]
336
+ pad_len = tmp_image_feature.shape[0]
337
+
338
+ left_idx = start_idx + (image_offset - prefix_len)
339
+ right_idx = start_idx + (image_offset - prefix_len) + pad_len
340
+ try:
341
+ input_embeds[left_idx:right_idx] = tmp_image_feature
342
+ except RuntimeError as e:
343
+ print(f"RuntimeError in image encoding: {e}")
344
+ print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
345
+ print(
346
+ f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
347
+ )
360
348
  pt += 1
361
349
 
362
350
  return self.language_model(
@@ -366,8 +354,9 @@ class LlavaLlamaForCausalLM(nn.Module):
366
354
  return self.language_model(input_ids, positions, input_metadata)
367
355
 
368
356
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
369
- # load clip vision model by cfg['mm_vision_tower']:
370
- # huggingface_name or path_of_clip_relative_to_llava_model_dir
357
+ # Load clip vision model by cfg['mm_vision_tower']:
358
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
359
+ # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
371
360
  vision_path = self.config.mm_vision_tower
372
361
  if "clip" in vision_path:
373
362
  self.vision_tower = CLIPVisionModel.from_pretrained(
@@ -422,21 +411,41 @@ class LlavaLlamaForCausalLM(nn.Module):
422
411
  # load language model
423
412
  self.language_model.load_weights(weights)
424
413
 
425
- monkey_path_clip_vision_embed_forward()
426
-
427
414
  @property
428
415
  def num_patches_per_side(self):
429
416
  return self.image_size // self.patch_size
430
417
 
431
418
 
432
- class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
419
+ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
433
420
  def __init__(
434
421
  self,
435
422
  config: LlavaConfig,
436
423
  quant_config: Optional[QuantizationConfig] = None,
437
424
  cache_config: Optional[CacheConfig] = None,
438
425
  ) -> None:
439
- super().__init__(config, quant_config=quant_config, cache_config=cache_config)
426
+ super().__init__()
427
+
428
+ self.config = config
429
+ self.vision_tower = None
430
+ self.config.vision_config.hidden_size = config.mm_hidden_size
431
+ self.config.text_config.hidden_size = config.hidden_size
432
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
433
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
434
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
435
+ self.language_model.model.image_newline = nn.Parameter(
436
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
437
+ )
438
+
439
+
440
+ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
441
+ def __init__(
442
+ self,
443
+ config: LlavaConfig,
444
+ quant_config: Optional[QuantizationConfig] = None,
445
+ cache_config: Optional[CacheConfig] = None,
446
+ ) -> None:
447
+ super().__init__()
448
+
440
449
  self.config = config
441
450
  self.vision_tower = None
442
451
  if getattr(self.config, "vision_config", None) is None:
@@ -462,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
462
471
  )
463
472
 
464
473
 
465
- class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
474
+ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
466
475
  def __init__(
467
476
  self,
468
477
  config: LlavaConfig,
469
478
  quant_config: Optional[QuantizationConfig] = None,
470
479
  cache_config: Optional[CacheConfig] = None,
471
480
  ) -> None:
472
- super().__init__(config, quant_config=quant_config, cache_config=cache_config)
481
+ super().__init__()
482
+
473
483
  self.config = config
474
484
  self.vision_tower = None
475
485
  if getattr(self.config, "vision_config", None) is None:
@@ -495,36 +505,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
495
505
  )
496
506
 
497
507
 
498
- first_call = True
499
-
500
-
501
- def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
502
- batch_size = pixel_values.shape[0]
503
-
504
- # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
505
- global first_call
506
- if first_call:
507
- self.patch_embedding.cpu().float()
508
- first_call = False
509
- pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
510
- patch_embeds = self.patch_embedding(pixel_values).cuda().half()
511
-
512
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
513
-
514
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
515
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
516
- embeddings = embeddings + self.position_embedding(self.position_ids)
517
- return embeddings
518
-
519
-
520
- def monkey_path_clip_vision_embed_forward():
521
- import transformers
522
-
523
- setattr(
524
- transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
525
- "forward",
526
- clip_vision_embed_forward,
527
- )
528
-
529
-
530
508
  EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
26
26
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
28
 
29
- from sglang.srt.mm_utils import (
30
- get_anyres_image_grid_shape,
31
- unpad_image,
32
- unpad_image_shape,
33
- )
34
29
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
35
30
  from sglang.srt.models.llama2 import LlamaForCausalLM
36
31
 
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
59
54
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
60
55
  )
61
56
 
62
- def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
57
+ def pad_input_ids(
58
+ self,
59
+ input_ids: List[int],
60
+ pad_value: List[int],
61
+ pixel_values: List,
62
+ image_sizes: List[List[int]],
63
+ ):
63
64
  new_image_feature_len = self.image_feature_len
64
- # now only support spatial_unpad + anyres
65
- # if self.mm_patch_merge_type.startswith("spatial"):
66
- # height = width = self.num_patches_per_side
67
- # if pt_shape[0] > 1:
68
- # if self.image_aspect_ratio == "anyres":
69
- # num_patch_width, num_patch_height = get_anyres_image_grid_shape(
70
- # image_size,
71
- # self.image_grid_pinpoints,
72
- # self.vision_tower.config.image_size,
73
- # )
74
- # if "unpad" in self.mm_patch_merge_type:
75
- # h = num_patch_height * height
76
- # w = num_patch_width * width
77
- # new_h, new_w = unpad_image_shape(h, w, image_size)
78
- # new_image_feature_len += new_h * (new_w + 1)
79
65
 
80
66
  pad_ids = pad_value * (
81
67
  (new_image_feature_len + len(pad_value)) // len(pad_value)
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
87
73
  + pad_ids[:new_image_feature_len]
88
74
  + input_ids[offset + 1 :]
89
75
  )
90
- return new_input_ids, offset
76
+ return new_input_ids, [offset]
91
77
 
92
78
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
93
79
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
133
119
  if input_metadata.forward_mode == ForwardMode.EXTEND:
134
120
  bs = input_metadata.batch_size
135
121
 
136
- # Embed text input
122
+ # Embed text inputs
137
123
  input_embeds = self.language_model.model.embed_tokens(input_ids)
138
124
 
139
- # Embed vision input
140
- need_vision = (
141
- (positions[input_metadata.extend_start_loc] < self.image_feature_len)
142
- .cpu()
143
- .numpy()
125
+ # Whether the requests need vision inputs
126
+ max_image_offset = np.array(
127
+ [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
144
128
  )
145
- # FIXME: We need to substract the length of the system prompt
146
- has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
147
- need_vision = need_vision & has_pixel
129
+ start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
130
+ need_vision = start_positions <= max_image_offset
148
131
 
149
132
  if need_vision.any():
150
133
  pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
151
- image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
152
134
 
153
135
  ########## Encode Image ########
154
136
 
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
183
165
  new_image_features.append(image_feature.flatten(0, 1))
184
166
  image_features = new_image_features
185
167
 
168
+ # Fill in the placeholder for the image
186
169
  extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
170
+ prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
187
171
  pt = 0
188
172
  for i in range(bs):
189
173
  if not need_vision[i]:
190
174
  continue
191
175
 
192
176
  start_idx = extend_start_loc_cpu[i]
193
- pad_len, pad_dim = image_features[pt].shape # 576, 4096
194
- dim = input_embeds.shape[1]
195
- assert (
196
- pad_dim == dim
197
- ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
198
- # Fill in the placeholder for the image
199
- try:
200
- input_embeds[
201
- start_idx
202
- + image_offsets[i] : start_idx
203
- + image_offsets[i]
204
- + pad_len
205
- ] = image_features[pt]
206
- except RuntimeError as e:
207
- print(f"RuntimeError in llava image encoding: {e}")
208
- print(input_embeds.shape)
209
- print(start_idx, image_offsets[i])
210
- pt += 1
177
+ prefix_len = prefix_lens_cpu[i]
178
+
179
+ # Multiple images
180
+ for image_offset in image_offsets[i]:
181
+ if image_offset < prefix_len:
182
+ continue
183
+
184
+ tmp_image_feature = image_features[pt]
185
+ pad_len = tmp_image_feature.shape[0]
186
+
187
+ left_idx = start_idx + (image_offset - prefix_len)
188
+ right_idx = start_idx + (image_offset - prefix_len) + pad_len
189
+ try:
190
+ input_embeds[left_idx:right_idx] = tmp_image_feature
191
+ except RuntimeError as e:
192
+ print(f"RuntimeError in image encoding: {e}")
193
+ print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
194
+ print(
195
+ f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
196
+ )
197
+ pt += 1
211
198
 
212
199
  return self.language_model(
213
200
  input_ids, positions, input_metadata, input_embeds=input_embeds
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
216
203
  return self.language_model(input_ids, positions, input_metadata)
217
204
 
218
205
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
219
- # load clip vision model by cfg['mm_vision_tower']:
220
- # huggingface_name or path_of_clip_relative_to_llava_model_dir
206
+ # Load clip vision model by cfg['mm_vision_tower']:
207
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
208
+ # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
221
209
  vision_path = self.config.mm_vision_tower
222
210
  self.vision_tower = CLIPVisionModel.from_pretrained(
223
211
  vision_path, torch_dtype=torch.float16
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
271
259
  # load language model
272
260
  self.language_model.load_weights(weights)
273
261
 
274
- monkey_path_clip_vision_embed_forward()
275
-
276
262
  @property
277
263
  def num_patches_per_side(self):
278
264
  return self.image_size // self.patch_size
279
265
 
280
266
 
281
- first_call = True
282
-
283
-
284
- def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
285
- batch_size = pixel_values.shape[0]
286
-
287
- # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
288
- global first_call
289
- if first_call:
290
- self.patch_embedding.cpu().float()
291
- first_call = False
292
- pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
293
- patch_embeds = self.patch_embedding(pixel_values).cuda().half()
294
-
295
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
296
-
297
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
298
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
299
- embeddings = embeddings + self.position_embedding(self.position_ids)
300
- return embeddings
301
-
302
-
303
- def monkey_path_clip_vision_embed_forward():
304
- import transformers
305
-
306
- setattr(
307
- transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
308
- "forward",
309
- clip_vision_embed_forward,
310
- )
311
-
312
-
313
267
  EntryClass = LlavaVidForCausalLM
@@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
312
312
  # Models trained using ColossalAI may include these tensors in
313
313
  # the checkpoint. Skip them.
314
314
  continue
315
+ if name.startswith("model.vision_tower") and name not in params_dict:
316
+ continue
317
+
315
318
  for param_name, weight_name, shard_id in stacked_params_mapping:
316
319
  if weight_name not in name:
317
320
  continue
@@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
319
322
  # Skip loading extra bias for GPTQ models.
320
323
  if name.endswith(".bias") and name not in params_dict:
321
324
  continue
322
- if name.startswith("model.vision_tower") and name not in params_dict:
323
- continue
324
325
  param = params_dict[name]
325
326
  weight_loader = param.weight_loader
326
327
  weight_loader(param, loaded_weight, shard_id)
@@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
329
330
  # Skip loading extra bias for GPTQ models.
330
331
  if name.endswith(".bias") and name not in params_dict:
331
332
  continue
332
- if name.startswith("model.vision_tower") and name not in params_dict:
333
- continue
334
333
  param = params_dict[name]
335
334
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
336
335
  weight_loader(param, loaded_weight)
@@ -401,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module):
401
401
  ("gate_up_proj", "up_proj", 1),
402
402
  ]
403
403
 
404
- expert_params_mapping = [
405
- # These are the weights for the experts
406
- # (param_name, weight_name, expert_id, shard_id)
407
- (
408
- (
409
- "experts.w13_weight"
410
- if weight_name in ["gate_proj", "up_proj"]
411
- else "experts.w2_weight"
412
- ),
413
- f"experts.{expert_id}.{weight_name}.weight",
414
- expert_id,
415
- shard_id,
416
- )
417
- for expert_id in range(self.config.num_experts)
418
- for shard_id, weight_name in enumerate(
419
- ["gate_proj", "down_proj", "up_proj"]
420
- )
421
- ]
404
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
405
+ ckpt_gate_proj_name="gate_proj",
406
+ ckpt_down_proj_name="down_proj",
407
+ ckpt_up_proj_name="up_proj",
408
+ num_experts=self.config.num_experts,
409
+ )
422
410
 
423
411
  params_dict = dict(self.named_parameters())
424
412
  for name, loaded_weight in weights:
@@ -458,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module):
458
446
  weight_loader(
459
447
  param,
460
448
  loaded_weight,
461
- weight_name,
449
+ name,
462
450
  shard_id=shard_id,
463
451
  expert_id=expert_id,
464
452
  )
sglang/srt/models/yivl.py CHANGED
@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.models.llava import (
28
- LlavaLlamaForCausalLM,
29
- monkey_path_clip_vision_embed_forward,
30
- )
27
+ from sglang.srt.models.llava import LlavaLlamaForCausalLM
31
28
 
32
29
 
33
30
  class YiVLForCausalLM(LlavaLlamaForCausalLM):
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
50
47
  self.config._name_or_path,
51
48
  torch_dtype=torch.float16,
52
49
  subfolder=self.vision_tower_subfolder,
53
- ).cuda()
50
+ ).to("cuda")
54
51
 
55
52
  self.vision_tower.eval()
56
53
 
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
94
91
  # load language model
95
92
  self.language_model.load_weights(weights)
96
93
 
97
- monkey_path_clip_vision_embed_forward()
98
-
99
94
 
100
95
  class YiVLMultiModalProjector(nn.Module):
101
96
  def __init__(self, config: LlavaConfig):
sglang/srt/server.py CHANGED
@@ -335,12 +335,12 @@ def launch_server(
335
335
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
336
336
 
337
337
  if server_args.dp_size == 1:
338
- start_process = start_controller_process_single
338
+ start_controller_process = start_controller_process_single
339
339
  else:
340
- start_process = start_controller_process_multi
340
+ start_controller_process = start_controller_process_multi
341
341
 
342
342
  proc_controller = mp.Process(
343
- target=start_process,
343
+ target=start_controller_process,
344
344
  args=(server_args, port_args, pipe_controller_writer, model_overide_args),
345
345
  )
346
346
  proc_controller.start()