sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sglang/launch_server_llavavid.py +26 -0
  2. sglang/srt/constrained/fsm_cache.py +11 -2
  3. sglang/srt/constrained/jump_forward.py +1 -0
  4. sglang/srt/hf_transformers_utils.py +0 -149
  5. sglang/srt/layers/activation.py +93 -11
  6. sglang/srt/layers/layernorm.py +47 -4
  7. sglang/srt/layers/logits_processor.py +4 -4
  8. sglang/srt/layers/sampler.py +15 -68
  9. sglang/srt/managers/io_struct.py +5 -4
  10. sglang/srt/managers/schedule_batch.py +20 -25
  11. sglang/srt/managers/tokenizer_manager.py +74 -61
  12. sglang/srt/managers/tp_worker.py +49 -43
  13. sglang/srt/model_executor/cuda_graph_runner.py +17 -31
  14. sglang/srt/model_executor/forward_batch_info.py +9 -26
  15. sglang/srt/model_executor/model_runner.py +20 -17
  16. sglang/srt/models/chatglm.py +13 -5
  17. sglang/srt/models/commandr.py +1 -5
  18. sglang/srt/models/dbrx.py +1 -5
  19. sglang/srt/models/deepseek.py +1 -5
  20. sglang/srt/models/deepseek_v2.py +1 -5
  21. sglang/srt/models/gemma.py +3 -7
  22. sglang/srt/models/gemma2.py +2 -56
  23. sglang/srt/models/gpt_bigcode.py +2 -6
  24. sglang/srt/models/grok.py +10 -8
  25. sglang/srt/models/internlm2.py +1 -5
  26. sglang/srt/models/llama2.py +6 -11
  27. sglang/srt/models/llama_classification.py +2 -6
  28. sglang/srt/models/llama_embedding.py +3 -4
  29. sglang/srt/models/llava.py +69 -91
  30. sglang/srt/models/llavavid.py +40 -86
  31. sglang/srt/models/minicpm.py +1 -5
  32. sglang/srt/models/mixtral.py +1 -5
  33. sglang/srt/models/mixtral_quant.py +1 -5
  34. sglang/srt/models/qwen.py +2 -5
  35. sglang/srt/models/qwen2.py +5 -10
  36. sglang/srt/models/qwen2_moe.py +21 -24
  37. sglang/srt/models/stablelm.py +1 -5
  38. sglang/srt/models/yivl.py +2 -7
  39. sglang/srt/openai_api/adapter.py +85 -4
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +1 -74
  42. sglang/srt/sampling/sampling_params.py +4 -0
  43. sglang/srt/server.py +11 -4
  44. sglang/srt/utils.py +18 -33
  45. sglang/test/runners.py +2 -2
  46. sglang/test/test_layernorm.py +53 -1
  47. sglang/version.py +1 -1
  48. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
  49. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
  50. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
  51. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
26
26
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
28
 
29
- from sglang.srt.mm_utils import (
30
- get_anyres_image_grid_shape,
31
- unpad_image,
32
- unpad_image_shape,
33
- )
34
29
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
35
30
  from sglang.srt.models.llama2 import LlamaForCausalLM
36
31
 
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
59
54
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
60
55
  )
61
56
 
62
- def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
57
+ def pad_input_ids(
58
+ self,
59
+ input_ids: List[int],
60
+ pad_value: List[int],
61
+ pixel_values: List,
62
+ image_sizes: List[List[int]],
63
+ ):
63
64
  new_image_feature_len = self.image_feature_len
64
- # now only support spatial_unpad + anyres
65
- # if self.mm_patch_merge_type.startswith("spatial"):
66
- # height = width = self.num_patches_per_side
67
- # if pt_shape[0] > 1:
68
- # if self.image_aspect_ratio == "anyres":
69
- # num_patch_width, num_patch_height = get_anyres_image_grid_shape(
70
- # image_size,
71
- # self.image_grid_pinpoints,
72
- # self.vision_tower.config.image_size,
73
- # )
74
- # if "unpad" in self.mm_patch_merge_type:
75
- # h = num_patch_height * height
76
- # w = num_patch_width * width
77
- # new_h, new_w = unpad_image_shape(h, w, image_size)
78
- # new_image_feature_len += new_h * (new_w + 1)
79
65
 
80
66
  pad_ids = pad_value * (
81
67
  (new_image_feature_len + len(pad_value)) // len(pad_value)
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
87
73
  + pad_ids[:new_image_feature_len]
88
74
  + input_ids[offset + 1 :]
89
75
  )
90
- return new_input_ids, offset
76
+ return new_input_ids, [offset]
91
77
 
92
78
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
93
79
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
133
119
  if input_metadata.forward_mode == ForwardMode.EXTEND:
134
120
  bs = input_metadata.batch_size
135
121
 
136
- # Embed text input
122
+ # Embed text inputs
137
123
  input_embeds = self.language_model.model.embed_tokens(input_ids)
138
124
 
139
- # Embed vision input
140
- need_vision = (
141
- (positions[input_metadata.extend_start_loc] < self.image_feature_len)
142
- .cpu()
143
- .numpy()
125
+ # Whether the requests need vision inputs
126
+ max_image_offset = np.array(
127
+ [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
144
128
  )
145
- # FIXME: We need to substract the length of the system prompt
146
- has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
147
- need_vision = need_vision & has_pixel
129
+ start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
130
+ need_vision = start_positions <= max_image_offset
148
131
 
149
132
  if need_vision.any():
150
133
  pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
151
- image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
152
134
 
153
135
  ########## Encode Image ########
154
136
 
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
183
165
  new_image_features.append(image_feature.flatten(0, 1))
184
166
  image_features = new_image_features
185
167
 
168
+ # Fill in the placeholder for the image
186
169
  extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
170
+ prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
187
171
  pt = 0
188
172
  for i in range(bs):
189
173
  if not need_vision[i]:
190
174
  continue
191
175
 
192
176
  start_idx = extend_start_loc_cpu[i]
193
- pad_len, pad_dim = image_features[pt].shape # 576, 4096
194
- dim = input_embeds.shape[1]
195
- assert (
196
- pad_dim == dim
197
- ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
198
- # Fill in the placeholder for the image
199
- try:
200
- input_embeds[
201
- start_idx
202
- + image_offsets[i] : start_idx
203
- + image_offsets[i]
204
- + pad_len
205
- ] = image_features[pt]
206
- except RuntimeError as e:
207
- print(f"RuntimeError in llava image encoding: {e}")
208
- print(input_embeds.shape)
209
- print(start_idx, image_offsets[i])
210
- pt += 1
177
+ prefix_len = prefix_lens_cpu[i]
178
+
179
+ # Multiple images
180
+ for image_offset in image_offsets[i]:
181
+ if image_offset < prefix_len:
182
+ continue
183
+
184
+ tmp_image_feature = image_features[pt]
185
+ pad_len = tmp_image_feature.shape[0]
186
+
187
+ left_idx = start_idx + (image_offset - prefix_len)
188
+ right_idx = start_idx + (image_offset - prefix_len) + pad_len
189
+ try:
190
+ input_embeds[left_idx:right_idx] = tmp_image_feature
191
+ except RuntimeError as e:
192
+ print(f"RuntimeError in image encoding: {e}")
193
+ print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
194
+ print(
195
+ f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
196
+ )
197
+ pt += 1
211
198
 
212
199
  return self.language_model(
213
200
  input_ids, positions, input_metadata, input_embeds=input_embeds
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
216
203
  return self.language_model(input_ids, positions, input_metadata)
217
204
 
218
205
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
219
- # load clip vision model by cfg['mm_vision_tower']:
220
- # huggingface_name or path_of_clip_relative_to_llava_model_dir
206
+ # Load clip vision model by cfg['mm_vision_tower']:
207
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
208
+ # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
221
209
  vision_path = self.config.mm_vision_tower
222
210
  self.vision_tower = CLIPVisionModel.from_pretrained(
223
211
  vision_path, torch_dtype=torch.float16
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
271
259
  # load language model
272
260
  self.language_model.load_weights(weights)
273
261
 
274
- monkey_path_clip_vision_embed_forward()
275
-
276
262
  @property
277
263
  def num_patches_per_side(self):
278
264
  return self.image_size // self.patch_size
279
265
 
280
266
 
281
- first_call = True
282
-
283
-
284
- def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
285
- batch_size = pixel_values.shape[0]
286
-
287
- # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
288
- global first_call
289
- if first_call:
290
- self.patch_embedding.cpu().float()
291
- first_call = False
292
- pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
293
- patch_embeds = self.patch_embedding(pixel_values).cuda().half()
294
-
295
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
296
-
297
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
298
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
299
- embeddings = embeddings + self.position_embedding(self.position_ids)
300
- return embeddings
301
-
302
-
303
- def monkey_path_clip_vision_embed_forward():
304
- import transformers
305
-
306
- setattr(
307
- transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
308
- "forward",
309
- clip_vision_embed_forward,
310
- )
311
-
312
-
313
267
  EntryClass = LlavaVidForCausalLM
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.layers.sampler import Sampler
43
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
43
 
45
44
 
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
298
297
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
299
298
 
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  @torch.no_grad()
304
302
  def forward(
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
316
314
  lm_head_weight = self.model.embed_tokens.weight
317
315
  else:
318
316
  lm_head_weight = self.lm_head.weight
319
- logits_output = self.logits_processor(
317
+ return self.logits_processor(
320
318
  input_ids, hidden_states, lm_head_weight, input_metadata
321
319
  )
322
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
- return sample_output, logits_output
324
320
 
325
321
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
322
  stacked_params_mapping = [
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
45
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
45
 
47
46
 
@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
300
299
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
301
300
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
302
301
  self.logits_processor = LogitsProcessor(config)
303
- self.sampler = Sampler()
304
302
 
305
303
  def forward(
306
304
  self,
@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
310
308
  input_embeds: torch.Tensor = None,
311
309
  ) -> torch.Tensor:
312
310
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
313
- logits_output = self.logits_processor(
311
+ return self.logits_processor(
314
312
  input_ids, hidden_states, self.lm_head.weight, input_metadata
315
313
  )
316
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
- return sample_output, logits_output
318
314
 
319
315
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
320
316
  stacked_params_mapping = [
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
49
 
51
50
 
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
334
333
  self.model = MixtralModel(config, quant_config=quant_config)
335
334
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
336
335
  self.logits_processor = LogitsProcessor(config)
337
- self.sampler = Sampler()
338
336
 
339
337
  @torch.no_grad()
340
338
  def forward(
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
345
343
  input_embeds: torch.Tensor = None,
346
344
  ) -> torch.Tensor:
347
345
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
348
- logits_output = self.logits_processor(
346
+ return self.logits_processor(
349
347
  input_ids, hidden_states, self.lm_head.weight, input_metadata
350
348
  )
351
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
352
- return sample_output, logits_output
353
349
 
354
350
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
355
351
  stacked_params_mapping = [
sglang/srt/models/qwen.py CHANGED
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.layers.sampler import Sampler
43
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
43
 
45
44
 
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
252
251
  vocab_size = ((config.vocab_size + 63) // 64) * 64
253
252
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
254
253
  self.logits_processor = LogitsProcessor(config)
255
- self.sampler = Sampler()
256
254
 
257
255
  @torch.no_grad()
258
256
  def forward(
@@ -262,11 +260,10 @@ class QWenLMHeadModel(nn.Module):
262
260
  input_metadata: InputMetadata,
263
261
  ):
264
262
  hidden_states = self.transformer(input_ids, positions, input_metadata)
265
- logits_output = self.logits_processor(
263
+ next_tokens = self.logits_processor(
266
264
  input_ids, hidden_states, self.lm_head.weight, input_metadata
267
265
  )
268
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
269
- return sample_output, logits_output
266
+ return next_tokens
270
267
 
271
268
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
272
269
  stacked_params_mapping = [
@@ -38,9 +38,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
- from sglang.srt.layers.pooler import Pooler, PoolingType
41
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
  Qwen2Config = None
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
277
276
  self.model = Qwen2Model(config, quant_config=quant_config)
278
277
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
279
278
  self.logits_processor = LogitsProcessor(config)
280
- self.sampler = Sampler()
281
279
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
282
280
 
283
281
  @torch.no_grad()
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
291
289
  ) -> torch.Tensor:
292
290
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
293
291
  if not get_embedding:
294
- logits_output = self.logits_processor(
292
+ return self.logits_processor(
295
293
  input_ids, hidden_states, self.lm_head.weight, input_metadata
296
294
  )
297
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
298
- return sample_output, logits_output
299
295
  else:
300
296
  return self.pooler(hidden_states, input_metadata)
301
297
 
@@ -316,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
316
312
  # Models trained using ColossalAI may include these tensors in
317
313
  # the checkpoint. Skip them.
318
314
  continue
315
+ if name.startswith("model.vision_tower") and name not in params_dict:
316
+ continue
317
+
319
318
  for param_name, weight_name, shard_id in stacked_params_mapping:
320
319
  if weight_name not in name:
321
320
  continue
@@ -323,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
323
322
  # Skip loading extra bias for GPTQ models.
324
323
  if name.endswith(".bias") and name not in params_dict:
325
324
  continue
326
- if name.startswith("model.vision_tower") and name not in params_dict:
327
- continue
328
325
  param = params_dict[name]
329
326
  weight_loader = param.weight_loader
330
327
  weight_loader(param, loaded_weight, shard_id)
@@ -333,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
333
330
  # Skip loading extra bias for GPTQ models.
334
331
  if name.endswith(".bias") and name not in params_dict:
335
332
  continue
336
- if name.startswith("model.vision_tower") and name not in params_dict:
337
- continue
338
333
  param = params_dict[name]
339
334
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
340
335
  weight_loader(param, loaded_weight)
@@ -35,8 +35,10 @@ from vllm.model_executor.layers.linear import (
35
35
  ReplicatedLinear,
36
36
  RowParallelLinear,
37
37
  )
38
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
39
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
39
40
  from vllm.model_executor.layers.rotary_embedding import get_rope
41
+ from vllm.model_executor.layers.sampler import Sampler
40
42
  from vllm.model_executor.layers.vocab_parallel_embedding import (
41
43
  ParallelLMHead,
42
44
  VocabParallelEmbedding,
@@ -47,7 +49,6 @@ from sglang.srt.layers.activation import SiluAndMul
47
49
  from sglang.srt.layers.layernorm import RMSNorm
48
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
49
51
  from sglang.srt.layers.radix_attention import RadixAttention
50
- from sglang.srt.layers.sampler import Sampler
51
52
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
53
 
53
54
 
@@ -365,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
365
366
  config.vocab_size, config.hidden_size, quant_config=quant_config
366
367
  )
367
368
  self.logits_processor = LogitsProcessor(config)
368
- self.sampler = Sampler()
369
369
 
370
370
  @torch.no_grad()
371
371
  def forward(
@@ -376,11 +376,20 @@ class Qwen2MoeForCausalLM(nn.Module):
376
376
  input_embeds: torch.Tensor = None,
377
377
  ) -> torch.Tensor:
378
378
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
- logits_output = self.logits_processor(
379
+ return self.logits_processor(
380
380
  input_ids, hidden_states, self.lm_head.weight, input_metadata
381
381
  )
382
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
383
- return sample_output, logits_output
382
+
383
+ def compute_logits(
384
+ self,
385
+ input_ids: torch.Tensor,
386
+ hidden_states: torch.Tensor,
387
+ input_metadata: InputMetadata,
388
+ ) -> torch.Tensor:
389
+ logits = self.logits_processor(
390
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
391
+ )
392
+ return logits
384
393
 
385
394
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
386
395
  stacked_params_mapping = [
@@ -392,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module):
392
401
  ("gate_up_proj", "up_proj", 1),
393
402
  ]
394
403
 
395
- expert_params_mapping = [
396
- # These are the weights for the experts
397
- # (param_name, weight_name, expert_id, shard_id)
398
- (
399
- (
400
- "experts.w13_weight"
401
- if weight_name in ["gate_proj", "up_proj"]
402
- else "experts.w2_weight"
403
- ),
404
- f"experts.{expert_id}.{weight_name}.weight",
405
- expert_id,
406
- shard_id,
407
- )
408
- for expert_id in range(self.config.num_experts)
409
- for shard_id, weight_name in enumerate(
410
- ["gate_proj", "down_proj", "up_proj"]
411
- )
412
- ]
404
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
405
+ ckpt_gate_proj_name="gate_proj",
406
+ ckpt_down_proj_name="down_proj",
407
+ ckpt_up_proj_name="up_proj",
408
+ num_experts=self.config.num_experts,
409
+ )
413
410
 
414
411
  params_dict = dict(self.named_parameters())
415
412
  for name, loaded_weight in weights:
@@ -449,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module):
449
446
  weight_loader(
450
447
  param,
451
448
  loaded_weight,
452
- weight_name,
449
+ name,
453
450
  shard_id=shard_id,
454
451
  expert_id=expert_id,
455
452
  )
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
250
249
  self.model = StableLMEpochModel(config, quant_config=quant_config)
251
250
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
252
251
  self.logits_processor = LogitsProcessor(config)
253
- self.sampler = Sampler()
254
252
 
255
253
  @torch.no_grad()
256
254
  def forward(
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
261
259
  input_embeds: torch.Tensor = None,
262
260
  ) -> torch.Tensor:
263
261
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
264
- logits_output = self.logits_processor(
262
+ return self.logits_processor(
265
263
  input_ids, hidden_states, self.lm_head.weight, input_metadata
266
264
  )
267
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
268
- return sample_output, logits_output
269
265
 
270
266
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
271
267
  stacked_params_mapping = [
sglang/srt/models/yivl.py CHANGED
@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.models.llava import (
28
- LlavaLlamaForCausalLM,
29
- monkey_path_clip_vision_embed_forward,
30
- )
27
+ from sglang.srt.models.llava import LlavaLlamaForCausalLM
31
28
 
32
29
 
33
30
  class YiVLForCausalLM(LlavaLlamaForCausalLM):
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
50
47
  self.config._name_or_path,
51
48
  torch_dtype=torch.float16,
52
49
  subfolder=self.vision_tower_subfolder,
53
- ).cuda()
50
+ ).to("cuda")
54
51
 
55
52
  self.vision_tower.eval()
56
53
 
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
94
91
  # load language model
95
92
  self.language_model.load_weights(weights)
96
93
 
97
- monkey_path_clip_vision_embed_forward()
98
-
99
94
 
100
95
  class YiVLMultiModalProjector(nn.Module):
101
96
  def __init__(self, config: LlavaConfig):