sglang 0.2.14.post1__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +26 -0
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -160
  11. sglang/srt/layers/activation.py +10 -4
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/layernorm.py +47 -1
  14. sglang/srt/layers/logits_processor.py +4 -4
  15. sglang/srt/layers/sampler.py +69 -16
  16. sglang/srt/managers/controller_multi.py +5 -5
  17. sglang/srt/managers/controller_single.py +5 -5
  18. sglang/srt/managers/io_struct.py +11 -5
  19. sglang/srt/managers/schedule_batch.py +25 -13
  20. sglang/srt/managers/tokenizer_manager.py +76 -63
  21. sglang/srt/managers/tp_worker.py +47 -36
  22. sglang/srt/model_config.py +3 -3
  23. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  24. sglang/srt/model_executor/forward_batch_info.py +78 -43
  25. sglang/srt/model_executor/model_runner.py +29 -18
  26. sglang/srt/models/chatglm.py +5 -13
  27. sglang/srt/models/commandr.py +5 -1
  28. sglang/srt/models/dbrx.py +5 -1
  29. sglang/srt/models/deepseek.py +5 -1
  30. sglang/srt/models/deepseek_v2.py +57 -25
  31. sglang/srt/models/exaone.py +399 -0
  32. sglang/srt/models/gemma.py +7 -3
  33. sglang/srt/models/gemma2.py +6 -52
  34. sglang/srt/models/gpt_bigcode.py +5 -1
  35. sglang/srt/models/grok.py +14 -4
  36. sglang/srt/models/internlm2.py +5 -1
  37. sglang/srt/models/llama2.py +10 -7
  38. sglang/srt/models/llama_classification.py +2 -6
  39. sglang/srt/models/llama_embedding.py +3 -4
  40. sglang/srt/models/llava.py +69 -91
  41. sglang/srt/models/llavavid.py +40 -86
  42. sglang/srt/models/minicpm.py +5 -1
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +9 -6
  47. sglang/srt/models/qwen2_moe.py +12 -33
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/models/yivl.py +2 -7
  50. sglang/srt/openai_api/adapter.py +16 -1
  51. sglang/srt/openai_api/protocol.py +5 -5
  52. sglang/srt/sampling/sampling_batch_info.py +79 -6
  53. sglang/srt/server.py +9 -9
  54. sglang/srt/utils.py +18 -36
  55. sglang/test/runners.py +2 -2
  56. sglang/test/test_layernorm.py +53 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/METADATA +8 -8
  59. sglang-0.2.15.dist-info/RECORD +118 -0
  60. sglang-0.2.14.post1.dist-info/RECORD +0 -114
  61. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
26
26
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
28
 
29
- from sglang.srt.mm_utils import (
30
- get_anyres_image_grid_shape,
31
- unpad_image,
32
- unpad_image_shape,
33
- )
34
29
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
35
30
  from sglang.srt.models.llama2 import LlamaForCausalLM
36
31
 
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
59
54
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
60
55
  )
61
56
 
62
- def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
57
+ def pad_input_ids(
58
+ self,
59
+ input_ids: List[int],
60
+ pad_value: List[int],
61
+ pixel_values: List,
62
+ image_sizes: List[List[int]],
63
+ ):
63
64
  new_image_feature_len = self.image_feature_len
64
- # now only support spatial_unpad + anyres
65
- # if self.mm_patch_merge_type.startswith("spatial"):
66
- # height = width = self.num_patches_per_side
67
- # if pt_shape[0] > 1:
68
- # if self.image_aspect_ratio == "anyres":
69
- # num_patch_width, num_patch_height = get_anyres_image_grid_shape(
70
- # image_size,
71
- # self.image_grid_pinpoints,
72
- # self.vision_tower.config.image_size,
73
- # )
74
- # if "unpad" in self.mm_patch_merge_type:
75
- # h = num_patch_height * height
76
- # w = num_patch_width * width
77
- # new_h, new_w = unpad_image_shape(h, w, image_size)
78
- # new_image_feature_len += new_h * (new_w + 1)
79
65
 
80
66
  pad_ids = pad_value * (
81
67
  (new_image_feature_len + len(pad_value)) // len(pad_value)
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
87
73
  + pad_ids[:new_image_feature_len]
88
74
  + input_ids[offset + 1 :]
89
75
  )
90
- return new_input_ids, offset
76
+ return new_input_ids, [offset]
91
77
 
92
78
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
93
79
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
133
119
  if input_metadata.forward_mode == ForwardMode.EXTEND:
134
120
  bs = input_metadata.batch_size
135
121
 
136
- # Embed text input
122
+ # Embed text inputs
137
123
  input_embeds = self.language_model.model.embed_tokens(input_ids)
138
124
 
139
- # Embed vision input
140
- need_vision = (
141
- (positions[input_metadata.extend_start_loc] < self.image_feature_len)
142
- .cpu()
143
- .numpy()
125
+ # Whether the requests need vision inputs
126
+ max_image_offset = np.array(
127
+ [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
144
128
  )
145
- # FIXME: We need to substract the length of the system prompt
146
- has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
147
- need_vision = need_vision & has_pixel
129
+ start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
130
+ need_vision = start_positions <= max_image_offset
148
131
 
149
132
  if need_vision.any():
150
133
  pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
151
- image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
152
134
 
153
135
  ########## Encode Image ########
154
136
 
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
183
165
  new_image_features.append(image_feature.flatten(0, 1))
184
166
  image_features = new_image_features
185
167
 
168
+ # Fill in the placeholder for the image
186
169
  extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
170
+ prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
187
171
  pt = 0
188
172
  for i in range(bs):
189
173
  if not need_vision[i]:
190
174
  continue
191
175
 
192
176
  start_idx = extend_start_loc_cpu[i]
193
- pad_len, pad_dim = image_features[pt].shape # 576, 4096
194
- dim = input_embeds.shape[1]
195
- assert (
196
- pad_dim == dim
197
- ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
198
- # Fill in the placeholder for the image
199
- try:
200
- input_embeds[
201
- start_idx
202
- + image_offsets[i] : start_idx
203
- + image_offsets[i]
204
- + pad_len
205
- ] = image_features[pt]
206
- except RuntimeError as e:
207
- print(f"RuntimeError in llava image encoding: {e}")
208
- print(input_embeds.shape)
209
- print(start_idx, image_offsets[i])
210
- pt += 1
177
+ prefix_len = prefix_lens_cpu[i]
178
+
179
+ # Multiple images
180
+ for image_offset in image_offsets[i]:
181
+ if image_offset < prefix_len:
182
+ continue
183
+
184
+ tmp_image_feature = image_features[pt]
185
+ pad_len = tmp_image_feature.shape[0]
186
+
187
+ left_idx = start_idx + (image_offset - prefix_len)
188
+ right_idx = start_idx + (image_offset - prefix_len) + pad_len
189
+ try:
190
+ input_embeds[left_idx:right_idx] = tmp_image_feature
191
+ except RuntimeError as e:
192
+ print(f"RuntimeError in image encoding: {e}")
193
+ print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
194
+ print(
195
+ f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
196
+ )
197
+ pt += 1
211
198
 
212
199
  return self.language_model(
213
200
  input_ids, positions, input_metadata, input_embeds=input_embeds
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
216
203
  return self.language_model(input_ids, positions, input_metadata)
217
204
 
218
205
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
219
- # load clip vision model by cfg['mm_vision_tower']:
220
- # huggingface_name or path_of_clip_relative_to_llava_model_dir
206
+ # Load clip vision model by cfg['mm_vision_tower']:
207
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
208
+ # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
221
209
  vision_path = self.config.mm_vision_tower
222
210
  self.vision_tower = CLIPVisionModel.from_pretrained(
223
211
  vision_path, torch_dtype=torch.float16
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
271
259
  # load language model
272
260
  self.language_model.load_weights(weights)
273
261
 
274
- monkey_path_clip_vision_embed_forward()
275
-
276
262
  @property
277
263
  def num_patches_per_side(self):
278
264
  return self.image_size // self.patch_size
279
265
 
280
266
 
281
- first_call = True
282
-
283
-
284
- def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
285
- batch_size = pixel_values.shape[0]
286
-
287
- # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
288
- global first_call
289
- if first_call:
290
- self.patch_embedding.cpu().float()
291
- first_call = False
292
- pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
293
- patch_embeds = self.patch_embedding(pixel_values).cuda().half()
294
-
295
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
296
-
297
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
298
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
299
- embeddings = embeddings + self.position_embedding(self.position_ids)
300
- return embeddings
301
-
302
-
303
- def monkey_path_clip_vision_embed_forward():
304
- import transformers
305
-
306
- setattr(
307
- transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
308
- "forward",
309
- clip_vision_embed_forward,
310
- )
311
-
312
-
313
267
  EntryClass = LlavaVidForCausalLM
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
297
298
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
298
299
 
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  @torch.no_grad()
302
304
  def forward(
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
314
316
  lm_head_weight = self.model.embed_tokens.weight
315
317
  else:
316
318
  lm_head_weight = self.lm_head.weight
317
- return self.logits_processor(
319
+ logits_output = self.logits_processor(
318
320
  input_ids, hidden_states, lm_head_weight, input_metadata
319
321
  )
322
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
+ return sample_output, logits_output
320
324
 
321
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
326
  stacked_params_mapping = [
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
299
300
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300
301
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301
302
  self.logits_processor = LogitsProcessor(config)
303
+ self.sampler = Sampler()
302
304
 
303
305
  def forward(
304
306
  self,
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
308
310
  input_embeds: torch.Tensor = None,
309
311
  ) -> torch.Tensor:
310
312
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
311
- return self.logits_processor(
313
+ logits_output = self.logits_processor(
312
314
  input_ids, hidden_states, self.lm_head.weight, input_metadata
313
315
  )
316
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
+ return sample_output, logits_output
314
318
 
315
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
316
320
  stacked_params_mapping = [
@@ -358,7 +362,7 @@ class MixtralForCausalLM(nn.Module):
358
362
  weight_loader(
359
363
  param,
360
364
  loaded_weight,
361
- weight_name,
365
+ name,
362
366
  shard_id=shard_id,
363
367
  expert_id=expert_id,
364
368
  )
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
333
334
  self.model = MixtralModel(config, quant_config=quant_config)
334
335
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
335
336
  self.logits_processor = LogitsProcessor(config)
337
+ self.sampler = Sampler()
336
338
 
337
339
  @torch.no_grad()
338
340
  def forward(
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
343
345
  input_embeds: torch.Tensor = None,
344
346
  ) -> torch.Tensor:
345
347
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
346
- return self.logits_processor(
348
+ logits_output = self.logits_processor(
347
349
  input_ids, hidden_states, self.lm_head.weight, input_metadata
348
350
  )
351
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
352
+ return sample_output, logits_output
349
353
 
350
354
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
351
355
  stacked_params_mapping = [
sglang/srt/models/qwen.py CHANGED
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
251
252
  vocab_size = ((config.vocab_size + 63) // 64) * 64
252
253
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
253
254
  self.logits_processor = LogitsProcessor(config)
255
+ self.sampler = Sampler()
254
256
 
255
257
  @torch.no_grad()
256
258
  def forward(
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
260
262
  input_metadata: InputMetadata,
261
263
  ):
262
264
  hidden_states = self.transformer(input_ids, positions, input_metadata)
263
- next_tokens = self.logits_processor(
265
+ logits_output = self.logits_processor(
264
266
  input_ids, hidden_states, self.lm_head.weight, input_metadata
265
267
  )
266
- return next_tokens
268
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
269
+ return sample_output, logits_output
267
270
 
268
271
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
269
272
  stacked_params_mapping = [
@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
- from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
41
+ from sglang.srt.layers.pooler import Pooler, PoolingType
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
  Qwen2Config = None
@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
276
277
  self.model = Qwen2Model(config, quant_config=quant_config)
277
278
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
278
279
  self.logits_processor = LogitsProcessor(config)
280
+ self.sampler = Sampler()
279
281
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
280
282
 
281
283
  @torch.no_grad()
@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
289
291
  ) -> torch.Tensor:
290
292
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
291
293
  if not get_embedding:
292
- return self.logits_processor(
294
+ logits_output = self.logits_processor(
293
295
  input_ids, hidden_states, self.lm_head.weight, input_metadata
294
296
  )
297
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
298
+ return sample_output, logits_output
295
299
  else:
296
300
  return self.pooler(hidden_states, input_metadata)
297
301
 
@@ -312,6 +316,9 @@ class Qwen2ForCausalLM(nn.Module):
312
316
  # Models trained using ColossalAI may include these tensors in
313
317
  # the checkpoint. Skip them.
314
318
  continue
319
+ if name.startswith("model.vision_tower") and name not in params_dict:
320
+ continue
321
+
315
322
  for param_name, weight_name, shard_id in stacked_params_mapping:
316
323
  if weight_name not in name:
317
324
  continue
@@ -319,8 +326,6 @@ class Qwen2ForCausalLM(nn.Module):
319
326
  # Skip loading extra bias for GPTQ models.
320
327
  if name.endswith(".bias") and name not in params_dict:
321
328
  continue
322
- if name.startswith("model.vision_tower") and name not in params_dict:
323
- continue
324
329
  param = params_dict[name]
325
330
  weight_loader = param.weight_loader
326
331
  weight_loader(param, loaded_weight, shard_id)
@@ -329,8 +334,6 @@ class Qwen2ForCausalLM(nn.Module):
329
334
  # Skip loading extra bias for GPTQ models.
330
335
  if name.endswith(".bias") and name not in params_dict:
331
336
  continue
332
- if name.startswith("model.vision_tower") and name not in params_dict:
333
- continue
334
337
  param = params_dict[name]
335
338
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
336
339
  weight_loader(param, loaded_weight)
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
35
35
  ReplicatedLinear,
36
36
  RowParallelLinear,
37
37
  )
38
- from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
38
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
40
39
  from vllm.model_executor.layers.rotary_embedding import get_rope
41
- from vllm.model_executor.layers.sampler import Sampler
42
40
  from vllm.model_executor.layers.vocab_parallel_embedding import (
43
41
  ParallelLMHead,
44
42
  VocabParallelEmbedding,
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
49
47
  from sglang.srt.layers.layernorm import RMSNorm
50
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
49
  from sglang.srt.layers.radix_attention import RadixAttention
50
+ from sglang.srt.layers.sampler import Sampler
52
51
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
53
52
 
54
53
 
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
366
365
  config.vocab_size, config.hidden_size, quant_config=quant_config
367
366
  )
368
367
  self.logits_processor = LogitsProcessor(config)
368
+ self.sampler = Sampler()
369
369
 
370
370
  @torch.no_grad()
371
371
  def forward(
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
376
376
  input_embeds: torch.Tensor = None,
377
377
  ) -> torch.Tensor:
378
378
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
- return self.logits_processor(
379
+ logits_output = self.logits_processor(
380
380
  input_ids, hidden_states, self.lm_head.weight, input_metadata
381
381
  )
382
-
383
- def compute_logits(
384
- self,
385
- input_ids: torch.Tensor,
386
- hidden_states: torch.Tensor,
387
- input_metadata: InputMetadata,
388
- ) -> torch.Tensor:
389
- logits = self.logits_processor(
390
- input_ids, hidden_states, self.lm_head.weight, input_metadata
391
- )
392
- return logits
382
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
383
+ return sample_output, logits_output
393
384
 
394
385
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
395
386
  stacked_params_mapping = [
@@ -401,24 +392,12 @@ class Qwen2MoeForCausalLM(nn.Module):
401
392
  ("gate_up_proj", "up_proj", 1),
402
393
  ]
403
394
 
404
- expert_params_mapping = [
405
- # These are the weights for the experts
406
- # (param_name, weight_name, expert_id, shard_id)
407
- (
408
- (
409
- "experts.w13_weight"
410
- if weight_name in ["gate_proj", "up_proj"]
411
- else "experts.w2_weight"
412
- ),
413
- f"experts.{expert_id}.{weight_name}.weight",
414
- expert_id,
415
- shard_id,
416
- )
417
- for expert_id in range(self.config.num_experts)
418
- for shard_id, weight_name in enumerate(
419
- ["gate_proj", "down_proj", "up_proj"]
420
- )
421
- ]
395
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
396
+ ckpt_gate_proj_name="gate_proj",
397
+ ckpt_down_proj_name="down_proj",
398
+ ckpt_up_proj_name="up_proj",
399
+ num_experts=self.config.num_experts,
400
+ )
422
401
 
423
402
  params_dict = dict(self.named_parameters())
424
403
  for name, loaded_weight in weights:
@@ -458,7 +437,7 @@ class Qwen2MoeForCausalLM(nn.Module):
458
437
  weight_loader(
459
438
  param,
460
439
  loaded_weight,
461
- weight_name,
440
+ name,
462
441
  shard_id=shard_id,
463
442
  expert_id=expert_id,
464
443
  )
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
249
250
  self.model = StableLMEpochModel(config, quant_config=quant_config)
250
251
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
251
252
  self.logits_processor = LogitsProcessor(config)
253
+ self.sampler = Sampler()
252
254
 
253
255
  @torch.no_grad()
254
256
  def forward(
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
259
261
  input_embeds: torch.Tensor = None,
260
262
  ) -> torch.Tensor:
261
263
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
262
- return self.logits_processor(
264
+ logits_output = self.logits_processor(
263
265
  input_ids, hidden_states, self.lm_head.weight, input_metadata
264
266
  )
267
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
268
+ return sample_output, logits_output
265
269
 
266
270
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
267
271
  stacked_params_mapping = [
sglang/srt/models/yivl.py CHANGED
@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.models.llava import (
28
- LlavaLlamaForCausalLM,
29
- monkey_path_clip_vision_embed_forward,
30
- )
27
+ from sglang.srt.models.llava import LlavaLlamaForCausalLM
31
28
 
32
29
 
33
30
  class YiVLForCausalLM(LlavaLlamaForCausalLM):
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
50
47
  self.config._name_or_path,
51
48
  torch_dtype=torch.float16,
52
49
  subfolder=self.vision_tower_subfolder,
53
- ).cuda()
50
+ ).to("cuda")
54
51
 
55
52
  self.vision_tower.eval()
56
53
 
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
94
91
  # load language model
95
92
  self.language_model.load_weights(weights)
96
93
 
97
- monkey_path_clip_vision_embed_forward()
98
-
99
94
 
100
95
  class YiVLMultiModalProjector(nn.Module):
101
96
  def __init__(self, config: LlavaConfig):
@@ -844,8 +844,23 @@ def v1_chat_generate_request(
844
844
  if not isinstance(request.messages, str):
845
845
  # Apply chat template and its stop strings.
846
846
  if chat_template_name is None:
847
+ openai_compatible_messages = []
848
+ for message in request.messages:
849
+ if isinstance(message.content, str):
850
+ openai_compatible_messages.append(
851
+ {"role": message.role, "content": message.content}
852
+ )
853
+ else:
854
+ content_list = message.dict()["content"]
855
+ for content in content_list:
856
+ if content["type"] == "text":
857
+ openai_compatible_messages.append(
858
+ {"role": message.role, "content": content["text"]}
859
+ )
847
860
  prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
848
- request.messages, tokenize=True, add_generation_prompt=True
861
+ openai_compatible_messages,
862
+ tokenize=True,
863
+ add_generation_prompt=True,
849
864
  )
850
865
  stop = request.stop
851
866
  image_data = None
@@ -200,11 +200,6 @@ class CompletionStreamResponse(BaseModel):
200
200
  usage: Optional[UsageInfo] = None
201
201
 
202
202
 
203
- class ChatCompletionMessageGenericParam(BaseModel):
204
- role: Literal["system", "assistant"]
205
- content: str
206
-
207
-
208
203
  class ChatCompletionMessageContentTextPart(BaseModel):
209
204
  type: Literal["text"]
210
205
  text: str
@@ -225,6 +220,11 @@ ChatCompletionMessageContentPart = Union[
225
220
  ]
226
221
 
227
222
 
223
+ class ChatCompletionMessageGenericParam(BaseModel):
224
+ role: Literal["system", "assistant"]
225
+ content: Union[str, List[ChatCompletionMessageContentTextPart]]
226
+
227
+
228
228
  class ChatCompletionMessageUserParam(BaseModel):
229
229
  role: Literal["user"]
230
230
  content: Union[str, List[ChatCompletionMessageContentPart]]