sglang 0.2.14.post2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/backend/runtime_endpoint.py +8 -4
  4. sglang/lang/interpreter.py +3 -0
  5. sglang/lang/ir.py +5 -0
  6. sglang/launch_server_llavavid.py +12 -12
  7. sglang/srt/configs/__init__.py +5 -0
  8. sglang/srt/configs/exaone.py +195 -0
  9. sglang/srt/constrained/fsm_cache.py +1 -1
  10. sglang/srt/conversation.py +24 -2
  11. sglang/srt/hf_transformers_utils.py +12 -12
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/sampler.py +94 -17
  15. sglang/srt/managers/controller_multi.py +5 -5
  16. sglang/srt/managers/controller_single.py +5 -5
  17. sglang/srt/managers/io_struct.py +6 -1
  18. sglang/srt/managers/schedule_batch.py +26 -11
  19. sglang/srt/managers/tokenizer_manager.py +9 -9
  20. sglang/srt/managers/tp_worker.py +38 -26
  21. sglang/srt/model_config.py +3 -3
  22. sglang/srt/model_executor/cuda_graph_runner.py +26 -9
  23. sglang/srt/model_executor/forward_batch_info.py +68 -23
  24. sglang/srt/model_executor/model_runner.py +15 -22
  25. sglang/srt/models/chatglm.py +9 -15
  26. sglang/srt/models/commandr.py +5 -1
  27. sglang/srt/models/dbrx.py +5 -1
  28. sglang/srt/models/deepseek.py +5 -1
  29. sglang/srt/models/deepseek_v2.py +57 -25
  30. sglang/srt/models/exaone.py +368 -0
  31. sglang/srt/models/gemma.py +5 -1
  32. sglang/srt/models/gemma2.py +5 -1
  33. sglang/srt/models/gpt_bigcode.py +5 -1
  34. sglang/srt/models/grok.py +5 -1
  35. sglang/srt/models/internlm2.py +5 -1
  36. sglang/srt/models/{llama2.py → llama.py} +25 -45
  37. sglang/srt/models/llama_classification.py +34 -41
  38. sglang/srt/models/llama_embedding.py +7 -6
  39. sglang/srt/models/llava.py +8 -11
  40. sglang/srt/models/llavavid.py +5 -6
  41. sglang/srt/models/minicpm.py +5 -1
  42. sglang/srt/models/mistral.py +2 -3
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +6 -2
  47. sglang/srt/models/qwen2_moe.py +5 -14
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/openai_api/adapter.py +16 -1
  50. sglang/srt/openai_api/protocol.py +5 -5
  51. sglang/srt/sampling/sampling_batch_info.py +75 -6
  52. sglang/srt/server.py +6 -6
  53. sglang/srt/utils.py +0 -3
  54. sglang/test/runners.py +1 -1
  55. sglang/test/test_programs.py +68 -0
  56. sglang/test/test_utils.py +4 -0
  57. sglang/utils.py +39 -0
  58. sglang/version.py +1 -1
  59. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/METADATA +9 -8
  60. sglang-0.3.0.dist-info/RECORD +118 -0
  61. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/WHEEL +1 -1
  62. sglang-0.2.14.post2.dist-info/RECORD +0 -115
  63. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/LICENSE +0 -0
  64. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
32
- class LogitProcessorOutput:
32
+ class LogitsProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
185
185
 
186
186
  # Return only last_logits if logprob is not requested
187
187
  if not logits_metadata.return_logprob:
188
- return LogitProcessorOutput(
188
+ return LogitsProcessorOutput(
189
189
  next_token_logits=last_logits,
190
190
  next_token_logprobs=None,
191
191
  normalized_prompt_logprobs=None,
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
209
209
  else:
210
210
  output_top_logprobs = None
211
211
 
212
- return LogitProcessorOutput(
212
+ return LogitsProcessorOutput(
213
213
  next_token_logits=last_logits,
214
214
  next_token_logprobs=last_logprobs,
215
215
  normalized_prompt_logprobs=None,
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
278
278
  # Remove the last token logprob for the prefill tokens.
279
279
  input_token_logprobs = input_token_logprobs[:-1]
280
280
 
281
- return LogitProcessorOutput(
281
+ return LogitsProcessorOutput(
282
282
  next_token_logits=last_logits,
283
283
  next_token_logprobs=last_logprobs,
284
284
  normalized_prompt_logprobs=normalized_prompt_logprobs,
@@ -1,4 +1,6 @@
1
+ import dataclasses
1
2
  import logging
3
+ from typing import Tuple, Union
2
4
 
3
5
  import torch
4
6
  from flashinfer.sampling import (
@@ -7,8 +9,11 @@ from flashinfer.sampling import (
7
9
  top_k_top_p_sampling_from_probs,
8
10
  top_p_renorm_prob,
9
11
  )
12
+ from torch.library import custom_op as torch_custom_op
10
13
  from vllm.model_executor.custom_op import CustomOp
11
14
 
15
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
+
12
17
  # TODO: move this dict to another place
13
18
  from sglang.srt.managers.schedule_batch import global_server_args_dict
14
19
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -16,37 +21,76 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
16
21
  logger = logging.getLogger(__name__)
17
22
 
18
23
 
24
+ @dataclasses.dataclass
25
+ class SampleOutput:
26
+ success: torch.Tensor
27
+ probs: torch.Tensor
28
+ batch_next_token_ids: torch.Tensor
29
+
30
+
19
31
  class Sampler(CustomOp):
20
32
  def __init__(self):
21
33
  super().__init__()
34
+ # FIXME: torch.multinomial has too many bugs
35
+ self.forward_native = self.forward_cuda
36
+ self.is_torch_compile = False
37
+
38
+ def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
39
+ # min-token, presence, frequency
40
+ if sampling_info.linear_penalties is not None:
41
+ logits += sampling_info.linear_penalties
42
+
43
+ # repetition
44
+ if sampling_info.scaling_penalties is not None:
45
+ logits = torch.where(
46
+ logits > 0,
47
+ logits / sampling_info.scaling_penalties,
48
+ logits * sampling_info.scaling_penalties,
49
+ )
50
+
51
+ return logits
22
52
 
23
- def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
53
+ def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
24
54
  # Post process logits
25
55
  logits = logits.contiguous()
26
56
  logits.div_(sampling_info.temperatures)
57
+ if self.is_torch_compile:
58
+ # FIXME: Temporary workaround for unknown bugs in torch.compile
59
+ logits.add_(0)
60
+
27
61
  if sampling_info.logit_bias is not None:
28
62
  logits.add_(sampling_info.logit_bias)
29
63
 
30
64
  if sampling_info.vocab_mask is not None:
31
- logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
65
+ logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
32
66
 
33
- logits = sampling_info.penalizer_orchestrator.apply(logits)
67
+ logits = self._apply_penalties(logits, sampling_info)
34
68
 
35
- probs = torch.softmax(logits, dim=-1)
69
+ return torch.softmax(logits, dim=-1)
70
+
71
+ def forward_cuda(
72
+ self,
73
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
74
+ sampling_info: SamplingBatchInfo,
75
+ ):
76
+ if isinstance(logits, LogitsProcessorOutput):
77
+ logits = logits.next_token_logits
78
+
79
+ probs = self._get_probs(logits, sampling_info)
36
80
 
37
81
  if not global_server_args_dict["disable_flashinfer_sampling"]:
38
82
  max_top_k_round, batch_size = 32, probs.shape[0]
39
83
  uniform_samples = torch.rand(
40
84
  (max_top_k_round, batch_size), device=probs.device
41
85
  )
42
- if sampling_info.min_ps.any():
86
+ if sampling_info.need_min_p_sampling:
43
87
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
44
88
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
45
89
  batch_next_token_ids, success = min_p_sampling_from_probs(
46
90
  probs, uniform_samples, sampling_info.min_ps
47
91
  )
48
92
  else:
49
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
93
+ batch_next_token_ids, success = flashinfer_top_k_top_p(
50
94
  probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
51
95
  )
52
96
  else:
@@ -55,18 +99,48 @@ class Sampler(CustomOp):
55
99
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
56
100
  )
57
101
 
58
- if not torch.all(success):
59
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
60
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
61
- argmax_ids = torch.argmax(probs, dim=-1)
62
- batch_next_token_ids = torch.where(
63
- success, batch_next_token_ids, argmax_ids
64
- )
102
+ return SampleOutput(success, probs, batch_next_token_ids)
103
+
104
+ def forward_native(
105
+ self,
106
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
107
+ sampling_info: SamplingBatchInfo,
108
+ ):
109
+ if isinstance(logits, LogitsProcessorOutput):
110
+ logits = logits.next_token_logits
111
+
112
+ probs = self._get_probs(logits, sampling_info)
113
+
114
+ batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
115
+ probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
116
+ )
117
+
118
+ return SampleOutput(success, probs, batch_next_token_ids)
65
119
 
66
- return batch_next_token_ids
67
120
 
68
- def forward_native():
69
- raise NotImplementedError("Native forward is not implemented yet.")
121
+ @torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
122
+ def flashinfer_top_k_top_p(
123
+ probs: torch.Tensor,
124
+ uniform_samples: torch.Tensor,
125
+ top_ks: torch.Tensor,
126
+ top_ps: torch.Tensor,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
128
+ # NOTE: we do not use min_p neither in CUDA nor in torch.compile
129
+ return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
130
+
131
+
132
+ @flashinfer_top_k_top_p.register_fake
133
+ def _(
134
+ probs: torch.Tensor,
135
+ uniform_samples: torch.Tensor,
136
+ top_ks: torch.Tensor,
137
+ top_ps: torch.Tensor,
138
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
139
+ bs = probs.shape[0]
140
+ return (
141
+ torch.ones(bs, dtype=torch.bool, device=probs.device),
142
+ torch.zeros(bs, dtype=torch.int32, device=probs.device),
143
+ )
70
144
 
71
145
 
72
146
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -87,7 +161,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
87
161
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
88
162
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
89
163
  try:
90
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
164
+ # FIXME: torch.multiomial does not support num_samples = 1
165
+ sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
166
+ :, :1
167
+ ]
91
168
  except RuntimeError as e:
92
169
  logger.warning(f"Sampling error: {e}")
93
170
  batch_next_token_ids = torch.zeros(
@@ -71,12 +71,12 @@ class ControllerMulti:
71
71
  self,
72
72
  server_args: ServerArgs,
73
73
  port_args: PortArgs,
74
- model_overide_args,
74
+ model_override_args,
75
75
  ):
76
76
  # Parse args
77
77
  self.server_args = server_args
78
78
  self.port_args = port_args
79
- self.model_overide_args = model_overide_args
79
+ self.model_override_args = model_override_args
80
80
  self.load_balance_method = LoadBalanceMethod.from_str(
81
81
  server_args.load_balance_method
82
82
  )
@@ -114,7 +114,7 @@ class ControllerMulti:
114
114
  self.server_args,
115
115
  self.port_args,
116
116
  pipe_controller_writer,
117
- self.model_overide_args,
117
+ self.model_override_args,
118
118
  True,
119
119
  gpu_ids,
120
120
  dp_worker_id,
@@ -189,14 +189,14 @@ def start_controller_process(
189
189
  server_args: ServerArgs,
190
190
  port_args: PortArgs,
191
191
  pipe_writer,
192
- model_overide_args: dict,
192
+ model_override_args: dict,
193
193
  ):
194
194
  """Start a controller process."""
195
195
 
196
196
  configure_logger(server_args)
197
197
 
198
198
  try:
199
- controller = ControllerMulti(server_args, port_args, model_overide_args)
199
+ controller = ControllerMulti(server_args, port_args, model_override_args)
200
200
  except Exception:
201
201
  pipe_writer.send(get_exception_traceback())
202
202
  raise
@@ -40,7 +40,7 @@ class ControllerSingle:
40
40
  self,
41
41
  server_args: ServerArgs,
42
42
  port_args: PortArgs,
43
- model_overide_args: dict,
43
+ model_override_args: dict,
44
44
  gpu_ids: List[int],
45
45
  is_data_parallel_worker: bool,
46
46
  dp_worker_id: int,
@@ -76,7 +76,7 @@ class ControllerSingle:
76
76
  tp_rank_range,
77
77
  server_args,
78
78
  port_args.nccl_ports[dp_worker_id],
79
- model_overide_args,
79
+ model_override_args,
80
80
  )
81
81
 
82
82
  # Launch tp rank 0
@@ -85,7 +85,7 @@ class ControllerSingle:
85
85
  0,
86
86
  server_args,
87
87
  port_args.nccl_ports[dp_worker_id],
88
- model_overide_args,
88
+ model_override_args,
89
89
  )
90
90
  self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
91
91
 
@@ -126,7 +126,7 @@ def start_controller_process(
126
126
  server_args: ServerArgs,
127
127
  port_args: PortArgs,
128
128
  pipe_writer: multiprocessing.connection.Connection,
129
- model_overide_args: dict,
129
+ model_override_args: dict,
130
130
  is_data_parallel_worker: bool = False,
131
131
  gpu_ids: List[int] = None,
132
132
  dp_worker_id: int = None,
@@ -149,7 +149,7 @@ def start_controller_process(
149
149
  controller = ControllerSingle(
150
150
  server_args,
151
151
  port_args,
152
- model_overide_args,
152
+ model_override_args,
153
153
  gpu_ids,
154
154
  is_data_parallel_worker,
155
155
  dp_worker_id,
@@ -18,8 +18,9 @@ The definition of objects transfered between different
18
18
  processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  """
20
20
 
21
+ import copy
21
22
  import uuid
22
- from dataclasses import dataclass
23
+ from dataclasses import dataclass, field
23
24
  from typing import Dict, List, Optional, Union
24
25
 
25
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
@@ -249,6 +250,10 @@ class BatchTokenIDOut:
249
250
  meta_info: List[Dict]
250
251
  finished_reason: List[BaseFinishReason]
251
252
 
253
+ def __post_init__(self):
254
+ # deepcopy meta_info to avoid modification in place
255
+ self.meta_info = copy.deepcopy(self.meta_info)
256
+
252
257
 
253
258
  @dataclass
254
259
  class BatchStrOut:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +19,7 @@ limitations under the License.
17
19
 
18
20
  import logging
19
21
  from dataclasses import dataclass
20
- from typing import List, Optional, Union
22
+ from typing import TYPE_CHECKING, List, Optional, Union
21
23
 
22
24
  import torch
23
25
 
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
29
31
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
30
32
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
31
33
 
34
+ if TYPE_CHECKING:
35
+ from sglang.srt.layers.sampler import SampleOutput
36
+
37
+
32
38
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
33
39
 
34
40
  # Put some global args for easy access
@@ -172,19 +178,22 @@ class Req:
172
178
  def adjust_max_prefix_ids(self):
173
179
  self.fill_ids = self.origin_input_ids + self.output_ids
174
180
  input_len = len(self.fill_ids)
175
- max_prefix_len = input_len
181
+
182
+ # FIXME: To work around some bugs in logprob computation, we need to ensure each
183
+ # request has at least one token. Later, we can relax this requirement and use `input_len`.
184
+ max_prefix_len = input_len - 1
176
185
 
177
186
  if self.sampling_params.max_new_tokens > 0:
178
187
  # Need at least one token to compute logits
179
188
  max_prefix_len = min(max_prefix_len, input_len - 1)
180
189
 
181
190
  if self.return_logprob:
182
- max_prefix_len = min(max_prefix_len, self.logprob_start_len)
183
-
184
191
  if self.normalized_prompt_logprob is None:
185
192
  # Need at least two tokens to compute normalized logprob
186
193
  max_prefix_len = min(max_prefix_len, input_len - 2)
194
+ max_prefix_len = min(max_prefix_len, self.logprob_start_len)
187
195
 
196
+ max_prefix_len = max(max_prefix_len, 0)
188
197
  return self.fill_ids[:max_prefix_len]
189
198
 
190
199
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
@@ -678,11 +687,17 @@ class ScheduleBatch:
678
687
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
679
688
  self.return_logprob = any(req.return_logprob for req in self.reqs)
680
689
 
681
- def sample(self, logits: torch.Tensor):
682
- from sglang.srt.layers.sampler import Sampler
683
-
684
- sampler = Sampler()
685
-
686
- batch_next_token_ids = sampler(logits, self.sampling_info)
690
+ def check_sample_results(self, sample_output: SampleOutput):
691
+ if not torch.all(sample_output.success):
692
+ probs = sample_output.probs
693
+ batch_next_token_ids = sample_output.batch_next_token_ids
694
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
695
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
696
+ argmax_ids = torch.argmax(probs, dim=-1)
697
+ batch_next_token_ids = torch.where(
698
+ sample_output.success, batch_next_token_ids, argmax_ids
699
+ )
700
+ sample_output.probs = probs
701
+ sample_output.batch_next_token_ids = batch_next_token_ids
687
702
 
688
- return batch_next_token_ids
703
+ return sample_output.batch_next_token_ids
@@ -77,7 +77,7 @@ class TokenizerManager:
77
77
  self,
78
78
  server_args: ServerArgs,
79
79
  port_args: PortArgs,
80
- model_overide_args: dict = None,
80
+ model_override_args: dict = None,
81
81
  ):
82
82
  self.server_args = server_args
83
83
 
@@ -86,8 +86,8 @@ class TokenizerManager:
86
86
  self.recv_from_detokenizer = context.socket(zmq.PULL)
87
87
  self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
88
88
 
89
- self.send_to_router = context.socket(zmq.PUSH)
90
- self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
89
+ self.send_to_controller = context.socket(zmq.PUSH)
90
+ self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
91
91
 
92
92
  # Read model args
93
93
  self.model_path = server_args.model_path
@@ -95,7 +95,7 @@ class TokenizerManager:
95
95
  self.hf_config = get_config(
96
96
  self.model_path,
97
97
  trust_remote_code=server_args.trust_remote_code,
98
- model_overide_args=model_overide_args,
98
+ model_override_args=model_override_args,
99
99
  )
100
100
  self.is_generation = is_generation_model(
101
101
  self.hf_config.architectures, self.server_args.is_embedding
@@ -271,7 +271,7 @@ class TokenizerManager:
271
271
  input_ids,
272
272
  sampling_params,
273
273
  )
274
- self.send_to_router.send_pyobj(tokenized_obj)
274
+ self.send_to_controller.send_pyobj(tokenized_obj)
275
275
 
276
276
  # Recv results
277
277
  event = asyncio.Event()
@@ -367,7 +367,7 @@ class TokenizerManager:
367
367
  input_ids,
368
368
  sampling_params,
369
369
  )
370
- self.send_to_router.send_pyobj(tokenized_obj)
370
+ self.send_to_controller.send_pyobj(tokenized_obj)
371
371
 
372
372
  event = asyncio.Event()
373
373
  state = ReqState([], False, event)
@@ -500,14 +500,14 @@ class TokenizerManager:
500
500
 
501
501
  def flush_cache(self):
502
502
  req = FlushCacheReq()
503
- self.send_to_router.send_pyobj(req)
503
+ self.send_to_controller.send_pyobj(req)
504
504
 
505
505
  def abort_request(self, rid: str):
506
506
  if rid not in self.rid_to_state:
507
507
  return
508
508
  del self.rid_to_state[rid]
509
509
  req = AbortReq(rid)
510
- self.send_to_router.send_pyobj(req)
510
+ self.send_to_controller.send_pyobj(req)
511
511
 
512
512
  async def update_weights(
513
513
  self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
@@ -524,7 +524,7 @@ class TokenizerManager:
524
524
  # wait for the previous generation requests to finish
525
525
  while len(self.rid_to_state) > 0:
526
526
  await asyncio.sleep(0)
527
- self.send_to_router.send_pyobj(obj)
527
+ self.send_to_controller.send_pyobj(obj)
528
528
  self.model_update_result = asyncio.Future()
529
529
  result = await self.model_update_result
530
530
  if result.success:
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
31
31
  from sglang.srt.constrained.fsm_cache import FSMCache
32
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
34
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
36
36
  AbortReq,
37
37
  BatchEmbeddingOut,
@@ -76,7 +76,7 @@ class ModelTpServer:
76
76
  tp_rank: int,
77
77
  server_args: ServerArgs,
78
78
  nccl_port: int,
79
- model_overide_args: dict,
79
+ model_override_args: dict,
80
80
  ):
81
81
  suppress_other_loggers()
82
82
 
@@ -93,7 +93,7 @@ class ModelTpServer:
93
93
  server_args.model_path,
94
94
  server_args.trust_remote_code,
95
95
  context_length=server_args.context_length,
96
- model_overide_args=model_overide_args,
96
+ model_override_args=model_override_args,
97
97
  )
98
98
 
99
99
  self.model_runner = ModelRunner(
@@ -504,21 +504,29 @@ class ModelTpServer:
504
504
  if self.model_runner.is_generation:
505
505
  # Forward and sample the next tokens
506
506
  if batch.extend_num_tokens != 0:
507
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
508
- next_token_ids = batch.sample(output.next_token_logits)
507
+ sample_output, logits_output = self.model_runner.forward(
508
+ batch, ForwardMode.EXTEND
509
+ )
510
+ next_token_ids = batch.check_sample_results(sample_output)
509
511
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
510
512
  next_token_ids
511
513
  )
512
514
 
513
515
  # Move logprobs to cpu
514
- if output.next_token_logprobs is not None:
515
- output.next_token_logprobs = output.next_token_logprobs[
516
- torch.arange(len(next_token_ids), device=next_token_ids.device),
517
- next_token_ids,
518
- ].tolist()
519
- output.input_token_logprobs = output.input_token_logprobs.tolist()
520
- output.normalized_prompt_logprobs = (
521
- output.normalized_prompt_logprobs.tolist()
516
+ if logits_output.next_token_logprobs is not None:
517
+ logits_output.next_token_logprobs = (
518
+ logits_output.next_token_logprobs[
519
+ torch.arange(
520
+ len(next_token_ids), device=next_token_ids.device
521
+ ),
522
+ next_token_ids,
523
+ ].tolist()
524
+ )
525
+ logits_output.input_token_logprobs = (
526
+ logits_output.input_token_logprobs.tolist()
527
+ )
528
+ logits_output.normalized_prompt_logprobs = (
529
+ logits_output.normalized_prompt_logprobs.tolist()
522
530
  )
523
531
 
524
532
  next_token_ids = next_token_ids.tolist()
@@ -557,12 +565,14 @@ class ModelTpServer:
557
565
  self.req_to_token_pool.free(req.req_pool_idx)
558
566
 
559
567
  if req.return_logprob:
560
- self.add_logprob_return_values(i, req, pt, next_token_ids, output)
568
+ self.add_logprob_return_values(
569
+ i, req, pt, next_token_ids, logits_output
570
+ )
561
571
  pt += req.extend_input_len
562
572
  else:
563
573
  assert batch.extend_num_tokens != 0
564
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
565
- embeddings = output.embeddings.tolist()
574
+ logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
575
+ embeddings = logits_output.embeddings.tolist()
566
576
 
567
577
  # Check finish conditions
568
578
  for i, req in enumerate(batch.reqs):
@@ -590,7 +600,7 @@ class ModelTpServer:
590
600
  req: Req,
591
601
  pt: int,
592
602
  next_token_ids: List[int],
593
- output: LogitProcessorOutput,
603
+ output: LogitsProcessorOutput,
594
604
  ):
595
605
  if req.normalized_prompt_logprob is None:
596
606
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -672,15 +682,17 @@ class ModelTpServer:
672
682
  batch.prepare_for_decode()
673
683
 
674
684
  # Forward and sample the next tokens
675
- output = self.model_runner.forward(batch, ForwardMode.DECODE)
676
- next_token_ids = batch.sample(output.next_token_logits)
685
+ sample_output, logits_output = self.model_runner.forward(
686
+ batch, ForwardMode.DECODE
687
+ )
688
+ next_token_ids = batch.check_sample_results(sample_output)
677
689
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
678
690
  next_token_ids
679
691
  )
680
692
 
681
693
  # Move logprobs to cpu
682
- if output.next_token_logprobs is not None:
683
- next_token_logprobs = output.next_token_logprobs[
694
+ if logits_output.next_token_logprobs is not None:
695
+ next_token_logprobs = logits_output.next_token_logprobs[
684
696
  torch.arange(len(next_token_ids), device=next_token_ids.device),
685
697
  next_token_ids,
686
698
  ].tolist()
@@ -706,7 +718,7 @@ class ModelTpServer:
706
718
  (next_token_logprobs[i], next_token_id)
707
719
  )
708
720
  if req.top_logprobs_num > 0:
709
- req.output_top_logprobs.append(output.output_top_logprobs[i])
721
+ req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
710
722
 
711
723
  self.handle_finished_requests(batch)
712
724
 
@@ -864,7 +876,7 @@ def run_tp_server(
864
876
  tp_rank: int,
865
877
  server_args: ServerArgs,
866
878
  nccl_port: int,
867
- model_overide_args: dict,
879
+ model_override_args: dict,
868
880
  ):
869
881
  """Run a tensor parallel model server."""
870
882
  configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -875,7 +887,7 @@ def run_tp_server(
875
887
  tp_rank,
876
888
  server_args,
877
889
  nccl_port,
878
- model_overide_args,
890
+ model_override_args,
879
891
  )
880
892
  tp_cpu_group = model_server.model_runner.tp_group.cpu_group
881
893
 
@@ -892,14 +904,14 @@ def launch_tp_servers(
892
904
  tp_rank_range: List[int],
893
905
  server_args: ServerArgs,
894
906
  nccl_port: int,
895
- model_overide_args: dict,
907
+ model_override_args: dict,
896
908
  ):
897
909
  """Launch multiple tensor parallel servers."""
898
910
  procs = []
899
911
  for i in tp_rank_range:
900
912
  proc = multiprocessing.Process(
901
913
  target=run_tp_server,
902
- args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
914
+ args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
903
915
  )
904
916
  proc.start()
905
917
  procs.append(proc)
@@ -33,17 +33,17 @@ class ModelConfig:
33
33
  trust_remote_code: bool = True,
34
34
  revision: Optional[str] = None,
35
35
  context_length: Optional[int] = None,
36
- model_overide_args: Optional[dict] = None,
36
+ model_override_args: Optional[dict] = None,
37
37
  ) -> None:
38
38
  self.path = path
39
39
  self.trust_remote_code = trust_remote_code
40
40
  self.revision = revision
41
- self.model_overide_args = model_overide_args
41
+ self.model_override_args = model_override_args
42
42
  self.hf_config = get_config(
43
43
  self.path,
44
44
  trust_remote_code,
45
45
  revision,
46
- model_overide_args=model_overide_args,
46
+ model_override_args=model_override_args,
47
47
  )
48
48
  self.hf_text_config = get_hf_text_config(self.hf_config)
49
49
  if context_length is not None: