sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -55,6 +55,9 @@ class LogitsMetadata:
55
55
  extend_start_loc: Optional[torch.Tensor] = None
56
56
  top_logprobs_nums: Optional[List[int]] = None
57
57
 
58
+ extend_seq_lens_cpu: List[int] = None
59
+ logprob_start_lens_cpu: List[int] = None
60
+
58
61
  @classmethod
59
62
  def from_input_metadata(cls, input_metadata: InputMetadata):
60
63
  return cls(
@@ -63,22 +66,30 @@ class LogitsMetadata:
63
66
  extend_start_loc=input_metadata.extend_start_loc,
64
67
  return_logprob=input_metadata.return_logprob,
65
68
  top_logprobs_nums=input_metadata.top_logprobs_nums,
69
+ extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
70
+ logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
66
71
  )
67
72
 
68
73
 
69
74
  class LogitsProcessor(nn.Module):
70
- def __init__(self, config):
75
+ def __init__(self, config, skip_all_gather: bool = False):
71
76
  super().__init__()
72
77
  self.config = config
73
- self.tp_size = get_tensor_model_parallel_world_size()
78
+ self.do_tensor_parallel_all_gather = (
79
+ not skip_all_gather and get_tensor_model_parallel_world_size() > 1
80
+ )
74
81
 
75
82
  def _get_normalized_prompt_logprobs(
76
- self, input_token_logprobs, logits_metadata: LogitsMetadata
83
+ self,
84
+ input_token_logprobs: torch.Tensor,
85
+ cum_start_len0: torch.Tensor,
86
+ cum_start_len1: torch.Tensor,
87
+ logits_metadata: LogitsMetadata,
77
88
  ):
78
89
  logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
79
90
 
80
- start = logits_metadata.extend_start_loc.clone()
81
- end = start + logits_metadata.extend_seq_lens - 2
91
+ start = logits_metadata.extend_start_loc.clone() - cum_start_len0
92
+ end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
82
93
  start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
83
94
  end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
84
95
  sum_logp = (
@@ -91,7 +102,7 @@ class LogitsProcessor(nn.Module):
91
102
  return normalized_prompt_logprobs
92
103
 
93
104
  @staticmethod
94
- def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
105
+ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
95
106
  if logits_metadata.forward_mode == ForwardMode.DECODE:
96
107
  output_top_logprobs = []
97
108
  max_k = max(logits_metadata.top_logprobs_nums)
@@ -105,7 +116,7 @@ class LogitsProcessor(nn.Module):
105
116
  # TODO: vectorize the code below
106
117
  input_top_logprobs, output_top_logprobs = [], []
107
118
  pt = 0
108
- extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
119
+ extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
109
120
 
110
121
  max_k = max(logits_metadata.top_logprobs_nums)
111
122
  ret = all_logprobs.topk(max_k, dim=1)
@@ -113,26 +124,30 @@ class LogitsProcessor(nn.Module):
113
124
  indices = ret.indices.tolist()
114
125
 
115
126
  for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
127
+ start_len = logits_metadata.logprob_start_lens_cpu[i]
128
+ pruned_len = extend_seq_len - start_len
129
+
116
130
  if extend_seq_len == 0:
117
131
  input_top_logprobs.append([])
118
132
  output_top_logprobs.append([])
119
133
  continue
134
+
120
135
  k = logits_metadata.top_logprobs_nums[i]
121
136
  input_top_logprobs.append(
122
137
  [
123
138
  list(zip(values[pt + j][:k], indices[pt + j][:k]))
124
- for j in range(extend_seq_len - 1)
139
+ for j in range(pruned_len - 1)
125
140
  ]
126
141
  )
127
142
  output_top_logprobs.append(
128
143
  list(
129
144
  zip(
130
- values[pt + extend_seq_len - 1][:k],
131
- indices[pt + extend_seq_len - 1][:k],
145
+ values[pt + pruned_len - 1][:k],
146
+ indices[pt + pruned_len - 1][:k],
132
147
  )
133
148
  )
134
149
  )
135
- pt += extend_seq_len
150
+ pt += pruned_len
136
151
 
137
152
  return input_top_logprobs, output_top_logprobs
138
153
 
@@ -159,13 +174,13 @@ class LogitsProcessor(nn.Module):
159
174
  last_hidden = hidden_states[last_index]
160
175
 
161
176
  last_logits = torch.matmul(last_hidden, weight.T)
162
- if self.tp_size > 1:
177
+ if self.do_tensor_parallel_all_gather:
163
178
  last_logits = tensor_model_parallel_all_gather(last_logits)
164
179
  last_logits = last_logits[:, : self.config.vocab_size].float()
165
180
 
166
181
  if hasattr(self.config, "final_logit_softcapping"):
167
182
  last_logits.div_(self.config.final_logit_softcapping)
168
- last_logits = torch.tanh(last_logits)
183
+ torch.tanh(last_logits, out=last_logits)
169
184
  last_logits.mul_(self.config.final_logit_softcapping)
170
185
 
171
186
  # Return only last_logits if logprob is not requested
@@ -203,14 +218,30 @@ class LogitsProcessor(nn.Module):
203
218
  output_top_logprobs=output_top_logprobs,
204
219
  )
205
220
  else:
206
- all_logits = torch.matmul(hidden_states, weight.T)
207
- if self.tp_size > 1:
221
+ pt, states, pruned_input_ids = 0, [], []
222
+ for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
223
+ start_len = logits_metadata.logprob_start_lens_cpu[i]
224
+ states.append(hidden_states[pt + start_len : pt + extend_len])
225
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
226
+ pt += extend_len
227
+
228
+ states = torch.cat(states, dim=0)
229
+ pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
230
+
231
+ cum_start_len1 = torch.tensor(
232
+ logits_metadata.logprob_start_lens_cpu, device="cuda"
233
+ ).cumsum(0)
234
+ cum_start_len0 = torch.zeros_like(cum_start_len1)
235
+ cum_start_len0[1:] = cum_start_len1[:-1]
236
+
237
+ all_logits = torch.matmul(states, weight.T)
238
+ if self.do_tensor_parallel_all_gather:
208
239
  all_logits = tensor_model_parallel_all_gather(all_logits)
209
240
  all_logits = all_logits[:, : self.config.vocab_size].float()
210
241
 
211
242
  if hasattr(self.config, "final_logit_softcapping"):
212
243
  all_logits.div_(self.config.final_logit_softcapping)
213
- all_logits = torch.tanh(all_logits)
244
+ torch.tanh(all_logits, out=all_logits)
214
245
  all_logits.mul_(self.config.final_logit_softcapping)
215
246
 
216
247
  all_logprobs = all_logits
@@ -228,19 +259,25 @@ class LogitsProcessor(nn.Module):
228
259
  else:
229
260
  input_top_logprobs = output_top_logprobs = None
230
261
 
231
- last_logprobs = all_logprobs[last_index]
262
+ last_logprobs = all_logprobs[last_index - cum_start_len1]
232
263
 
233
264
  # Compute the logprobs and normalized logprobs for the prefill tokens.
234
265
  # Note that we pad a zero at the end of each sequence for easy computation.
235
266
  input_token_logprobs = all_logprobs[
236
267
  torch.arange(all_logprobs.shape[0], device="cuda"),
237
- torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
268
+ torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
238
269
  ]
239
270
 
240
271
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
241
- input_token_logprobs, logits_metadata
272
+ input_token_logprobs,
273
+ cum_start_len0,
274
+ cum_start_len1,
275
+ logits_metadata,
242
276
  )
243
277
 
278
+ # Remove the last token logprob for the prefill tokens.
279
+ input_token_logprobs = input_token_logprobs[:-1]
280
+
244
281
  return LogitProcessorOutput(
245
282
  next_token_logits=last_logits,
246
283
  next_token_logprobs=last_logprobs,
@@ -203,7 +203,6 @@ class RadixAttention(nn.Module):
203
203
  return self.decode_forward(q, k, v, input_metadata)
204
204
 
205
205
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
206
- k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
207
- v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
208
- k_cache[input_metadata.out_cache_loc] = cache_k
209
- v_cache[input_metadata.out_cache_loc] = cache_v
206
+ input_metadata.token_to_kv_pool.set_kv_buffer(
207
+ self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
208
+ )
@@ -0,0 +1,101 @@
1
+ import logging
2
+
3
+ import torch
4
+ from flashinfer.sampling import (
5
+ min_p_sampling_from_probs,
6
+ top_k_renorm_prob,
7
+ top_k_top_p_sampling_from_probs,
8
+ top_p_renorm_prob,
9
+ )
10
+ from vllm.model_executor.custom_op import CustomOp
11
+
12
+ # TODO: move this dict to another place
13
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
14
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class Sampler(CustomOp):
20
+ def __init__(self):
21
+ super().__init__()
22
+
23
+ def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
24
+ # Post process logits
25
+ logits = logits.contiguous()
26
+ logits.div_(sampling_info.temperatures)
27
+ if sampling_info.logit_bias is not None:
28
+ logits.add_(sampling_info.logit_bias)
29
+
30
+ if sampling_info.vocab_mask is not None:
31
+ logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
32
+
33
+ logits = sampling_info.penalizer_orchestrator.apply(logits)
34
+
35
+ probs = torch.softmax(logits, dim=-1)
36
+
37
+ if not global_server_args_dict["disable_flashinfer_sampling"]:
38
+ max_top_k_round, batch_size = 32, probs.shape[0]
39
+ uniform_samples = torch.rand(
40
+ (max_top_k_round, batch_size), device=probs.device
41
+ )
42
+ if sampling_info.min_ps.any():
43
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
44
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
45
+ batch_next_token_ids, success = min_p_sampling_from_probs(
46
+ probs, uniform_samples, sampling_info.min_ps
47
+ )
48
+ else:
49
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
50
+ probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
51
+ )
52
+ else:
53
+ # Here we provide a slower fallback implementation.
54
+ batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
55
+ probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
56
+ )
57
+
58
+ if not torch.all(success):
59
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
60
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
61
+ argmax_ids = torch.argmax(probs, dim=-1)
62
+ batch_next_token_ids = torch.where(
63
+ success, batch_next_token_ids, argmax_ids
64
+ )
65
+
66
+ return batch_next_token_ids
67
+
68
+ def forward_native():
69
+ raise NotImplementedError("Native forward is not implemented yet.")
70
+
71
+
72
+ def top_k_top_p_min_p_sampling_from_probs_torch(
73
+ probs: torch.Tensor,
74
+ top_ks: torch.Tensor,
75
+ top_ps: torch.Tensor,
76
+ min_ps: torch.Tensor,
77
+ ):
78
+ """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
79
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
80
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
81
+ min_p_thresholds = probs_sort[:, 0] * min_ps
82
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
83
+ probs_sort[
84
+ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
85
+ >= top_ks.view(-1, 1)
86
+ ] = 0.0
87
+ probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
88
+ probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
89
+ try:
90
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
91
+ except RuntimeError as e:
92
+ logger.warning(f"Sampling error: {e}")
93
+ batch_next_token_ids = torch.zeros(
94
+ (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
95
+ )
96
+ success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
97
+ return batch_next_token_ids, success
98
+
99
+ batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
100
+ success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
101
+ return batch_next_token_ids, success
@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
21
21
  import dataclasses
22
22
  import logging
23
23
  import multiprocessing
24
- import os
25
24
  from enum import Enum, auto
26
25
 
27
26
  import numpy as np
@@ -36,7 +35,7 @@ from sglang.srt.managers.io_struct import (
36
35
  TokenizedGenerateReqInput,
37
36
  )
38
37
  from sglang.srt.server_args import PortArgs, ServerArgs
39
- from sglang.srt.utils import kill_parent_process
38
+ from sglang.srt.utils import configure_logger, kill_parent_process
40
39
  from sglang.utils import get_exception_traceback
41
40
 
42
41
  logger = logging.getLogger(__name__)
@@ -194,10 +193,7 @@ def start_controller_process(
194
193
  ):
195
194
  """Start a controller process."""
196
195
 
197
- logging.basicConfig(
198
- level=getattr(logging, server_args.log_level.upper()),
199
- format="%(message)s",
200
- )
196
+ configure_logger(server_args)
201
197
 
202
198
  try:
203
199
  controller = ControllerMulti(server_args, port_args, model_overide_args)
@@ -212,6 +208,4 @@ def start_controller_process(
212
208
  except Exception:
213
209
  logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
214
210
  finally:
215
- for w in controller.workers:
216
- os.kill(w.proc.pid, 9)
217
211
  kill_parent_process()
@@ -17,7 +17,6 @@ limitations under the License.
17
17
 
18
18
  import logging
19
19
  import multiprocessing
20
- import os
21
20
  from typing import List
22
21
 
23
22
  import zmq
@@ -28,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
28
27
  launch_tp_servers,
29
28
  )
30
29
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.srt.utils import kill_parent_process
30
+ from sglang.srt.utils import configure_logger, kill_parent_process
32
31
  from sglang.utils import get_exception_traceback
33
32
 
34
33
  logger = logging.getLogger(__name__)
@@ -53,7 +52,7 @@ class ControllerSingle:
53
52
  self.dp_worker_id = dp_worker_id
54
53
  self.mp_queue = mp_queue
55
54
 
56
- # Init communication
55
+ # Init inter-process communication
57
56
  context = zmq.Context(2)
58
57
 
59
58
  if not self.is_dp_worker:
@@ -134,11 +133,11 @@ def start_controller_process(
134
133
  queue: multiprocessing.connection.Connection = None,
135
134
  ):
136
135
  """Start a controller process."""
137
-
138
- logging.basicConfig(
139
- level=getattr(logging, server_args.log_level.upper()),
140
- format="%(message)s",
141
- )
136
+ if is_data_parallel_worker:
137
+ logger_prefix = f" DP{dp_worker_id} TP0"
138
+ else:
139
+ logger_prefix = " TP0"
140
+ configure_logger(server_args, prefix=logger_prefix)
142
141
 
143
142
  if not is_data_parallel_worker:
144
143
  tp_size_local = server_args.tp_size // server_args.nnodes
@@ -167,6 +166,4 @@ def start_controller_process(
167
166
  except Exception:
168
167
  logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
169
168
  finally:
170
- for t in controller.tp_procs:
171
- os.kill(t.pid, 9)
172
169
  kill_parent_process()
@@ -17,7 +17,6 @@ limitations under the License.
17
17
 
18
18
  import asyncio
19
19
  import dataclasses
20
- import inspect
21
20
  from typing import List
22
21
 
23
22
  import uvloop
@@ -29,6 +28,7 @@ from sglang.srt.managers.io_struct import (
29
28
  BatchEmbeddingOut,
30
29
  BatchStrOut,
31
30
  BatchTokenIDOut,
31
+ UpdateWeightReqOutput,
32
32
  )
33
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
34
34
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
39
39
 
40
40
  @dataclasses.dataclass
41
41
  class DecodeStatus:
42
+ """Store the status of incremental decoding."""
43
+
42
44
  vid: int
43
45
  decoded_text: str
44
46
  decode_ids: List[int]
@@ -47,11 +49,14 @@ class DecodeStatus:
47
49
 
48
50
 
49
51
  class DetokenizerManager:
52
+ """DetokenizerManager is a process that detokenizes the token ids."""
53
+
50
54
  def __init__(
51
55
  self,
52
56
  server_args: ServerArgs,
53
57
  port_args: PortArgs,
54
58
  ):
59
+ # Init inter-process communication
55
60
  context = zmq.asyncio.Context(2)
56
61
  self.recv_from_router = context.socket(zmq.PULL)
57
62
  self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
@@ -71,10 +76,13 @@ class DetokenizerManager:
71
76
  self.decode_status = {}
72
77
 
73
78
  async def handle_loop(self):
79
+ """The event loop that handles requests"""
80
+
74
81
  while True:
75
- recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
82
+ recv_obj = await self.recv_from_router.recv_pyobj()
76
83
 
77
84
  if isinstance(recv_obj, BatchEmbeddingOut):
85
+ # If it is embedding model, no detokenization is needed.
78
86
  self.send_to_tokenizer.send_pyobj(
79
87
  BatchEmbeddingOut(
80
88
  rids=recv_obj.rids,
@@ -84,15 +92,18 @@ class DetokenizerManager:
84
92
  )
85
93
  )
86
94
  continue
95
+ elif isinstance(recv_obj, UpdateWeightReqOutput):
96
+ # If it is a weight update request, no detokenization is needed.
97
+ self.send_to_tokenizer.send_pyobj(recv_obj)
98
+ continue
99
+ elif self.tokenizer is None:
100
+ # If the tokenizer is skipped, no detokenization is needed
101
+ self.send_to_tokenizer.send_pyobj(recv_obj)
102
+ continue
87
103
 
88
104
  assert isinstance(recv_obj, BatchTokenIDOut)
89
105
  bs = len(recv_obj.rids)
90
106
 
91
- if self.tokenizer is None:
92
- # Send BatchTokenIDOut if no tokenizer init'ed.
93
- self.send_to_tokenizer.send_pyobj(recv_obj)
94
- continue
95
-
96
107
  # Initialize decode status
97
108
  read_ids, surr_ids = [], []
98
109
  for i in range(bs):
@@ -126,8 +137,7 @@ class DetokenizerManager:
126
137
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
127
138
  )
128
139
 
129
- # Trim stop str
130
- # TODO(lmzheng): handle the case where multiple stop strs are hit
140
+ # Incremental decoding
131
141
  output_strs = []
132
142
  for i in range(bs):
133
143
  s = self.decode_status[recv_obj.rids[i]]
@@ -144,6 +154,7 @@ class DetokenizerManager:
144
154
 
145
155
  output_strs.append(s.decoded_text + new_text)
146
156
 
157
+ # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
147
158
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
148
159
  pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
149
160
  if pos != -1:
@@ -22,10 +22,8 @@ import uuid
22
22
  from dataclasses import dataclass
23
23
  from typing import Dict, List, Optional, Union
24
24
 
25
- import torch
26
-
27
25
  from sglang.srt.managers.schedule_batch import BaseFinishReason
28
- from sglang.srt.sampling_params import SamplingParams
26
+ from sglang.srt.sampling.sampling_params import SamplingParams
29
27
 
30
28
 
31
29
  @dataclass
@@ -43,9 +41,9 @@ class GenerateReqInput:
43
41
  rid: Optional[Union[List[str], str]] = None
44
42
  # Whether to return logprobs.
45
43
  return_logprob: Optional[Union[List[bool], bool]] = None
46
- # The start location of the prompt for return_logprob.
44
+ # If return logprobs, the start location in the prompt for returning logprobs.
47
45
  logprob_start_len: Optional[Union[List[int], int]] = None
48
- # The number of top logprobs to return.
46
+ # If return logprobs, the number of top logprobs to return at each position.
49
47
  top_logprobs_num: Optional[Union[List[int], int]] = None
50
48
  # Whether to detokenize tokens in text in the returned logprobs.
51
49
  return_text_in_logprobs: bool = False
@@ -77,7 +75,7 @@ class GenerateReqInput:
77
75
  if self.return_logprob is None:
78
76
  self.return_logprob = False
79
77
  if self.logprob_start_len is None:
80
- self.logprob_start_len = 0
78
+ self.logprob_start_len = -1
81
79
  if self.top_logprobs_num is None:
82
80
  self.top_logprobs_num = 0
83
81
  else:
@@ -143,7 +141,7 @@ class GenerateReqInput:
143
141
  self.return_logprob = [self.return_logprob] * num
144
142
 
145
143
  if self.logprob_start_len is None:
146
- self.logprob_start_len = [0] * num
144
+ self.logprob_start_len = [-1] * num
147
145
  elif not isinstance(self.logprob_start_len, list):
148
146
  self.logprob_start_len = [self.logprob_start_len] * num
149
147
 
@@ -155,16 +153,27 @@ class GenerateReqInput:
155
153
 
156
154
  @dataclass
157
155
  class TokenizedGenerateReqInput:
156
+ # The request id
158
157
  rid: str
158
+ # The input text
159
159
  input_text: str
160
+ # The input token ids
160
161
  input_ids: List[int]
162
+ # The pixel values for input images
161
163
  pixel_values: List[float]
164
+ # The hash of input images
162
165
  image_hash: int
166
+ # The image size
163
167
  image_size: List[int]
168
+ # The sampling parameters
164
169
  sampling_params: SamplingParams
170
+ # Whether to return the logprobs
165
171
  return_logprob: bool
172
+ # If return logprobs, the start location in the prompt for returning logprobs.
166
173
  logprob_start_len: int
174
+ # If return logprobs, the number of top logprobs to return at each position.
167
175
  top_logprobs_num: int
176
+ # Whether to stream output
168
177
  stream: bool
169
178
 
170
179
 
@@ -215,15 +224,21 @@ class EmbeddingReqInput:
215
224
 
216
225
  @dataclass
217
226
  class TokenizedEmbeddingReqInput:
227
+ # The request id
218
228
  rid: str
229
+ # The input text
219
230
  input_text: str
231
+ # The input token ids
220
232
  input_ids: List[int]
233
+ # Dummy sampling params for compatibility
221
234
  sampling_params: SamplingParams
222
235
 
223
236
 
224
237
  @dataclass
225
238
  class BatchTokenIDOut:
239
+ # The request id
226
240
  rids: List[str]
241
+ # The version id to sync decode status with in detokenizer_manager
227
242
  vids: List[int]
228
243
  decoded_texts: List[str]
229
244
  decode_ids: List[int]
@@ -236,17 +251,25 @@ class BatchTokenIDOut:
236
251
 
237
252
  @dataclass
238
253
  class BatchStrOut:
254
+ # The request id
239
255
  rids: List[str]
256
+ # The output decoded strings
240
257
  output_strs: List[str]
258
+ # The meta info
241
259
  meta_info: List[Dict]
260
+ # The finish reason
242
261
  finished_reason: List[BaseFinishReason]
243
262
 
244
263
 
245
264
  @dataclass
246
265
  class BatchEmbeddingOut:
266
+ # The request id
247
267
  rids: List[str]
268
+ # The output embedding
248
269
  embeddings: List[List[float]]
270
+ # The meta info
249
271
  meta_info: List[Dict]
272
+ # The finish reason
250
273
  finished_reason: List[BaseFinishReason]
251
274
 
252
275
 
@@ -256,10 +279,20 @@ class FlushCacheReq:
256
279
 
257
280
 
258
281
  @dataclass
259
- class AbortReq:
260
- rid: str
282
+ class UpdateWeightReqInput:
283
+ # The model path with the new weights
284
+ model_path: str
285
+ # The format to load the weights
286
+ load_format: Optional[str] = None
261
287
 
262
288
 
263
289
  @dataclass
264
- class DetokenizeReqInput:
265
- input_ids: List[int]
290
+ class UpdateWeightReqOutput:
291
+ success: bool
292
+ message: str
293
+
294
+
295
+ @dataclass
296
+ class AbortReq:
297
+ # The request id
298
+ rid: str
@@ -111,11 +111,14 @@ class PrefillAdder:
111
111
  rem_total_tokens: int,
112
112
  rem_input_tokens: int,
113
113
  rem_chunk_tokens: Optional[int],
114
+ mixed_with_decode_tokens: int = 0,
114
115
  ):
115
116
  self.tree_cache = tree_cache
116
- self.rem_total_tokens = rem_total_tokens
117
- self.rem_input_tokens = rem_input_tokens
117
+ self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
118
+ self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
118
119
  self.rem_chunk_tokens = rem_chunk_tokens
120
+ if self.rem_chunk_tokens is not None:
121
+ self.rem_chunk_tokens -= mixed_with_decode_tokens
119
122
 
120
123
  self.can_run_list = []
121
124
  self.new_inflight_req = None