sglang 0.2.14.post2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/backend/runtime_endpoint.py +8 -4
  4. sglang/lang/interpreter.py +3 -0
  5. sglang/lang/ir.py +5 -0
  6. sglang/launch_server_llavavid.py +12 -12
  7. sglang/srt/configs/__init__.py +5 -0
  8. sglang/srt/configs/exaone.py +195 -0
  9. sglang/srt/constrained/fsm_cache.py +1 -1
  10. sglang/srt/conversation.py +24 -2
  11. sglang/srt/hf_transformers_utils.py +12 -12
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/sampler.py +94 -17
  15. sglang/srt/managers/controller_multi.py +5 -5
  16. sglang/srt/managers/controller_single.py +5 -5
  17. sglang/srt/managers/io_struct.py +6 -1
  18. sglang/srt/managers/schedule_batch.py +26 -11
  19. sglang/srt/managers/tokenizer_manager.py +9 -9
  20. sglang/srt/managers/tp_worker.py +38 -26
  21. sglang/srt/model_config.py +3 -3
  22. sglang/srt/model_executor/cuda_graph_runner.py +26 -9
  23. sglang/srt/model_executor/forward_batch_info.py +68 -23
  24. sglang/srt/model_executor/model_runner.py +15 -22
  25. sglang/srt/models/chatglm.py +9 -15
  26. sglang/srt/models/commandr.py +5 -1
  27. sglang/srt/models/dbrx.py +5 -1
  28. sglang/srt/models/deepseek.py +5 -1
  29. sglang/srt/models/deepseek_v2.py +57 -25
  30. sglang/srt/models/exaone.py +368 -0
  31. sglang/srt/models/gemma.py +5 -1
  32. sglang/srt/models/gemma2.py +5 -1
  33. sglang/srt/models/gpt_bigcode.py +5 -1
  34. sglang/srt/models/grok.py +5 -1
  35. sglang/srt/models/internlm2.py +5 -1
  36. sglang/srt/models/{llama2.py → llama.py} +25 -45
  37. sglang/srt/models/llama_classification.py +34 -41
  38. sglang/srt/models/llama_embedding.py +7 -6
  39. sglang/srt/models/llava.py +8 -11
  40. sglang/srt/models/llavavid.py +5 -6
  41. sglang/srt/models/minicpm.py +5 -1
  42. sglang/srt/models/mistral.py +2 -3
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +6 -2
  47. sglang/srt/models/qwen2_moe.py +5 -14
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/openai_api/adapter.py +16 -1
  50. sglang/srt/openai_api/protocol.py +5 -5
  51. sglang/srt/sampling/sampling_batch_info.py +75 -6
  52. sglang/srt/server.py +6 -6
  53. sglang/srt/utils.py +0 -3
  54. sglang/test/runners.py +1 -1
  55. sglang/test/test_programs.py +68 -0
  56. sglang/test/test_utils.py +4 -0
  57. sglang/utils.py +39 -0
  58. sglang/version.py +1 -1
  59. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/METADATA +9 -8
  60. sglang-0.3.0.dist-info/RECORD +118 -0
  61. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/WHEEL +1 -1
  62. sglang-0.2.14.post2.dist-info/RECORD +0 -115
  63. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/LICENSE +0 -0
  64. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/top_level.txt +0 -0
@@ -16,17 +16,16 @@ limitations under the License.
16
16
  from typing import Iterable, Optional, Tuple
17
17
 
18
18
  import torch
19
- import tqdm
20
19
  from torch import nn
21
20
  from transformers import LlamaConfig
22
21
  from vllm.config import CacheConfig
23
- from vllm.distributed import get_tensor_model_parallel_rank
24
22
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
23
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
24
 
27
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
25
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
26
+ from sglang.srt.layers.sampler import SampleOutput
28
27
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
- from sglang.srt.models.llama2 import LlamaModel
28
+ from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
30
29
 
31
30
 
32
31
  class LlamaForClassification(nn.Module):
@@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module):
42
41
  self.model = LlamaModel(config, quant_config=quant_config)
43
42
 
44
43
  self.classification_head = nn.Linear(
45
- config.hidden_size, config.classification_out_size
44
+ config.hidden_size, config.classification_out_size, bias=False
46
45
  )
47
46
  self.eos_token_id = config.eos_token_id
48
47
 
48
+ self.param_dict = dict(self.named_parameters())
49
+
49
50
  @torch.no_grad()
50
51
  def forward(
51
52
  self,
@@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module):
65
66
  (input_metadata.batch_size, self.config.classification_out_size)
66
67
  ).to(input_ids.device)
67
68
 
68
- return LogitProcessorOutput(
69
+ logits_output = LogitsProcessorOutput(
69
70
  next_token_logits=scores,
70
71
  next_token_logprobs=scores,
71
72
  normalized_prompt_logprobs=scores,
@@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module):
74
75
  output_top_logprobs=None,
75
76
  )
76
77
 
78
+ # A dummy to make this work
79
+ sample_output = SampleOutput(
80
+ success=torch.full(
81
+ size=(scores.shape[0],),
82
+ fill_value=True,
83
+ dtype=torch.bool,
84
+ ),
85
+ probs=torch.full(
86
+ size=(scores.shape[0], 1),
87
+ fill_value=1.0,
88
+ dtype=torch.float16,
89
+ ),
90
+ batch_next_token_ids=torch.full(
91
+ size=(scores.shape[0],),
92
+ fill_value=0,
93
+ dtype=torch.long,
94
+ ),
95
+ )
96
+ return sample_output, logits_output
97
+
77
98
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
78
- stacked_params_mapping = [
79
- # (param_name, shard_name, shard_id)
80
- ("qkv_proj", "q_proj", "q"),
81
- ("qkv_proj", "k_proj", "k"),
82
- ("qkv_proj", "v_proj", "v"),
83
- ("gate_up_proj", "gate_proj", 0),
84
- ("gate_up_proj", "up_proj", 1),
85
- ]
86
- params_dict = dict(self.named_parameters())
87
- if get_tensor_model_parallel_rank() == 0:
88
- weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
89
- for name, loaded_weight in weights:
90
- if "rotary_emb.inv_freq" in name or "projector" in name:
91
- continue
92
- if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
93
- # Models trained using ColossalAI may include these tensors in
94
- # the checkpoint. Skip them.
95
- continue
96
- if "lm_head" in name:
97
- continue
99
+ params_dict = self.param_dict
98
100
 
99
- for param_name, weight_name, shard_id in stacked_params_mapping:
100
- if weight_name not in name:
101
- continue
102
- name = name.replace(weight_name, param_name)
103
- # Skip loading extra bias for GPTQ models.
104
- if name.endswith(".bias") and name not in params_dict:
105
- continue
106
- param = params_dict[name]
107
- weight_loader = param.weight_loader
108
- weight_loader(param, loaded_weight, shard_id)
109
- break
110
- else:
111
- # Skip loading extra bias for GPTQ models.
112
- if name.endswith(".bias") and name not in params_dict:
113
- continue
101
+ for name, loaded_weight in weights:
102
+ if "classification_head" in name:
114
103
  param = params_dict[name]
115
104
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
116
105
  weight_loader(param, loaded_weight)
106
+ elif "lm_head" in name:
107
+ continue
108
+ else:
109
+ LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
117
110
 
118
111
 
119
112
  EntryClass = LlamaForClassification
@@ -1,4 +1,4 @@
1
- from typing import Iterable, Optional, Tuple
1
+ from typing import Iterable, Tuple
2
2
 
3
3
  import torch
4
4
  from torch import nn
@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
7
 
8
8
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
9
  from sglang.srt.model_executor.model_runner import InputMetadata
10
- from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
10
+ from sglang.srt.models.llama import LlamaModel
11
11
 
12
12
 
13
13
  class LlamaEmbeddingModel(nn.Module):
@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module):
16
16
  config: LlamaConfig,
17
17
  quant_config=None,
18
18
  cache_config=None,
19
- efficient_weight_load=False,
20
19
  ) -> None:
21
20
  super().__init__()
22
21
  self.model = LlamaModel(config, quant_config=quant_config)
@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module):
86
85
  load_weights_per_param(name, loaded_weight)
87
86
 
88
87
 
89
- EntryClass = LlamaEmbeddingModel
90
- # compat: e5-mistral model.config class == MistralModel
91
- EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
88
+ class MistralModel(LlamaEmbeddingModel):
89
+ pass
90
+
91
+
92
+ EntryClass = [LlamaEmbeddingModel, MistralModel]
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
41
41
  unpad_image_shape,
42
42
  )
43
43
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
44
- from sglang.srt.models.llama2 import LlamaForCausalLM
44
+ from sglang.srt.models.llama import LlamaForCausalLM
45
45
  from sglang.srt.models.mistral import MistralForCausalLM
46
46
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
47
47
 
@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module):
395
395
  "model.mm_projector.0": "multi_modal_projector.linear_1",
396
396
  "model.mm_projector.2": "multi_modal_projector.linear_2",
397
397
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
398
+ "model.image_newline": "language_model.model.image_newline",
398
399
  }
399
400
  params_dict = dict(self.named_parameters())
400
- weights = list(weights)
401
401
  for name, loaded_weight in weights:
402
- # FIXME: why projector weights read two times?
403
- if "projector" in name or "vision_tower" in name:
402
+ if "projector" in name or "vision_tower" in name or "image_newline" in name:
404
403
  for weight_name, param_name in projector_weights.items():
405
404
  if weight_name in name:
406
405
  name = name.replace(weight_name, param_name)
407
406
  param = params_dict[name]
408
407
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
409
408
  weight_loader(param, loaded_weight)
410
-
411
- # load language model
412
- self.language_model.load_weights(weights)
409
+ else:
410
+ self.language_model.load_weights([(name, loaded_weight)])
413
411
 
414
412
  @property
415
413
  def num_patches_per_side(self):
@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
429
427
  self.vision_tower = None
430
428
  self.config.vision_config.hidden_size = config.mm_hidden_size
431
429
  self.config.text_config.hidden_size = config.hidden_size
430
+
432
431
  self.multi_modal_projector = LlavaMultiModalProjector(config)
433
432
  self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
434
433
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
448
447
 
449
448
  self.config = config
450
449
  self.vision_tower = None
450
+
451
451
  if getattr(self.config, "vision_config", None) is None:
452
452
  self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
453
-
454
453
  if getattr(self.config, "text_config", None) is None:
455
454
  self.config.text_config = Qwen2Config(self.config._name_or_path)
456
455
 
@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
459
458
 
460
459
  if getattr(self.config, "projector_hidden_act", None) is None:
461
460
  self.config.projector_hidden_act = "gelu"
462
-
463
461
  if getattr(self.config, "image_token_index", None) is None:
464
462
  self.config.image_token_index = 151646
465
463
 
@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
482
480
 
483
481
  self.config = config
484
482
  self.vision_tower = None
483
+
485
484
  if getattr(self.config, "vision_config", None) is None:
486
485
  self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
487
-
488
486
  if getattr(self.config, "text_config", None) is None:
489
487
  self.config.text_config = MistralConfig(self.config._name_or_path)
490
488
 
@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
493
491
 
494
492
  if getattr(self.config, "projector_hidden_act", None) is None:
495
493
  self.config.projector_hidden_act = "gelu"
496
-
497
494
  if getattr(self.config, "image_token_index", None) is None:
498
495
  self.config.image_token_index = 32000
499
496
 
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
28
 
29
29
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
30
- from sglang.srt.models.llama2 import LlamaForCausalLM
30
+ from sglang.srt.models.llama import LlamaForCausalLM
31
31
 
32
32
 
33
33
  class LlavaVidForCausalLM(nn.Module):
@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module):
239
239
  "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
240
240
  "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
241
241
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
242
+ "model.image_newline": "language_model.model.image_newline",
242
243
  }
243
244
  params_dict = dict(self.named_parameters())
244
- weights = list(weights)
245
245
  for name, loaded_weight in weights:
246
246
  # FIXME: why projector weights read two times?
247
- if "projector" in name or "vision_tower" in name:
247
+ if "projector" in name or "vision_tower" in name or "image_newline" in name:
248
248
  for weight_name, param_name in projector_weights.items():
249
249
  if weight_name in name:
250
250
  name = name.replace(weight_name, param_name)
@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module):
255
255
  continue
256
256
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
257
257
  weight_loader(param, loaded_weight)
258
-
259
- # load language model
260
- self.language_model.load_weights(weights)
258
+ else:
259
+ self.language_model.load_weights([(name, loaded_weight)])
261
260
 
262
261
  @property
263
262
  def num_patches_per_side(self):
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
297
298
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
298
299
 
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  @torch.no_grad()
302
304
  def forward(
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
314
316
  lm_head_weight = self.model.embed_tokens.weight
315
317
  else:
316
318
  lm_head_weight = self.lm_head.weight
317
- return self.logits_processor(
319
+ logits_output = self.logits_processor(
318
320
  input_ids, hidden_states, lm_head_weight, input_metadata
319
321
  )
322
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
+ return sample_output, logits_output
320
324
 
321
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
326
  stacked_params_mapping = [
@@ -15,12 +15,11 @@ limitations under the License.
15
15
 
16
16
  """Inference-only Mistral model."""
17
17
 
18
- from sglang.srt.models.llama2 import LlamaForCausalLM
18
+ from sglang.srt.models.llama import LlamaForCausalLM
19
19
 
20
20
 
21
21
  class MistralForCausalLM(LlamaForCausalLM):
22
- def __init__(self, *args, **kwargs):
23
- super().__init__(*args, **kwargs)
22
+ pass
24
23
 
25
24
 
26
25
  EntryClass = MistralForCausalLM
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
299
300
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300
301
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301
302
  self.logits_processor = LogitsProcessor(config)
303
+ self.sampler = Sampler()
302
304
 
303
305
  def forward(
304
306
  self,
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
308
310
  input_embeds: torch.Tensor = None,
309
311
  ) -> torch.Tensor:
310
312
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
311
- return self.logits_processor(
313
+ logits_output = self.logits_processor(
312
314
  input_ids, hidden_states, self.lm_head.weight, input_metadata
313
315
  )
316
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
+ return sample_output, logits_output
314
318
 
315
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
316
320
  stacked_params_mapping = [
@@ -358,7 +362,7 @@ class MixtralForCausalLM(nn.Module):
358
362
  weight_loader(
359
363
  param,
360
364
  loaded_weight,
361
- weight_name,
365
+ name,
362
366
  shard_id=shard_id,
363
367
  expert_id=expert_id,
364
368
  )
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
333
334
  self.model = MixtralModel(config, quant_config=quant_config)
334
335
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
335
336
  self.logits_processor = LogitsProcessor(config)
337
+ self.sampler = Sampler()
336
338
 
337
339
  @torch.no_grad()
338
340
  def forward(
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
343
345
  input_embeds: torch.Tensor = None,
344
346
  ) -> torch.Tensor:
345
347
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
346
- return self.logits_processor(
348
+ logits_output = self.logits_processor(
347
349
  input_ids, hidden_states, self.lm_head.weight, input_metadata
348
350
  )
351
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
352
+ return sample_output, logits_output
349
353
 
350
354
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
351
355
  stacked_params_mapping = [
sglang/srt/models/qwen.py CHANGED
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
251
252
  vocab_size = ((config.vocab_size + 63) // 64) * 64
252
253
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
253
254
  self.logits_processor = LogitsProcessor(config)
255
+ self.sampler = Sampler()
254
256
 
255
257
  @torch.no_grad()
256
258
  def forward(
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
260
262
  input_metadata: InputMetadata,
261
263
  ):
262
264
  hidden_states = self.transformer(input_ids, positions, input_metadata)
263
- next_tokens = self.logits_processor(
265
+ logits_output = self.logits_processor(
264
266
  input_ids, hidden_states, self.lm_head.weight, input_metadata
265
267
  )
266
- return next_tokens
268
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
269
+ return sample_output, logits_output
267
270
 
268
271
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
269
272
  stacked_params_mapping = [
@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
- from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
41
+ from sglang.srt.layers.pooler import Pooler, PoolingType
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
  Qwen2Config = None
@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
276
277
  self.model = Qwen2Model(config, quant_config=quant_config)
277
278
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
278
279
  self.logits_processor = LogitsProcessor(config)
280
+ self.sampler = Sampler()
279
281
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
280
282
 
281
283
  @torch.no_grad()
@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
289
291
  ) -> torch.Tensor:
290
292
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
291
293
  if not get_embedding:
292
- return self.logits_processor(
294
+ logits_output = self.logits_processor(
293
295
  input_ids, hidden_states, self.lm_head.weight, input_metadata
294
296
  )
297
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
298
+ return sample_output, logits_output
295
299
  else:
296
300
  return self.pooler(hidden_states, input_metadata)
297
301
 
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
35
35
  ReplicatedLinear,
36
36
  RowParallelLinear,
37
37
  )
38
- from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
38
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
40
39
  from vllm.model_executor.layers.rotary_embedding import get_rope
41
- from vllm.model_executor.layers.sampler import Sampler
42
40
  from vllm.model_executor.layers.vocab_parallel_embedding import (
43
41
  ParallelLMHead,
44
42
  VocabParallelEmbedding,
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
49
47
  from sglang.srt.layers.layernorm import RMSNorm
50
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
49
  from sglang.srt.layers.radix_attention import RadixAttention
50
+ from sglang.srt.layers.sampler import Sampler
52
51
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
53
52
 
54
53
 
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
366
365
  config.vocab_size, config.hidden_size, quant_config=quant_config
367
366
  )
368
367
  self.logits_processor = LogitsProcessor(config)
368
+ self.sampler = Sampler()
369
369
 
370
370
  @torch.no_grad()
371
371
  def forward(
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
376
376
  input_embeds: torch.Tensor = None,
377
377
  ) -> torch.Tensor:
378
378
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
- return self.logits_processor(
379
+ logits_output = self.logits_processor(
380
380
  input_ids, hidden_states, self.lm_head.weight, input_metadata
381
381
  )
382
-
383
- def compute_logits(
384
- self,
385
- input_ids: torch.Tensor,
386
- hidden_states: torch.Tensor,
387
- input_metadata: InputMetadata,
388
- ) -> torch.Tensor:
389
- logits = self.logits_processor(
390
- input_ids, hidden_states, self.lm_head.weight, input_metadata
391
- )
392
- return logits
382
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
383
+ return sample_output, logits_output
393
384
 
394
385
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
395
386
  stacked_params_mapping = [
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
249
250
  self.model = StableLMEpochModel(config, quant_config=quant_config)
250
251
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
251
252
  self.logits_processor = LogitsProcessor(config)
253
+ self.sampler = Sampler()
252
254
 
253
255
  @torch.no_grad()
254
256
  def forward(
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
259
261
  input_embeds: torch.Tensor = None,
260
262
  ) -> torch.Tensor:
261
263
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
262
- return self.logits_processor(
264
+ logits_output = self.logits_processor(
263
265
  input_ids, hidden_states, self.lm_head.weight, input_metadata
264
266
  )
267
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
268
+ return sample_output, logits_output
265
269
 
266
270
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
267
271
  stacked_params_mapping = [
@@ -844,8 +844,23 @@ def v1_chat_generate_request(
844
844
  if not isinstance(request.messages, str):
845
845
  # Apply chat template and its stop strings.
846
846
  if chat_template_name is None:
847
+ openai_compatible_messages = []
848
+ for message in request.messages:
849
+ if isinstance(message.content, str):
850
+ openai_compatible_messages.append(
851
+ {"role": message.role, "content": message.content}
852
+ )
853
+ else:
854
+ content_list = message.dict()["content"]
855
+ for content in content_list:
856
+ if content["type"] == "text":
857
+ openai_compatible_messages.append(
858
+ {"role": message.role, "content": content["text"]}
859
+ )
847
860
  prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
848
- request.messages, tokenize=True, add_generation_prompt=True
861
+ openai_compatible_messages,
862
+ tokenize=True,
863
+ add_generation_prompt=True,
849
864
  )
850
865
  stop = request.stop
851
866
  image_data = None
@@ -200,11 +200,6 @@ class CompletionStreamResponse(BaseModel):
200
200
  usage: Optional[UsageInfo] = None
201
201
 
202
202
 
203
- class ChatCompletionMessageGenericParam(BaseModel):
204
- role: Literal["system", "assistant"]
205
- content: str
206
-
207
-
208
203
  class ChatCompletionMessageContentTextPart(BaseModel):
209
204
  type: Literal["text"]
210
205
  text: str
@@ -225,6 +220,11 @@ ChatCompletionMessageContentPart = Union[
225
220
  ]
226
221
 
227
222
 
223
+ class ChatCompletionMessageGenericParam(BaseModel):
224
+ role: Literal["system", "assistant"]
225
+ content: Union[str, List[ChatCompletionMessageContentTextPart]]
226
+
227
+
228
228
  class ChatCompletionMessageUserParam(BaseModel):
229
229
  role: Literal["user"]
230
230
  content: Union[str, List[ChatCompletionMessageContentPart]]