sglang 0.2.14.post2__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +12 -12
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -11
  11. sglang/srt/layers/extend_attention.py +13 -8
  12. sglang/srt/layers/logits_processor.py +4 -4
  13. sglang/srt/layers/sampler.py +69 -16
  14. sglang/srt/managers/controller_multi.py +5 -5
  15. sglang/srt/managers/controller_single.py +5 -5
  16. sglang/srt/managers/io_struct.py +6 -1
  17. sglang/srt/managers/schedule_batch.py +20 -8
  18. sglang/srt/managers/tokenizer_manager.py +2 -2
  19. sglang/srt/managers/tp_worker.py +38 -26
  20. sglang/srt/model_config.py +3 -3
  21. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  22. sglang/srt/model_executor/forward_batch_info.py +68 -23
  23. sglang/srt/model_executor/model_runner.py +14 -12
  24. sglang/srt/models/chatglm.py +4 -12
  25. sglang/srt/models/commandr.py +5 -1
  26. sglang/srt/models/dbrx.py +5 -1
  27. sglang/srt/models/deepseek.py +5 -1
  28. sglang/srt/models/deepseek_v2.py +57 -25
  29. sglang/srt/models/exaone.py +399 -0
  30. sglang/srt/models/gemma.py +5 -1
  31. sglang/srt/models/gemma2.py +5 -1
  32. sglang/srt/models/gpt_bigcode.py +5 -1
  33. sglang/srt/models/grok.py +5 -1
  34. sglang/srt/models/internlm2.py +5 -1
  35. sglang/srt/models/llama2.py +7 -3
  36. sglang/srt/models/llama_classification.py +2 -2
  37. sglang/srt/models/minicpm.py +5 -1
  38. sglang/srt/models/mixtral.py +6 -2
  39. sglang/srt/models/mixtral_quant.py +5 -1
  40. sglang/srt/models/qwen.py +5 -2
  41. sglang/srt/models/qwen2.py +6 -2
  42. sglang/srt/models/qwen2_moe.py +5 -14
  43. sglang/srt/models/stablelm.py +5 -1
  44. sglang/srt/openai_api/adapter.py +16 -1
  45. sglang/srt/openai_api/protocol.py +5 -5
  46. sglang/srt/sampling/sampling_batch_info.py +79 -6
  47. sglang/srt/server.py +6 -6
  48. sglang/srt/utils.py +0 -3
  49. sglang/test/runners.py +1 -1
  50. sglang/version.py +1 -1
  51. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/METADATA +7 -7
  52. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/RECORD +55 -52
  53. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  54. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  55. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,6 @@
1
+ import dataclasses
1
2
  import logging
3
+ from typing import Union
2
4
 
3
5
  import torch
4
6
  from flashinfer.sampling import (
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
9
11
  )
10
12
  from vllm.model_executor.custom_op import CustomOp
11
13
 
14
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
15
+
12
16
  # TODO: move this dict to another place
13
17
  from sglang.srt.managers.schedule_batch import global_server_args_dict
14
18
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
16
20
  logger = logging.getLogger(__name__)
17
21
 
18
22
 
23
+ @dataclasses.dataclass
24
+ class SampleOutput:
25
+ success: torch.Tensor
26
+ probs: torch.Tensor
27
+ batch_next_token_ids: torch.Tensor
28
+
29
+
19
30
  class Sampler(CustomOp):
20
31
  def __init__(self):
21
32
  super().__init__()
22
33
 
23
- def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
34
+ def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
35
+ # min-token, presence, frequency
36
+ if sampling_info.linear_penalties is not None:
37
+ logits += sampling_info.linear_penalties
38
+
39
+ # repetition
40
+ if sampling_info.scaling_penalties is not None:
41
+ logits = torch.where(
42
+ logits > 0,
43
+ logits / sampling_info.scaling_penalties,
44
+ logits * sampling_info.scaling_penalties,
45
+ )
46
+
47
+ return logits
48
+
49
+ def _get_probs(
50
+ self,
51
+ logits: torch.Tensor,
52
+ sampling_info: SamplingBatchInfo,
53
+ is_torch_compile: bool = False,
54
+ ):
24
55
  # Post process logits
25
56
  logits = logits.contiguous()
26
57
  logits.div_(sampling_info.temperatures)
58
+ if is_torch_compile:
59
+ # FIXME: Temporary workaround for unknown bugs in torch.compile
60
+ logits.add_(0)
61
+
27
62
  if sampling_info.logit_bias is not None:
28
63
  logits.add_(sampling_info.logit_bias)
29
64
 
30
65
  if sampling_info.vocab_mask is not None:
31
- logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
66
+ logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
32
67
 
33
- logits = sampling_info.penalizer_orchestrator.apply(logits)
68
+ logits = self._apply_penalties(logits, sampling_info)
34
69
 
35
- probs = torch.softmax(logits, dim=-1)
70
+ return torch.softmax(logits, dim=-1)
71
+
72
+ def forward_cuda(
73
+ self,
74
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
75
+ sampling_info: SamplingBatchInfo,
76
+ ):
77
+ if isinstance(logits, LogitsProcessorOutput):
78
+ logits = logits.next_token_logits
79
+
80
+ probs = self._get_probs(logits, sampling_info)
36
81
 
37
82
  if not global_server_args_dict["disable_flashinfer_sampling"]:
38
83
  max_top_k_round, batch_size = 32, probs.shape[0]
39
84
  uniform_samples = torch.rand(
40
85
  (max_top_k_round, batch_size), device=probs.device
41
86
  )
42
- if sampling_info.min_ps.any():
87
+ if sampling_info.need_min_p_sampling:
43
88
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
44
89
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
45
90
  batch_next_token_ids, success = min_p_sampling_from_probs(
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
55
100
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
56
101
  )
57
102
 
58
- if not torch.all(success):
59
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
60
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
61
- argmax_ids = torch.argmax(probs, dim=-1)
62
- batch_next_token_ids = torch.where(
63
- success, batch_next_token_ids, argmax_ids
64
- )
103
+ return SampleOutput(success, probs, batch_next_token_ids)
65
104
 
66
- return batch_next_token_ids
105
+ def forward_native(
106
+ self,
107
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
108
+ sampling_info: SamplingBatchInfo,
109
+ ):
110
+ if isinstance(logits, LogitsProcessorOutput):
111
+ logits = logits.next_token_logits
112
+
113
+ probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
114
+
115
+ batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
116
+ probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
117
+ )
67
118
 
68
- def forward_native():
69
- raise NotImplementedError("Native forward is not implemented yet.")
119
+ return SampleOutput(success, probs, batch_next_token_ids)
70
120
 
71
121
 
72
122
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
87
137
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
88
138
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
89
139
  try:
90
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
140
+ # FIXME: torch.multiomial does not support num_samples = 1
141
+ sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
142
+ :, :1
143
+ ]
91
144
  except RuntimeError as e:
92
145
  logger.warning(f"Sampling error: {e}")
93
146
  batch_next_token_ids = torch.zeros(
@@ -71,12 +71,12 @@ class ControllerMulti:
71
71
  self,
72
72
  server_args: ServerArgs,
73
73
  port_args: PortArgs,
74
- model_overide_args,
74
+ model_override_args,
75
75
  ):
76
76
  # Parse args
77
77
  self.server_args = server_args
78
78
  self.port_args = port_args
79
- self.model_overide_args = model_overide_args
79
+ self.model_override_args = model_override_args
80
80
  self.load_balance_method = LoadBalanceMethod.from_str(
81
81
  server_args.load_balance_method
82
82
  )
@@ -114,7 +114,7 @@ class ControllerMulti:
114
114
  self.server_args,
115
115
  self.port_args,
116
116
  pipe_controller_writer,
117
- self.model_overide_args,
117
+ self.model_override_args,
118
118
  True,
119
119
  gpu_ids,
120
120
  dp_worker_id,
@@ -189,14 +189,14 @@ def start_controller_process(
189
189
  server_args: ServerArgs,
190
190
  port_args: PortArgs,
191
191
  pipe_writer,
192
- model_overide_args: dict,
192
+ model_override_args: dict,
193
193
  ):
194
194
  """Start a controller process."""
195
195
 
196
196
  configure_logger(server_args)
197
197
 
198
198
  try:
199
- controller = ControllerMulti(server_args, port_args, model_overide_args)
199
+ controller = ControllerMulti(server_args, port_args, model_override_args)
200
200
  except Exception:
201
201
  pipe_writer.send(get_exception_traceback())
202
202
  raise
@@ -40,7 +40,7 @@ class ControllerSingle:
40
40
  self,
41
41
  server_args: ServerArgs,
42
42
  port_args: PortArgs,
43
- model_overide_args: dict,
43
+ model_override_args: dict,
44
44
  gpu_ids: List[int],
45
45
  is_data_parallel_worker: bool,
46
46
  dp_worker_id: int,
@@ -76,7 +76,7 @@ class ControllerSingle:
76
76
  tp_rank_range,
77
77
  server_args,
78
78
  port_args.nccl_ports[dp_worker_id],
79
- model_overide_args,
79
+ model_override_args,
80
80
  )
81
81
 
82
82
  # Launch tp rank 0
@@ -85,7 +85,7 @@ class ControllerSingle:
85
85
  0,
86
86
  server_args,
87
87
  port_args.nccl_ports[dp_worker_id],
88
- model_overide_args,
88
+ model_override_args,
89
89
  )
90
90
  self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
91
91
 
@@ -126,7 +126,7 @@ def start_controller_process(
126
126
  server_args: ServerArgs,
127
127
  port_args: PortArgs,
128
128
  pipe_writer: multiprocessing.connection.Connection,
129
- model_overide_args: dict,
129
+ model_override_args: dict,
130
130
  is_data_parallel_worker: bool = False,
131
131
  gpu_ids: List[int] = None,
132
132
  dp_worker_id: int = None,
@@ -149,7 +149,7 @@ def start_controller_process(
149
149
  controller = ControllerSingle(
150
150
  server_args,
151
151
  port_args,
152
- model_overide_args,
152
+ model_override_args,
153
153
  gpu_ids,
154
154
  is_data_parallel_worker,
155
155
  dp_worker_id,
@@ -18,8 +18,9 @@ The definition of objects transfered between different
18
18
  processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  """
20
20
 
21
+ import copy
21
22
  import uuid
22
- from dataclasses import dataclass
23
+ from dataclasses import dataclass, field
23
24
  from typing import Dict, List, Optional, Union
24
25
 
25
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
@@ -249,6 +250,10 @@ class BatchTokenIDOut:
249
250
  meta_info: List[Dict]
250
251
  finished_reason: List[BaseFinishReason]
251
252
 
253
+ def __post_init__(self):
254
+ # deepcopy meta_info to avoid modification in place
255
+ self.meta_info = copy.deepcopy(self.meta_info)
256
+
252
257
 
253
258
  @dataclass
254
259
  class BatchStrOut:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +19,7 @@ limitations under the License.
17
19
 
18
20
  import logging
19
21
  from dataclasses import dataclass
20
- from typing import List, Optional, Union
22
+ from typing import TYPE_CHECKING, List, Optional, Union
21
23
 
22
24
  import torch
23
25
 
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
29
31
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
30
32
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
31
33
 
34
+ if TYPE_CHECKING:
35
+ from sglang.srt.layers.sampler import SampleOutput
36
+
37
+
32
38
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
33
39
 
34
40
  # Put some global args for easy access
@@ -678,11 +684,17 @@ class ScheduleBatch:
678
684
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
679
685
  self.return_logprob = any(req.return_logprob for req in self.reqs)
680
686
 
681
- def sample(self, logits: torch.Tensor):
682
- from sglang.srt.layers.sampler import Sampler
683
-
684
- sampler = Sampler()
685
-
686
- batch_next_token_ids = sampler(logits, self.sampling_info)
687
+ def check_sample_results(self, sample_output: SampleOutput):
688
+ if not torch.all(sample_output.success):
689
+ probs = sample_output.probs
690
+ batch_next_token_ids = sample_output.batch_next_token_ids
691
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
692
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
693
+ argmax_ids = torch.argmax(probs, dim=-1)
694
+ batch_next_token_ids = torch.where(
695
+ sample_output.success, batch_next_token_ids, argmax_ids
696
+ )
697
+ sample_output.probs = probs
698
+ sample_output.batch_next_token_ids = batch_next_token_ids
687
699
 
688
- return batch_next_token_ids
700
+ return sample_output.batch_next_token_ids
@@ -77,7 +77,7 @@ class TokenizerManager:
77
77
  self,
78
78
  server_args: ServerArgs,
79
79
  port_args: PortArgs,
80
- model_overide_args: dict = None,
80
+ model_override_args: dict = None,
81
81
  ):
82
82
  self.server_args = server_args
83
83
 
@@ -95,7 +95,7 @@ class TokenizerManager:
95
95
  self.hf_config = get_config(
96
96
  self.model_path,
97
97
  trust_remote_code=server_args.trust_remote_code,
98
- model_overide_args=model_overide_args,
98
+ model_override_args=model_override_args,
99
99
  )
100
100
  self.is_generation = is_generation_model(
101
101
  self.hf_config.architectures, self.server_args.is_embedding
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
31
31
  from sglang.srt.constrained.fsm_cache import FSMCache
32
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
34
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
36
36
  AbortReq,
37
37
  BatchEmbeddingOut,
@@ -76,7 +76,7 @@ class ModelTpServer:
76
76
  tp_rank: int,
77
77
  server_args: ServerArgs,
78
78
  nccl_port: int,
79
- model_overide_args: dict,
79
+ model_override_args: dict,
80
80
  ):
81
81
  suppress_other_loggers()
82
82
 
@@ -93,7 +93,7 @@ class ModelTpServer:
93
93
  server_args.model_path,
94
94
  server_args.trust_remote_code,
95
95
  context_length=server_args.context_length,
96
- model_overide_args=model_overide_args,
96
+ model_override_args=model_override_args,
97
97
  )
98
98
 
99
99
  self.model_runner = ModelRunner(
@@ -504,21 +504,29 @@ class ModelTpServer:
504
504
  if self.model_runner.is_generation:
505
505
  # Forward and sample the next tokens
506
506
  if batch.extend_num_tokens != 0:
507
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
508
- next_token_ids = batch.sample(output.next_token_logits)
507
+ sample_output, logits_output = self.model_runner.forward(
508
+ batch, ForwardMode.EXTEND
509
+ )
510
+ next_token_ids = batch.check_sample_results(sample_output)
509
511
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
510
512
  next_token_ids
511
513
  )
512
514
 
513
515
  # Move logprobs to cpu
514
- if output.next_token_logprobs is not None:
515
- output.next_token_logprobs = output.next_token_logprobs[
516
- torch.arange(len(next_token_ids), device=next_token_ids.device),
517
- next_token_ids,
518
- ].tolist()
519
- output.input_token_logprobs = output.input_token_logprobs.tolist()
520
- output.normalized_prompt_logprobs = (
521
- output.normalized_prompt_logprobs.tolist()
516
+ if logits_output.next_token_logprobs is not None:
517
+ logits_output.next_token_logprobs = (
518
+ logits_output.next_token_logprobs[
519
+ torch.arange(
520
+ len(next_token_ids), device=next_token_ids.device
521
+ ),
522
+ next_token_ids,
523
+ ].tolist()
524
+ )
525
+ logits_output.input_token_logprobs = (
526
+ logits_output.input_token_logprobs.tolist()
527
+ )
528
+ logits_output.normalized_prompt_logprobs = (
529
+ logits_output.normalized_prompt_logprobs.tolist()
522
530
  )
523
531
 
524
532
  next_token_ids = next_token_ids.tolist()
@@ -557,12 +565,14 @@ class ModelTpServer:
557
565
  self.req_to_token_pool.free(req.req_pool_idx)
558
566
 
559
567
  if req.return_logprob:
560
- self.add_logprob_return_values(i, req, pt, next_token_ids, output)
568
+ self.add_logprob_return_values(
569
+ i, req, pt, next_token_ids, logits_output
570
+ )
561
571
  pt += req.extend_input_len
562
572
  else:
563
573
  assert batch.extend_num_tokens != 0
564
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
565
- embeddings = output.embeddings.tolist()
574
+ logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
575
+ embeddings = logits_output.embeddings.tolist()
566
576
 
567
577
  # Check finish conditions
568
578
  for i, req in enumerate(batch.reqs):
@@ -590,7 +600,7 @@ class ModelTpServer:
590
600
  req: Req,
591
601
  pt: int,
592
602
  next_token_ids: List[int],
593
- output: LogitProcessorOutput,
603
+ output: LogitsProcessorOutput,
594
604
  ):
595
605
  if req.normalized_prompt_logprob is None:
596
606
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -672,15 +682,17 @@ class ModelTpServer:
672
682
  batch.prepare_for_decode()
673
683
 
674
684
  # Forward and sample the next tokens
675
- output = self.model_runner.forward(batch, ForwardMode.DECODE)
676
- next_token_ids = batch.sample(output.next_token_logits)
685
+ sample_output, logits_output = self.model_runner.forward(
686
+ batch, ForwardMode.DECODE
687
+ )
688
+ next_token_ids = batch.check_sample_results(sample_output)
677
689
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
678
690
  next_token_ids
679
691
  )
680
692
 
681
693
  # Move logprobs to cpu
682
- if output.next_token_logprobs is not None:
683
- next_token_logprobs = output.next_token_logprobs[
694
+ if logits_output.next_token_logprobs is not None:
695
+ next_token_logprobs = logits_output.next_token_logprobs[
684
696
  torch.arange(len(next_token_ids), device=next_token_ids.device),
685
697
  next_token_ids,
686
698
  ].tolist()
@@ -706,7 +718,7 @@ class ModelTpServer:
706
718
  (next_token_logprobs[i], next_token_id)
707
719
  )
708
720
  if req.top_logprobs_num > 0:
709
- req.output_top_logprobs.append(output.output_top_logprobs[i])
721
+ req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
710
722
 
711
723
  self.handle_finished_requests(batch)
712
724
 
@@ -864,7 +876,7 @@ def run_tp_server(
864
876
  tp_rank: int,
865
877
  server_args: ServerArgs,
866
878
  nccl_port: int,
867
- model_overide_args: dict,
879
+ model_override_args: dict,
868
880
  ):
869
881
  """Run a tensor parallel model server."""
870
882
  configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -875,7 +887,7 @@ def run_tp_server(
875
887
  tp_rank,
876
888
  server_args,
877
889
  nccl_port,
878
- model_overide_args,
890
+ model_override_args,
879
891
  )
880
892
  tp_cpu_group = model_server.model_runner.tp_group.cpu_group
881
893
 
@@ -892,14 +904,14 @@ def launch_tp_servers(
892
904
  tp_rank_range: List[int],
893
905
  server_args: ServerArgs,
894
906
  nccl_port: int,
895
- model_overide_args: dict,
907
+ model_override_args: dict,
896
908
  ):
897
909
  """Launch multiple tensor parallel servers."""
898
910
  procs = []
899
911
  for i in tp_rank_range:
900
912
  proc = multiprocessing.Process(
901
913
  target=run_tp_server,
902
- args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
914
+ args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
903
915
  )
904
916
  proc.start()
905
917
  procs.append(proc)
@@ -33,17 +33,17 @@ class ModelConfig:
33
33
  trust_remote_code: bool = True,
34
34
  revision: Optional[str] = None,
35
35
  context_length: Optional[int] = None,
36
- model_overide_args: Optional[dict] = None,
36
+ model_override_args: Optional[dict] = None,
37
37
  ) -> None:
38
38
  self.path = path
39
39
  self.trust_remote_code = trust_remote_code
40
40
  self.revision = revision
41
- self.model_overide_args = model_overide_args
41
+ self.model_override_args = model_override_args
42
42
  self.hf_config = get_config(
43
43
  self.path,
44
44
  trust_remote_code,
45
45
  revision,
46
- model_overide_args=model_overide_args,
46
+ model_override_args=model_override_args,
47
47
  )
48
48
  self.hf_text_config = get_hf_text_config(self.hf_config)
49
49
  if context_length is not None:
@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
28
  from sglang.srt.layers.logits_processor import (
29
- LogitProcessorOutput,
30
29
  LogitsMetadata,
31
30
  LogitsProcessor,
31
+ LogitsProcessorOutput,
32
32
  )
33
+ from sglang.srt.layers.sampler import SampleOutput
33
34
  from sglang.srt.managers.schedule_batch import ScheduleBatch
34
35
  from sglang.srt.model_executor.forward_batch_info import (
35
36
  ForwardMode,
36
37
  InputMetadata,
37
38
  update_flashinfer_indices,
38
39
  )
40
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
39
41
  from sglang.srt.utils import monkey_patch_vllm_all_gather
40
42
 
41
43
 
@@ -144,6 +146,10 @@ class CudaGraphRunner:
144
146
  self.flashinfer_kv_indices.clone(),
145
147
  ]
146
148
 
149
+ # Sampling inputs
150
+ vocab_size = model_runner.model_config.vocab_size
151
+ self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
152
+
147
153
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
148
154
 
149
155
  if use_torch_compile:
@@ -235,6 +241,7 @@ class CudaGraphRunner:
235
241
  def run_once():
236
242
  input_metadata = InputMetadata(
237
243
  forward_mode=ForwardMode.DECODE,
244
+ sampling_info=self.sampling_info[:bs],
238
245
  batch_size=bs,
239
246
  req_pool_indices=req_pool_indices,
240
247
  seq_lens=seq_lens,
@@ -299,27 +306,35 @@ class CudaGraphRunner:
299
306
  self.flashinfer_handlers[bs],
300
307
  )
301
308
 
309
+ # Sampling inputs
310
+ self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
311
+
302
312
  # Replay
303
313
  torch.cuda.synchronize()
304
314
  self.graphs[bs].replay()
305
315
  torch.cuda.synchronize()
306
- output = self.output_buffers[bs]
316
+ sample_output, logits_output = self.output_buffers[bs]
307
317
 
308
318
  # Unpad
309
319
  if bs != raw_bs:
310
- output = LogitProcessorOutput(
311
- next_token_logits=output.next_token_logits[:raw_bs],
320
+ logits_output = LogitsProcessorOutput(
321
+ next_token_logits=logits_output.next_token_logits[:raw_bs],
312
322
  next_token_logprobs=None,
313
323
  normalized_prompt_logprobs=None,
314
324
  input_token_logprobs=None,
315
325
  input_top_logprobs=None,
316
326
  output_top_logprobs=None,
317
327
  )
328
+ sample_output = SampleOutput(
329
+ sample_output.success[:raw_bs],
330
+ sample_output.probs[:raw_bs],
331
+ sample_output.batch_next_token_ids[:raw_bs],
332
+ )
318
333
 
319
334
  # Extract logprobs
320
335
  if batch.return_logprob:
321
- output.next_token_logprobs = torch.nn.functional.log_softmax(
322
- output.next_token_logits, dim=-1
336
+ logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
337
+ logits_output.next_token_logits, dim=-1
323
338
  )
324
339
  return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
325
340
  if return_top_logprob:
@@ -327,8 +342,8 @@ class CudaGraphRunner:
327
342
  forward_mode=ForwardMode.DECODE,
328
343
  top_logprobs_nums=batch.top_logprobs_nums,
329
344
  )
330
- output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
331
- output.next_token_logprobs, logits_metadata
345
+ logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
346
+ logits_output.next_token_logprobs, logits_metadata
332
347
  )[1]
333
348
 
334
- return output
349
+ return sample_output, logits_output