sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
32
- class LogitProcessorOutput:
32
+ class LogitsProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
@@ -55,6 +55,9 @@ class LogitsMetadata:
55
55
  extend_start_loc: Optional[torch.Tensor] = None
56
56
  top_logprobs_nums: Optional[List[int]] = None
57
57
 
58
+ extend_seq_lens_cpu: List[int] = None
59
+ logprob_start_lens_cpu: List[int] = None
60
+
58
61
  @classmethod
59
62
  def from_input_metadata(cls, input_metadata: InputMetadata):
60
63
  return cls(
@@ -63,22 +66,30 @@ class LogitsMetadata:
63
66
  extend_start_loc=input_metadata.extend_start_loc,
64
67
  return_logprob=input_metadata.return_logprob,
65
68
  top_logprobs_nums=input_metadata.top_logprobs_nums,
69
+ extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
70
+ logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
66
71
  )
67
72
 
68
73
 
69
74
  class LogitsProcessor(nn.Module):
70
- def __init__(self, config):
75
+ def __init__(self, config, skip_all_gather: bool = False):
71
76
  super().__init__()
72
77
  self.config = config
73
- self.tp_size = get_tensor_model_parallel_world_size()
78
+ self.do_tensor_parallel_all_gather = (
79
+ not skip_all_gather and get_tensor_model_parallel_world_size() > 1
80
+ )
74
81
 
75
82
  def _get_normalized_prompt_logprobs(
76
- self, input_token_logprobs, logits_metadata: LogitsMetadata
83
+ self,
84
+ input_token_logprobs: torch.Tensor,
85
+ cum_start_len0: torch.Tensor,
86
+ cum_start_len1: torch.Tensor,
87
+ logits_metadata: LogitsMetadata,
77
88
  ):
78
89
  logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
79
90
 
80
- start = logits_metadata.extend_start_loc.clone()
81
- end = start + logits_metadata.extend_seq_lens - 2
91
+ start = logits_metadata.extend_start_loc.clone() - cum_start_len0
92
+ end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
82
93
  start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
83
94
  end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
84
95
  sum_logp = (
@@ -91,7 +102,7 @@ class LogitsProcessor(nn.Module):
91
102
  return normalized_prompt_logprobs
92
103
 
93
104
  @staticmethod
94
- def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
105
+ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
95
106
  if logits_metadata.forward_mode == ForwardMode.DECODE:
96
107
  output_top_logprobs = []
97
108
  max_k = max(logits_metadata.top_logprobs_nums)
@@ -105,7 +116,7 @@ class LogitsProcessor(nn.Module):
105
116
  # TODO: vectorize the code below
106
117
  input_top_logprobs, output_top_logprobs = [], []
107
118
  pt = 0
108
- extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
119
+ extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
109
120
 
110
121
  max_k = max(logits_metadata.top_logprobs_nums)
111
122
  ret = all_logprobs.topk(max_k, dim=1)
@@ -113,26 +124,30 @@ class LogitsProcessor(nn.Module):
113
124
  indices = ret.indices.tolist()
114
125
 
115
126
  for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
127
+ start_len = logits_metadata.logprob_start_lens_cpu[i]
128
+ pruned_len = extend_seq_len - start_len
129
+
116
130
  if extend_seq_len == 0:
117
131
  input_top_logprobs.append([])
118
132
  output_top_logprobs.append([])
119
133
  continue
134
+
120
135
  k = logits_metadata.top_logprobs_nums[i]
121
136
  input_top_logprobs.append(
122
137
  [
123
138
  list(zip(values[pt + j][:k], indices[pt + j][:k]))
124
- for j in range(extend_seq_len - 1)
139
+ for j in range(pruned_len - 1)
125
140
  ]
126
141
  )
127
142
  output_top_logprobs.append(
128
143
  list(
129
144
  zip(
130
- values[pt + extend_seq_len - 1][:k],
131
- indices[pt + extend_seq_len - 1][:k],
145
+ values[pt + pruned_len - 1][:k],
146
+ indices[pt + pruned_len - 1][:k],
132
147
  )
133
148
  )
134
149
  )
135
- pt += extend_seq_len
150
+ pt += pruned_len
136
151
 
137
152
  return input_top_logprobs, output_top_logprobs
138
153
 
@@ -159,18 +174,18 @@ class LogitsProcessor(nn.Module):
159
174
  last_hidden = hidden_states[last_index]
160
175
 
161
176
  last_logits = torch.matmul(last_hidden, weight.T)
162
- if self.tp_size > 1:
177
+ if self.do_tensor_parallel_all_gather:
163
178
  last_logits = tensor_model_parallel_all_gather(last_logits)
164
179
  last_logits = last_logits[:, : self.config.vocab_size].float()
165
180
 
166
181
  if hasattr(self.config, "final_logit_softcapping"):
167
- last_logits /= self.config.final_logit_softcapping
168
- last_logits = torch.tanh(last_logits)
169
- last_logits *= self.config.final_logit_softcapping
182
+ last_logits.div_(self.config.final_logit_softcapping)
183
+ torch.tanh(last_logits, out=last_logits)
184
+ last_logits.mul_(self.config.final_logit_softcapping)
170
185
 
171
186
  # Return only last_logits if logprob is not requested
172
187
  if not logits_metadata.return_logprob:
173
- return LogitProcessorOutput(
188
+ return LogitsProcessorOutput(
174
189
  next_token_logits=last_logits,
175
190
  next_token_logprobs=None,
176
191
  normalized_prompt_logprobs=None,
@@ -194,7 +209,7 @@ class LogitsProcessor(nn.Module):
194
209
  else:
195
210
  output_top_logprobs = None
196
211
 
197
- return LogitProcessorOutput(
212
+ return LogitsProcessorOutput(
198
213
  next_token_logits=last_logits,
199
214
  next_token_logprobs=last_logprobs,
200
215
  normalized_prompt_logprobs=None,
@@ -203,15 +218,31 @@ class LogitsProcessor(nn.Module):
203
218
  output_top_logprobs=output_top_logprobs,
204
219
  )
205
220
  else:
206
- all_logits = torch.matmul(hidden_states, weight.T)
207
- if self.tp_size > 1:
221
+ pt, states, pruned_input_ids = 0, [], []
222
+ for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
223
+ start_len = logits_metadata.logprob_start_lens_cpu[i]
224
+ states.append(hidden_states[pt + start_len : pt + extend_len])
225
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
226
+ pt += extend_len
227
+
228
+ states = torch.cat(states, dim=0)
229
+ pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
230
+
231
+ cum_start_len1 = torch.tensor(
232
+ logits_metadata.logprob_start_lens_cpu, device="cuda"
233
+ ).cumsum(0)
234
+ cum_start_len0 = torch.zeros_like(cum_start_len1)
235
+ cum_start_len0[1:] = cum_start_len1[:-1]
236
+
237
+ all_logits = torch.matmul(states, weight.T)
238
+ if self.do_tensor_parallel_all_gather:
208
239
  all_logits = tensor_model_parallel_all_gather(all_logits)
209
240
  all_logits = all_logits[:, : self.config.vocab_size].float()
210
241
 
211
242
  if hasattr(self.config, "final_logit_softcapping"):
212
- all_logits /= self.config.final_logit_softcapping
213
- all_logits = torch.tanh(all_logits)
214
- all_logits *= self.config.final_logit_softcapping
243
+ all_logits.div_(self.config.final_logit_softcapping)
244
+ torch.tanh(all_logits, out=all_logits)
245
+ all_logits.mul_(self.config.final_logit_softcapping)
215
246
 
216
247
  all_logprobs = all_logits
217
248
  del all_logits, hidden_states
@@ -228,20 +259,26 @@ class LogitsProcessor(nn.Module):
228
259
  else:
229
260
  input_top_logprobs = output_top_logprobs = None
230
261
 
231
- last_logprobs = all_logprobs[last_index]
262
+ last_logprobs = all_logprobs[last_index - cum_start_len1]
232
263
 
233
264
  # Compute the logprobs and normalized logprobs for the prefill tokens.
234
265
  # Note that we pad a zero at the end of each sequence for easy computation.
235
266
  input_token_logprobs = all_logprobs[
236
267
  torch.arange(all_logprobs.shape[0], device="cuda"),
237
- torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
268
+ torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
238
269
  ]
239
270
 
240
271
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
241
- input_token_logprobs, logits_metadata
272
+ input_token_logprobs,
273
+ cum_start_len0,
274
+ cum_start_len1,
275
+ logits_metadata,
242
276
  )
243
277
 
244
- return LogitProcessorOutput(
278
+ # Remove the last token logprob for the prefill tokens.
279
+ input_token_logprobs = input_token_logprobs[:-1]
280
+
281
+ return LogitsProcessorOutput(
245
282
  next_token_logits=last_logits,
246
283
  next_token_logprobs=last_logprobs,
247
284
  normalized_prompt_logprobs=normalized_prompt_logprobs,
@@ -15,6 +15,8 @@ limitations under the License.
15
15
 
16
16
  """Radix attention."""
17
17
 
18
+ from typing import Optional
19
+
18
20
  import torch
19
21
  from flashinfer.cascade import merge_state
20
22
  from torch import nn
@@ -34,6 +36,7 @@ class RadixAttention(nn.Module):
34
36
  scaling: float,
35
37
  num_kv_heads: int,
36
38
  layer_id: int,
39
+ sliding_window_size: Optional[int] = None,
37
40
  logit_cap: int = -1,
38
41
  v_head_dim: int = -1,
39
42
  ):
@@ -46,6 +49,7 @@ class RadixAttention(nn.Module):
46
49
  self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
47
50
  self.scaling = scaling
48
51
  self.layer_id = layer_id
52
+ self.sliding_window_size = sliding_window_size if sliding_window_size else -1
49
53
 
50
54
  if (
51
55
  not global_server_args_dict.get("disable_flashinfer", False)
@@ -113,14 +117,25 @@ class RadixAttention(nn.Module):
113
117
  return o
114
118
 
115
119
  def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
120
+ # using two wrappers is unnecessary in the current PR, but are prepared for future PRs
121
+ prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
122
+ if self.sliding_window_size != -1:
123
+ prefill_wrapper_paged = prefill_wrapper_paged[0]
124
+ else:
125
+ if isinstance(prefill_wrapper_paged, list):
126
+ prefill_wrapper_paged = prefill_wrapper_paged[1]
127
+
116
128
  if not input_metadata.flashinfer_use_ragged:
117
- self.store_kv_cache(k, v, input_metadata)
129
+ if k is not None:
130
+ assert v is not None
131
+ self.store_kv_cache(k, v, input_metadata)
118
132
 
119
- o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
133
+ o = prefill_wrapper_paged.forward(
120
134
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
121
135
  input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
122
136
  causal=True,
123
137
  sm_scale=self.scaling,
138
+ window_left=self.sliding_window_size,
124
139
  logits_soft_cap=self.logit_cap,
125
140
  )
126
141
  else:
@@ -138,14 +153,12 @@ class RadixAttention(nn.Module):
138
153
  if input_metadata.extend_no_prefix:
139
154
  o = o1
140
155
  else:
141
- o2, s2 = (
142
- input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
143
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
144
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
145
- causal=False,
146
- sm_scale=self.scaling,
147
- logits_soft_cap=self.logit_cap,
148
- )
156
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
157
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
158
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
159
+ causal=False,
160
+ sm_scale=self.scaling,
161
+ logits_soft_cap=self.logit_cap,
149
162
  )
150
163
 
151
164
  o, _ = merge_state(o1, s1, o2, s2)
@@ -158,9 +171,18 @@ class RadixAttention(nn.Module):
158
171
  return o.view(-1, self.tp_q_head_num * self.head_dim)
159
172
 
160
173
  def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
161
- self.store_kv_cache(k, v, input_metadata)
174
+ decode_wrapper = input_metadata.flashinfer_decode_wrapper
175
+ if self.sliding_window_size != -1:
176
+ decode_wrapper = decode_wrapper[0]
177
+ else:
178
+ if isinstance(decode_wrapper, list):
179
+ decode_wrapper = decode_wrapper[1]
162
180
 
163
- o = input_metadata.flashinfer_decode_wrapper.forward(
181
+ if k is not None:
182
+ assert v is not None
183
+ self.store_kv_cache(k, v, input_metadata)
184
+
185
+ o = decode_wrapper.forward(
164
186
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
165
187
  input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
166
188
  sm_scale=self.scaling,
@@ -170,8 +192,10 @@ class RadixAttention(nn.Module):
170
192
  return o.view(-1, self.tp_q_head_num * self.head_dim)
171
193
 
172
194
  def forward(self, q, k, v, input_metadata: InputMetadata):
173
- k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
174
- v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
195
+ if k is not None:
196
+ assert v is not None
197
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
198
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
175
199
 
176
200
  if input_metadata.forward_mode == ForwardMode.EXTEND:
177
201
  return self.extend_forward(q, k, v, input_metadata)
@@ -179,7 +203,6 @@ class RadixAttention(nn.Module):
179
203
  return self.decode_forward(q, k, v, input_metadata)
180
204
 
181
205
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
182
- k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
183
- v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
184
- k_cache[input_metadata.out_cache_loc] = cache_k
185
- v_cache[input_metadata.out_cache_loc] = cache_v
206
+ input_metadata.token_to_kv_pool.set_kv_buffer(
207
+ self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
208
+ )
@@ -0,0 +1,154 @@
1
+ import dataclasses
2
+ import logging
3
+ from typing import Union
4
+
5
+ import torch
6
+ from flashinfer.sampling import (
7
+ min_p_sampling_from_probs,
8
+ top_k_renorm_prob,
9
+ top_k_top_p_sampling_from_probs,
10
+ top_p_renorm_prob,
11
+ )
12
+ from vllm.model_executor.custom_op import CustomOp
13
+
14
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
15
+
16
+ # TODO: move this dict to another place
17
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
18
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class SampleOutput:
25
+ success: torch.Tensor
26
+ probs: torch.Tensor
27
+ batch_next_token_ids: torch.Tensor
28
+
29
+
30
+ class Sampler(CustomOp):
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+ def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
35
+ # min-token, presence, frequency
36
+ if sampling_info.linear_penalties is not None:
37
+ logits += sampling_info.linear_penalties
38
+
39
+ # repetition
40
+ if sampling_info.scaling_penalties is not None:
41
+ logits = torch.where(
42
+ logits > 0,
43
+ logits / sampling_info.scaling_penalties,
44
+ logits * sampling_info.scaling_penalties,
45
+ )
46
+
47
+ return logits
48
+
49
+ def _get_probs(
50
+ self,
51
+ logits: torch.Tensor,
52
+ sampling_info: SamplingBatchInfo,
53
+ is_torch_compile: bool = False,
54
+ ):
55
+ # Post process logits
56
+ logits = logits.contiguous()
57
+ logits.div_(sampling_info.temperatures)
58
+ if is_torch_compile:
59
+ # FIXME: Temporary workaround for unknown bugs in torch.compile
60
+ logits.add_(0)
61
+
62
+ if sampling_info.logit_bias is not None:
63
+ logits.add_(sampling_info.logit_bias)
64
+
65
+ if sampling_info.vocab_mask is not None:
66
+ logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
67
+
68
+ logits = self._apply_penalties(logits, sampling_info)
69
+
70
+ return torch.softmax(logits, dim=-1)
71
+
72
+ def forward_cuda(
73
+ self,
74
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
75
+ sampling_info: SamplingBatchInfo,
76
+ ):
77
+ if isinstance(logits, LogitsProcessorOutput):
78
+ logits = logits.next_token_logits
79
+
80
+ probs = self._get_probs(logits, sampling_info)
81
+
82
+ if not global_server_args_dict["disable_flashinfer_sampling"]:
83
+ max_top_k_round, batch_size = 32, probs.shape[0]
84
+ uniform_samples = torch.rand(
85
+ (max_top_k_round, batch_size), device=probs.device
86
+ )
87
+ if sampling_info.need_min_p_sampling:
88
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
89
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
90
+ batch_next_token_ids, success = min_p_sampling_from_probs(
91
+ probs, uniform_samples, sampling_info.min_ps
92
+ )
93
+ else:
94
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
95
+ probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
96
+ )
97
+ else:
98
+ # Here we provide a slower fallback implementation.
99
+ batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
100
+ probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
101
+ )
102
+
103
+ return SampleOutput(success, probs, batch_next_token_ids)
104
+
105
+ def forward_native(
106
+ self,
107
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
108
+ sampling_info: SamplingBatchInfo,
109
+ ):
110
+ if isinstance(logits, LogitsProcessorOutput):
111
+ logits = logits.next_token_logits
112
+
113
+ probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
114
+
115
+ batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
116
+ probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
117
+ )
118
+
119
+ return SampleOutput(success, probs, batch_next_token_ids)
120
+
121
+
122
+ def top_k_top_p_min_p_sampling_from_probs_torch(
123
+ probs: torch.Tensor,
124
+ top_ks: torch.Tensor,
125
+ top_ps: torch.Tensor,
126
+ min_ps: torch.Tensor,
127
+ ):
128
+ """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
129
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
130
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
131
+ min_p_thresholds = probs_sort[:, 0] * min_ps
132
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
133
+ probs_sort[
134
+ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
135
+ >= top_ks.view(-1, 1)
136
+ ] = 0.0
137
+ probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
138
+ probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
139
+ try:
140
+ # FIXME: torch.multiomial does not support num_samples = 1
141
+ sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
142
+ :, :1
143
+ ]
144
+ except RuntimeError as e:
145
+ logger.warning(f"Sampling error: {e}")
146
+ batch_next_token_ids = torch.zeros(
147
+ (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
148
+ )
149
+ success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
150
+ return batch_next_token_ids, success
151
+
152
+ batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
153
+ success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
154
+ return batch_next_token_ids, success
@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
21
21
  import dataclasses
22
22
  import logging
23
23
  import multiprocessing
24
- import os
25
24
  from enum import Enum, auto
26
25
 
27
26
  import numpy as np
@@ -36,7 +35,7 @@ from sglang.srt.managers.io_struct import (
36
35
  TokenizedGenerateReqInput,
37
36
  )
38
37
  from sglang.srt.server_args import PortArgs, ServerArgs
39
- from sglang.srt.utils import kill_parent_process
38
+ from sglang.srt.utils import configure_logger, kill_parent_process
40
39
  from sglang.utils import get_exception_traceback
41
40
 
42
41
  logger = logging.getLogger(__name__)
@@ -194,10 +193,7 @@ def start_controller_process(
194
193
  ):
195
194
  """Start a controller process."""
196
195
 
197
- logging.basicConfig(
198
- level=getattr(logging, server_args.log_level.upper()),
199
- format="%(message)s",
200
- )
196
+ configure_logger(server_args)
201
197
 
202
198
  try:
203
199
  controller = ControllerMulti(server_args, port_args, model_overide_args)
@@ -212,6 +208,4 @@ def start_controller_process(
212
208
  except Exception:
213
209
  logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
214
210
  finally:
215
- for w in controller.workers:
216
- os.kill(w.proc.pid, 9)
217
211
  kill_parent_process()
@@ -17,7 +17,6 @@ limitations under the License.
17
17
 
18
18
  import logging
19
19
  import multiprocessing
20
- import os
21
20
  from typing import List
22
21
 
23
22
  import zmq
@@ -28,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
28
27
  launch_tp_servers,
29
28
  )
30
29
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.srt.utils import kill_parent_process
30
+ from sglang.srt.utils import configure_logger, kill_parent_process
32
31
  from sglang.utils import get_exception_traceback
33
32
 
34
33
  logger = logging.getLogger(__name__)
@@ -53,7 +52,7 @@ class ControllerSingle:
53
52
  self.dp_worker_id = dp_worker_id
54
53
  self.mp_queue = mp_queue
55
54
 
56
- # Init communication
55
+ # Init inter-process communication
57
56
  context = zmq.Context(2)
58
57
 
59
58
  if not self.is_dp_worker:
@@ -134,11 +133,11 @@ def start_controller_process(
134
133
  queue: multiprocessing.connection.Connection = None,
135
134
  ):
136
135
  """Start a controller process."""
137
-
138
- logging.basicConfig(
139
- level=getattr(logging, server_args.log_level.upper()),
140
- format="%(message)s",
141
- )
136
+ if is_data_parallel_worker:
137
+ logger_prefix = f" DP{dp_worker_id} TP0"
138
+ else:
139
+ logger_prefix = " TP0"
140
+ configure_logger(server_args, prefix=logger_prefix)
142
141
 
143
142
  if not is_data_parallel_worker:
144
143
  tp_size_local = server_args.tp_size // server_args.nnodes
@@ -167,6 +166,4 @@ def start_controller_process(
167
166
  except Exception:
168
167
  logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
169
168
  finally:
170
- for t in controller.tp_procs:
171
- os.kill(t.pid, 9)
172
169
  kill_parent_process()
@@ -17,7 +17,6 @@ limitations under the License.
17
17
 
18
18
  import asyncio
19
19
  import dataclasses
20
- import inspect
21
20
  from typing import List
22
21
 
23
22
  import uvloop
@@ -29,6 +28,7 @@ from sglang.srt.managers.io_struct import (
29
28
  BatchEmbeddingOut,
30
29
  BatchStrOut,
31
30
  BatchTokenIDOut,
31
+ UpdateWeightReqOutput,
32
32
  )
33
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
34
34
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
39
39
 
40
40
  @dataclasses.dataclass
41
41
  class DecodeStatus:
42
+ """Store the status of incremental decoding."""
43
+
42
44
  vid: int
43
45
  decoded_text: str
44
46
  decode_ids: List[int]
@@ -47,11 +49,14 @@ class DecodeStatus:
47
49
 
48
50
 
49
51
  class DetokenizerManager:
52
+ """DetokenizerManager is a process that detokenizes the token ids."""
53
+
50
54
  def __init__(
51
55
  self,
52
56
  server_args: ServerArgs,
53
57
  port_args: PortArgs,
54
58
  ):
59
+ # Init inter-process communication
55
60
  context = zmq.asyncio.Context(2)
56
61
  self.recv_from_router = context.socket(zmq.PULL)
57
62
  self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
@@ -71,10 +76,13 @@ class DetokenizerManager:
71
76
  self.decode_status = {}
72
77
 
73
78
  async def handle_loop(self):
79
+ """The event loop that handles requests"""
80
+
74
81
  while True:
75
- recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
82
+ recv_obj = await self.recv_from_router.recv_pyobj()
76
83
 
77
84
  if isinstance(recv_obj, BatchEmbeddingOut):
85
+ # If it is embedding model, no detokenization is needed.
78
86
  self.send_to_tokenizer.send_pyobj(
79
87
  BatchEmbeddingOut(
80
88
  rids=recv_obj.rids,
@@ -84,15 +92,18 @@ class DetokenizerManager:
84
92
  )
85
93
  )
86
94
  continue
95
+ elif isinstance(recv_obj, UpdateWeightReqOutput):
96
+ # If it is a weight update request, no detokenization is needed.
97
+ self.send_to_tokenizer.send_pyobj(recv_obj)
98
+ continue
99
+ elif self.tokenizer is None:
100
+ # If the tokenizer is skipped, no detokenization is needed
101
+ self.send_to_tokenizer.send_pyobj(recv_obj)
102
+ continue
87
103
 
88
104
  assert isinstance(recv_obj, BatchTokenIDOut)
89
105
  bs = len(recv_obj.rids)
90
106
 
91
- if self.tokenizer is None:
92
- # Send BatchTokenIDOut if no tokenizer init'ed.
93
- self.send_to_tokenizer.send_pyobj(recv_obj)
94
- continue
95
-
96
107
  # Initialize decode status
97
108
  read_ids, surr_ids = [], []
98
109
  for i in range(bs):
@@ -126,8 +137,7 @@ class DetokenizerManager:
126
137
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
127
138
  )
128
139
 
129
- # Trim stop str
130
- # TODO(lmzheng): handle the case where multiple stop strs are hit
140
+ # Incremental decoding
131
141
  output_strs = []
132
142
  for i in range(bs):
133
143
  s = self.decode_status[recv_obj.rids[i]]
@@ -144,6 +154,7 @@ class DetokenizerManager:
144
154
 
145
155
  output_strs.append(s.decoded_text + new_text)
146
156
 
157
+ # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
147
158
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
148
159
  pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
149
160
  if pos != -1: