sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,14 @@
1
1
  """A tensor parallel worker."""
2
2
 
3
- import asyncio
4
3
  import logging
4
+ import multiprocessing
5
+ import pickle
5
6
  import time
6
7
  import warnings
7
- from concurrent.futures import ThreadPoolExecutor
8
8
  from typing import List, Optional
9
9
 
10
- import rpyc
11
10
  import torch
12
- from rpyc.utils.classic import obtain
11
+ import torch.distributed as dist
13
12
 
14
13
  from sglang.global_config import global_config
15
14
  from sglang.srt.constrained.fsm_cache import FSMCache
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
32
31
  TokenizedGenerateReqInput,
33
32
  )
34
33
  from sglang.srt.model_config import ModelConfig
35
- from sglang.srt.server_args import ModelPortArgs, ServerArgs
34
+ from sglang.srt.server_args import ServerArgs
36
35
  from sglang.srt.utils import (
37
- connect_rpyc_service,
38
36
  get_int_token_logit_bias,
39
37
  is_multimodal_model,
40
38
  set_random_seed,
41
- start_rpyc_service_process,
42
39
  suppress_other_loggers,
43
40
  )
44
41
  from sglang.utils import get_exception_traceback
@@ -52,10 +49,9 @@ class ModelTpServer:
52
49
  gpu_id: int,
53
50
  tp_rank: int,
54
51
  server_args: ServerArgs,
55
- model_port_args: ModelPortArgs,
52
+ nccl_port: int,
56
53
  model_overide_args: dict,
57
54
  ):
58
- server_args, model_port_args = obtain(server_args), obtain(model_port_args)
59
55
  suppress_other_loggers()
60
56
 
61
57
  # Copy arguments
@@ -79,7 +75,7 @@ class ModelTpServer:
79
75
  gpu_id=gpu_id,
80
76
  tp_rank=tp_rank,
81
77
  tp_size=server_args.tp_size,
82
- nccl_port=model_port_args.nccl_port,
78
+ nccl_port=nccl_port,
83
79
  server_args=server_args,
84
80
  )
85
81
 
@@ -107,6 +103,9 @@ class ModelTpServer:
107
103
  if server_args.max_running_requests is None
108
104
  else server_args.max_running_requests
109
105
  )
106
+ self.max_running_requests = min(
107
+ self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
108
+ )
110
109
  self.int_token_logit_bias = torch.tensor(
111
110
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
112
111
  )
@@ -117,13 +116,9 @@ class ModelTpServer:
117
116
  f"[gpu_id={self.gpu_id}] "
118
117
  f"max_total_num_tokens={self.max_total_num_tokens}, "
119
118
  f"max_prefill_tokens={self.max_prefill_tokens}, "
119
+ f"max_running_requests={self.max_running_requests}, "
120
120
  f"context_len={self.model_config.context_len}"
121
121
  )
122
- if self.tp_rank == 0:
123
- logger.info(
124
- f"[gpu_id={self.gpu_id}] "
125
- f"server_args: {server_args.print_mode_args()}"
126
- )
127
122
 
128
123
  # Init cache
129
124
  self.tree_cache = RadixCache(
@@ -165,22 +160,16 @@ class ModelTpServer:
165
160
  assert (
166
161
  server_args.schedule_conservativeness >= 0
167
162
  ), "Invalid schedule_conservativeness"
168
- self.new_token_ratio = min(
169
- global_config.base_new_token_ratio * server_args.schedule_conservativeness,
170
- 1.0,
171
- )
172
163
  self.min_new_token_ratio = min(
173
164
  global_config.base_min_new_token_ratio
174
165
  * server_args.schedule_conservativeness,
175
166
  1.0,
176
167
  )
168
+ self.new_token_ratio = self.min_new_token_ratio
177
169
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
178
170
  self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
179
171
 
180
172
  def exposed_step(self, recv_reqs):
181
- if not isinstance(recv_reqs, list):
182
- recv_reqs = obtain(recv_reqs)
183
-
184
173
  try:
185
174
  # Recv requests
186
175
  for recv_req in recv_reqs:
@@ -228,23 +217,7 @@ class ModelTpServer:
228
217
 
229
218
  # Print stats
230
219
  if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
231
- num_used = self.max_total_num_tokens - (
232
- self.token_to_kv_pool.available_size()
233
- + self.tree_cache.evictable_size()
234
- )
235
- throughput = self.num_generated_tokens / (
236
- time.time() - self.last_stats_tic
237
- )
238
- self.num_generated_tokens = 0
239
- self.last_stats_tic = time.time()
240
- logger.info(
241
- f"[gpu_id={self.gpu_id}] Decode batch. "
242
- f"#running-req: {len(self.running_batch.reqs)}, "
243
- f"#token: {num_used}, "
244
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
245
- f"gen throughput (token/s): {throughput:.2f}, "
246
- f"#queue-req: {len(self.forward_queue)}"
247
- )
220
+ self.print_stats()
248
221
 
249
222
  if self.running_batch.is_empty():
250
223
  self.running_batch = None
@@ -253,17 +226,35 @@ class ModelTpServer:
253
226
  if self.out_pyobjs and self.running_batch.has_stream():
254
227
  break
255
228
  else:
256
- # Check the available size
257
- available_size = (
258
- self.token_to_kv_pool.available_size()
259
- + self.tree_cache.evictable_size()
260
- )
261
- if available_size != self.max_total_num_tokens:
262
- warnings.warn(
263
- "Warning: "
264
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
265
- "KV cache pool leak detected!"
266
- )
229
+ self.check_memory()
230
+ self.new_token_ratio = global_config.init_new_token_ratio
231
+
232
+ def print_stats(self):
233
+ num_used = self.max_total_num_tokens - (
234
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
235
+ )
236
+ throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
237
+ self.num_generated_tokens = 0
238
+ self.last_stats_tic = time.time()
239
+ logger.info(
240
+ f"[gpu_id={self.gpu_id}] Decode batch. "
241
+ f"#running-req: {len(self.running_batch.reqs)}, "
242
+ f"#token: {num_used}, "
243
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
244
+ f"gen throughput (token/s): {throughput:.2f}, "
245
+ f"#queue-req: {len(self.forward_queue)}"
246
+ )
247
+
248
+ def check_memory(self):
249
+ available_size = (
250
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
251
+ )
252
+ if available_size != self.max_total_num_tokens:
253
+ warnings.warn(
254
+ "Warning: "
255
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
256
+ "KV cache pool leak detected!"
257
+ )
267
258
 
268
259
  def handle_generate_request(
269
260
  self,
@@ -310,6 +301,12 @@ class ModelTpServer:
310
301
  self.model_config.context_len - 1 - len(req.origin_input_ids),
311
302
  self.max_total_num_tokens - 128 - len(req.origin_input_ids),
312
303
  )
304
+ if req.sampling_params.max_new_tokens < 0:
305
+ req.origin_input_ids = req.origin_input_ids[
306
+ : self.max_total_num_tokens - 128
307
+ ]
308
+ logger.error("Request longer than memory pool size, truncated!!!")
309
+
313
310
  self.forward_queue.append(req)
314
311
 
315
312
  def get_new_prefill_batch(self) -> Optional[Batch]:
@@ -343,7 +340,8 @@ class ModelTpServer:
343
340
  if self.running_batch:
344
341
  available_size -= sum(
345
342
  [
346
- (r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
343
+ (r.sampling_params.max_new_tokens - len(r.output_ids))
344
+ * self.new_token_ratio
347
345
  for r in self.running_batch.reqs
348
346
  ]
349
347
  )
@@ -365,7 +363,9 @@ class ModelTpServer:
365
363
  req.image_offset += 1
366
364
 
367
365
  if (
368
- req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
366
+ req.extend_input_len
367
+ + req.sampling_params.max_new_tokens
368
+ + new_batch_total_tokens
369
369
  < available_size
370
370
  and (
371
371
  req.extend_input_len + new_batch_input_tokens
@@ -377,7 +377,9 @@ class ModelTpServer:
377
377
  available_size += delta
378
378
 
379
379
  if not (
380
- req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
380
+ req.extend_input_len
381
+ + req.sampling_params.max_new_tokens
382
+ + new_batch_total_tokens
381
383
  < available_size
382
384
  ):
383
385
  # Undo locking
@@ -419,12 +421,6 @@ class ModelTpServer:
419
421
  f"#running-req: {running_bs}, "
420
422
  f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
421
423
  )
422
- # logger.debug(
423
- # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
424
- # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
425
- # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
426
- # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
427
- # )
428
424
 
429
425
  # Return the new batch
430
426
  new_batch = Batch.init_new(
@@ -445,7 +441,7 @@ class ModelTpServer:
445
441
  # Forward and sample the next tokens
446
442
  if batch.extend_num_tokens != 0:
447
443
  output = self.model_runner.forward(batch, ForwardMode.EXTEND)
448
- next_token_ids, _ = batch.sample(output.next_token_logits)
444
+ next_token_ids = batch.sample(output.next_token_logits)
449
445
 
450
446
  # Move logprobs to cpu
451
447
  if output.next_token_logprobs is not None:
@@ -540,9 +536,10 @@ class ModelTpServer:
540
536
  # Check if decode out of memory
541
537
  if not batch.check_decode_mem():
542
538
  old_ratio = self.new_token_ratio
543
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
544
539
 
545
- retracted_reqs = batch.retract_decode()
540
+ retracted_reqs, new_token_ratio = batch.retract_decode()
541
+ self.new_token_ratio = new_token_ratio
542
+
546
543
  logger.info(
547
544
  "decode out of memory happened, "
548
545
  f"#retracted_reqs: {len(retracted_reqs)}, "
@@ -568,7 +565,7 @@ class ModelTpServer:
568
565
 
569
566
  # Forward and sample the next tokens
570
567
  output = self.model_runner.forward(batch, ForwardMode.DECODE)
571
- next_token_ids, _ = batch.sample(output.next_token_logits)
568
+ next_token_ids = batch.sample(output.next_token_logits)
572
569
 
573
570
  # Move logprobs to cpu
574
571
  if output.next_token_logprobs is not None:
@@ -596,9 +593,10 @@ class ModelTpServer:
596
593
 
597
594
  def handle_finished_requests(self, batch: Batch):
598
595
  output_rids = []
596
+ output_vids = []
599
597
  decoded_texts = []
600
- surr_output_ids = []
601
- read_output_ids = []
598
+ output_read_ids = []
599
+ output_read_offsets = []
602
600
  output_skip_special_tokens = []
603
601
  output_spaces_between_special_tokens = []
604
602
  output_meta_info = []
@@ -621,10 +619,11 @@ class ModelTpServer:
621
619
  )
622
620
  ):
623
621
  output_rids.append(req.rid)
622
+ output_vids.append(req.vid)
624
623
  decoded_texts.append(req.decoded_text)
625
- surr_ids, read_ids, _ = req.init_detokenize_incrementally()
626
- surr_output_ids.append(surr_ids)
627
- read_output_ids.append(read_ids)
624
+ read_ids, read_offset = req.init_incremental_detokenize()
625
+ output_read_ids.append(read_ids)
626
+ output_read_offsets.append(read_offset)
628
627
  output_skip_special_tokens.append(
629
628
  req.sampling_params.skip_special_tokens
630
629
  )
@@ -660,9 +659,10 @@ class ModelTpServer:
660
659
  self.out_pyobjs.append(
661
660
  BatchTokenIDOut(
662
661
  output_rids,
662
+ output_vids,
663
663
  decoded_texts,
664
- surr_output_ids,
665
- read_output_ids,
664
+ output_read_ids,
665
+ output_read_offsets,
666
666
  output_skip_special_tokens,
667
667
  output_spaces_between_special_tokens,
668
668
  output_meta_info,
@@ -727,87 +727,74 @@ class ModelTpServer:
727
727
  break
728
728
 
729
729
 
730
- class ModelTpService(rpyc.Service):
731
- exposed_ModelTpServer = ModelTpServer
732
-
733
-
734
- class ModelTpClient:
735
- def __init__(
736
- self,
737
- gpu_ids: List[int],
738
- server_args: ServerArgs,
739
- model_port_args: ModelPortArgs,
740
- model_overide_args,
741
- ):
742
- server_args, model_port_args = obtain(server_args), obtain(model_port_args)
743
- self.tp_size = server_args.tp_size
730
+ def run_tp_server(
731
+ gpu_id: int,
732
+ tp_rank: int,
733
+ server_args: ServerArgs,
734
+ nccl_port: int,
735
+ model_overide_args: dict,
736
+ ):
737
+ """Run a tensor parallel server."""
738
+ try:
739
+ model_server = ModelTpServer(
740
+ gpu_id,
741
+ tp_rank,
742
+ server_args,
743
+ nccl_port,
744
+ model_overide_args,
745
+ )
746
+ tp_cpu_group = model_server.model_runner.tp_group.cpu_group
747
+
748
+ while True:
749
+ recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
750
+ model_server.exposed_step(recv_reqs)
751
+ except Exception:
752
+ logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
753
+ raise
754
+
755
+
756
+ def launch_tp_servers(
757
+ gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
758
+ ):
759
+ """Launch multiple tensor parallel servers."""
760
+ procs = []
761
+ for i in tp_rank_range:
762
+ proc = multiprocessing.Process(
763
+ target=run_tp_server,
764
+ args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
765
+ )
766
+ proc.start()
767
+ procs.append(proc)
744
768
 
745
- if self.tp_size * server_args.dp_size == 1:
746
- # Init model
747
- assert len(gpu_ids) == 1
748
- self.model_server = ModelTpService().exposed_ModelTpServer(
749
- gpu_ids[0],
750
- 0,
751
- server_args,
752
- model_port_args,
753
- model_overide_args,
754
- )
769
+ return procs
755
770
 
756
- # Wrap functions
757
- def async_wrap(f):
758
- async def _func(*args, **kwargs):
759
- return f(*args, **kwargs)
760
771
 
761
- return _func
772
+ def broadcast_recv_input(data, rank, dist_group):
773
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
762
774
 
763
- self.step = async_wrap(self.model_server.exposed_step)
775
+ if rank == 0:
776
+ if len(data) == 0:
777
+ tensor_size = torch.tensor([0], dtype=torch.long)
778
+ dist.broadcast(tensor_size, src=0, group=dist_group)
764
779
  else:
765
- with ThreadPoolExecutor(self.tp_size) as executor:
766
- # Launch model processes
767
- if server_args.nnodes == 1:
768
- self.procs = list(
769
- executor.map(
770
- lambda args: start_rpyc_service_process(*args),
771
- [
772
- (ModelTpService, p)
773
- for p in model_port_args.model_tp_ports
774
- ],
775
- )
776
- )
777
- addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
778
- else:
779
- addrs = [
780
- (ip, port)
781
- for ip, port in zip(
782
- model_port_args.model_tp_ips, model_port_args.model_tp_ports
783
- )
784
- ]
785
-
786
- self.model_services = list(
787
- executor.map(lambda args: connect_rpyc_service(*args), addrs)
788
- )
789
-
790
- # Init model
791
- def init_model(i):
792
- return self.model_services[i].ModelTpServer(
793
- gpu_ids[i],
794
- i,
795
- server_args,
796
- model_port_args,
797
- model_overide_args,
798
- )
799
-
800
- self.model_servers = list(executor.map(init_model, range(self.tp_size)))
801
-
802
- # Wrap functions
803
- def async_wrap(func_name):
804
- fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
805
-
806
- async def _func(*args, **kwargs):
807
- tasks = [f(*args, **kwargs) for f in fs]
808
- await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
809
- return obtain(tasks[0].value)
810
-
811
- return _func
812
-
813
- self.step = async_wrap("step")
780
+ serialized_data = pickle.dumps(data)
781
+ size = len(serialized_data)
782
+ tensor_data = torch.ByteTensor(list(serialized_data))
783
+ tensor_size = torch.tensor([size], dtype=torch.long)
784
+
785
+ dist.broadcast(tensor_size, src=0, group=dist_group)
786
+ dist.broadcast(tensor_data, src=0, group=dist_group)
787
+ else:
788
+ tensor_size = torch.tensor([0], dtype=torch.long)
789
+ dist.broadcast(tensor_size, src=0, group=dist_group)
790
+ size = tensor_size.item()
791
+
792
+ if size == 0:
793
+ return []
794
+
795
+ tensor_data = torch.empty(size, dtype=torch.uint8)
796
+ dist.broadcast(tensor_data, src=0, group=dist_group)
797
+
798
+ serialized_data = bytes(tensor_data.tolist())
799
+ data = pickle.loads(serialized_data)
800
+ return data
@@ -1,7 +1,9 @@
1
1
  """DetokenizerManager is a process that detokenizes the token ids."""
2
2
 
3
3
  import asyncio
4
+ import dataclasses
4
5
  import inspect
6
+ from typing import List
5
7
 
6
8
  import uvloop
7
9
  import zmq
@@ -16,6 +18,15 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
16
18
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
17
19
 
18
20
 
21
+ @dataclasses.dataclass
22
+ class DecodeStatus:
23
+ vid: int
24
+ decoded_text: str
25
+ decode_ids: List[int]
26
+ surr_offset: int
27
+ read_offset: int
28
+
29
+
19
30
  class DetokenizerManager:
20
31
  def __init__(
21
32
  self,
@@ -35,19 +46,43 @@ class DetokenizerManager:
35
46
  trust_remote_code=server_args.trust_remote_code,
36
47
  )
37
48
 
49
+ self.decode_status = {}
50
+
38
51
  async def handle_loop(self):
39
52
  while True:
40
53
  recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
41
54
  assert isinstance(recv_obj, BatchTokenIDOut)
55
+ bs = len(recv_obj.rids)
56
+
57
+ # Initialize decode status
58
+ read_ids, surr_ids = [], []
59
+ for i in range(bs):
60
+ rid = recv_obj.rids[i]
61
+ vid = recv_obj.vids[i]
62
+ if rid not in self.decode_status or self.decode_status[rid].vid != vid:
63
+ s = DecodeStatus(
64
+ vid=vid,
65
+ decoded_text=recv_obj.decoded_texts[i],
66
+ decode_ids=recv_obj.decode_ids[i],
67
+ surr_offset=0,
68
+ read_offset=recv_obj.read_offsets[i],
69
+ )
70
+ self.decode_status[rid] = s
71
+ else:
72
+ s = self.decode_status[rid]
73
+ s.decode_ids = recv_obj.decode_ids[i]
74
+
75
+ read_ids.append(s.decode_ids[s.surr_offset :])
76
+ surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
42
77
 
43
78
  # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
79
  surr_texts = self.tokenizer.batch_decode(
45
- recv_obj.surr_output_ids,
80
+ surr_ids,
46
81
  skip_special_tokens=recv_obj.skip_special_tokens[0],
47
82
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
48
83
  )
49
84
  read_texts = self.tokenizer.batch_decode(
50
- recv_obj.read_output_ids,
85
+ read_ids,
51
86
  skip_special_tokens=recv_obj.skip_special_tokens[0],
52
87
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
53
88
  )
@@ -55,11 +90,20 @@ class DetokenizerManager:
55
90
  # Trim stop str
56
91
  # TODO(lmzheng): handle the case where multiple stop strs are hit
57
92
  output_strs = []
58
- for i in range(len(recv_obj.rids)):
93
+ for i in range(bs):
94
+ s = self.decode_status[recv_obj.rids[i]]
59
95
  new_text = read_texts[i][len(surr_texts[i]) :]
60
96
  if recv_obj.finished_reason[i] is None:
61
- new_text = find_printable_text(new_text)
62
- output_strs.append(recv_obj.decoded_texts[i] + new_text)
97
+ # Streaming chunk: update the decode status
98
+ if len(new_text) > 0 and not new_text.endswith("�"):
99
+ s.decoded_text = s.decoded_text + new_text
100
+ s.surr_offset = s.read_offset
101
+ s.read_offset = len(s.decode_ids)
102
+ new_text = ""
103
+ else:
104
+ new_text = find_printable_text(new_text)
105
+
106
+ output_strs.append(s.decoded_text + new_text)
63
107
 
64
108
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
65
109
  pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
@@ -13,25 +13,26 @@ from sglang.srt.sampling_params import SamplingParams
13
13
 
14
14
  @dataclass
15
15
  class GenerateReqInput:
16
- # The input prompt
16
+ # The input prompt. It can be a single prompt or a batch of prompts.
17
17
  text: Optional[Union[List[str], str]] = None
18
- # The token ids for text; one can either specify text or input_ids
18
+ # The token ids for text; one can either specify text or input_ids.
19
19
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
20
- # The image input
20
+ # The image input. It can be a file name, a url, or base64 encoded string.
21
+ # See also python/sglang/srt/utils.py:load_image.
21
22
  image_data: Optional[Union[List[str], str]] = None
22
- # The sampling_params
23
+ # The sampling_params.
23
24
  sampling_params: Union[List[Dict], Dict] = None
24
- # The request id
25
+ # The request id.
25
26
  rid: Optional[Union[List[str], str]] = None
26
- # Whether to return logprobs
27
+ # Whether to return logprobs.
27
28
  return_logprob: Optional[Union[List[bool], bool]] = None
28
- # The start location of the prompt for return_logprob
29
+ # The start location of the prompt for return_logprob.
29
30
  logprob_start_len: Optional[Union[List[int], int]] = None
30
- # The number of top logprobs to return
31
+ # The number of top logprobs to return.
31
32
  top_logprobs_num: Optional[Union[List[int], int]] = None
32
- # Whether to detokenize tokens in logprobs
33
+ # Whether to detokenize tokens in logprobs.
33
34
  return_text_in_logprobs: bool = False
34
- # Whether to stream output
35
+ # Whether to stream output.
35
36
  stream: bool = False
36
37
 
37
38
  def post_init(self):
@@ -39,11 +40,13 @@ class GenerateReqInput:
39
40
  self.text is not None and self.input_ids is not None
40
41
  ):
41
42
  raise ValueError("Either text or input_ids should be provided.")
42
-
43
- if self.text is not None:
44
- is_single = isinstance(self.text, str)
43
+ if self.sampling_params.get("n", 1) != 1:
44
+ is_single = False
45
45
  else:
46
- is_single = isinstance(self.input_ids[0], int)
46
+ if self.text is not None:
47
+ is_single = isinstance(self.text, str)
48
+ else:
49
+ is_single = isinstance(self.input_ids[0], int)
47
50
  self.is_single = is_single
48
51
 
49
52
  if is_single:
@@ -58,7 +61,22 @@ class GenerateReqInput:
58
61
  if self.top_logprobs_num is None:
59
62
  self.top_logprobs_num = 0
60
63
  else:
61
- num = len(self.text) if self.text is not None else len(self.input_ids)
64
+
65
+ parallel_sample_num = self.sampling_params.get("n", 1)
66
+
67
+ if parallel_sample_num != 1:
68
+ # parallel sampling +1 represents the original prefill stage
69
+ num = parallel_sample_num + 1
70
+ if isinstance(self.text, List):
71
+ ## suppot batch operation
72
+ self.batch_size = len(self.text)
73
+ num = num * len(self.text)
74
+ else:
75
+ self.batch_size = 1
76
+ else:
77
+ ## support select operation
78
+ num = len(self.text) if self.text is not None else len(self.input_ids)
79
+ self.batch_size = num
62
80
 
63
81
  if self.image_data is None:
64
82
  self.image_data = [None] * num
@@ -110,9 +128,10 @@ class TokenizedGenerateReqInput:
110
128
  @dataclass
111
129
  class BatchTokenIDOut:
112
130
  rids: List[str]
131
+ vids: List[int]
113
132
  decoded_texts: List[str]
114
- surr_output_ids: List[List[int]]
115
- read_output_ids: List[List[int]]
133
+ decode_ids: List[int]
134
+ read_offsets: List[int]
116
135
  skip_special_tokens: List[bool]
117
136
  spaces_between_special_tokens: List[bool]
118
137
  meta_info: List[Dict]