sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +758 -0
  7. sglang/check_env.py +171 -0
  8. sglang/lang/backend/__init__.py +0 -0
  9. sglang/lang/backend/anthropic.py +77 -0
  10. sglang/lang/backend/base_backend.py +80 -0
  11. sglang/lang/backend/litellm.py +90 -0
  12. sglang/lang/backend/openai.py +438 -0
  13. sglang/lang/backend/runtime_endpoint.py +283 -0
  14. sglang/lang/backend/vertexai.py +149 -0
  15. sglang/lang/tracer.py +1 -1
  16. sglang/launch_server.py +1 -1
  17. sglang/launch_server_llavavid.py +1 -4
  18. sglang/srt/conversation.py +1 -1
  19. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  20. sglang/srt/layers/extend_attention.py +0 -39
  21. sglang/srt/layers/linear.py +869 -0
  22. sglang/srt/layers/quantization/__init__.py +49 -0
  23. sglang/srt/layers/quantization/fp8.py +662 -0
  24. sglang/srt/layers/radix_attention.py +31 -5
  25. sglang/srt/layers/token_attention.py +1 -51
  26. sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
  27. sglang/srt/managers/controller/infer_batch.py +47 -49
  28. sglang/srt/managers/controller/manager_multi.py +107 -100
  29. sglang/srt/managers/controller/manager_single.py +76 -96
  30. sglang/srt/managers/controller/model_runner.py +35 -23
  31. sglang/srt/managers/controller/tp_worker.py +127 -138
  32. sglang/srt/managers/detokenizer_manager.py +49 -5
  33. sglang/srt/managers/io_struct.py +36 -17
  34. sglang/srt/managers/tokenizer_manager.py +228 -125
  35. sglang/srt/memory_pool.py +19 -6
  36. sglang/srt/model_loader/model_loader.py +277 -0
  37. sglang/srt/model_loader/utils.py +260 -0
  38. sglang/srt/models/chatglm.py +1 -0
  39. sglang/srt/models/dbrx.py +1 -0
  40. sglang/srt/models/grok.py +1 -0
  41. sglang/srt/models/internlm2.py +317 -0
  42. sglang/srt/models/llama2.py +65 -16
  43. sglang/srt/models/llama_classification.py +1 -0
  44. sglang/srt/models/llava.py +1 -0
  45. sglang/srt/models/llavavid.py +1 -0
  46. sglang/srt/models/minicpm.py +1 -0
  47. sglang/srt/models/mixtral.py +1 -0
  48. sglang/srt/models/mixtral_quant.py +1 -0
  49. sglang/srt/models/qwen.py +1 -0
  50. sglang/srt/models/qwen2.py +6 -0
  51. sglang/srt/models/qwen2_moe.py +7 -4
  52. sglang/srt/models/stablelm.py +1 -0
  53. sglang/srt/openai_api/adapter.py +432 -0
  54. sglang/srt/openai_api/api_adapter.py +432 -0
  55. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  56. sglang/srt/openai_api/openai_protocol.py +207 -0
  57. sglang/srt/openai_api/protocol.py +208 -0
  58. sglang/srt/openai_protocol.py +17 -0
  59. sglang/srt/sampling_params.py +2 -0
  60. sglang/srt/server.py +113 -84
  61. sglang/srt/server_args.py +23 -15
  62. sglang/srt/utils.py +16 -117
  63. sglang/test/test_conversation.py +1 -1
  64. sglang/test/test_openai_protocol.py +1 -1
  65. sglang/test/test_programs.py +1 -1
  66. sglang/test/test_utils.py +2 -2
  67. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
  68. sglang-0.1.22.dist-info/RECORD +103 -0
  69. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  70. sglang-0.1.21.dist-info/RECORD +0 -82
  71. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  72. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,14 @@
1
1
  """A tensor parallel worker."""
2
2
 
3
- import asyncio
4
3
  import logging
4
+ import multiprocessing
5
+ import pickle
5
6
  import time
6
7
  import warnings
7
- from concurrent.futures import ThreadPoolExecutor
8
8
  from typing import List, Optional
9
9
 
10
- import rpyc
11
10
  import torch
12
- from rpyc.utils.classic import obtain
11
+ import torch.distributed as dist
13
12
 
14
13
  from sglang.global_config import global_config
15
14
  from sglang.srt.constrained.fsm_cache import FSMCache
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
32
31
  TokenizedGenerateReqInput,
33
32
  )
34
33
  from sglang.srt.model_config import ModelConfig
35
- from sglang.srt.server_args import ModelPortArgs, ServerArgs
34
+ from sglang.srt.server_args import ServerArgs
36
35
  from sglang.srt.utils import (
37
- connect_rpyc_service,
38
36
  get_int_token_logit_bias,
39
37
  is_multimodal_model,
40
38
  set_random_seed,
41
- start_rpyc_service_process,
42
39
  suppress_other_loggers,
43
40
  )
44
41
  from sglang.utils import get_exception_traceback
@@ -52,10 +49,9 @@ class ModelTpServer:
52
49
  gpu_id: int,
53
50
  tp_rank: int,
54
51
  server_args: ServerArgs,
55
- model_port_args: ModelPortArgs,
52
+ nccl_port: int,
56
53
  model_overide_args: dict,
57
54
  ):
58
- server_args, model_port_args = obtain(server_args), obtain(model_port_args)
59
55
  suppress_other_loggers()
60
56
 
61
57
  # Copy arguments
@@ -79,7 +75,7 @@ class ModelTpServer:
79
75
  gpu_id=gpu_id,
80
76
  tp_rank=tp_rank,
81
77
  tp_size=server_args.tp_size,
82
- nccl_port=model_port_args.nccl_port,
78
+ nccl_port=nccl_port,
83
79
  server_args=server_args,
84
80
  )
85
81
 
@@ -178,9 +174,6 @@ class ModelTpServer:
178
174
  self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
179
175
 
180
176
  def exposed_step(self, recv_reqs):
181
- if not isinstance(recv_reqs, list):
182
- recv_reqs = obtain(recv_reqs)
183
-
184
177
  try:
185
178
  # Recv requests
186
179
  for recv_req in recv_reqs:
@@ -228,23 +221,7 @@ class ModelTpServer:
228
221
 
229
222
  # Print stats
230
223
  if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
231
- num_used = self.max_total_num_tokens - (
232
- self.token_to_kv_pool.available_size()
233
- + self.tree_cache.evictable_size()
234
- )
235
- throughput = self.num_generated_tokens / (
236
- time.time() - self.last_stats_tic
237
- )
238
- self.num_generated_tokens = 0
239
- self.last_stats_tic = time.time()
240
- logger.info(
241
- f"[gpu_id={self.gpu_id}] Decode batch. "
242
- f"#running-req: {len(self.running_batch.reqs)}, "
243
- f"#token: {num_used}, "
244
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
245
- f"gen throughput (token/s): {throughput:.2f}, "
246
- f"#queue-req: {len(self.forward_queue)}"
247
- )
224
+ self.print_stats()
248
225
 
249
226
  if self.running_batch.is_empty():
250
227
  self.running_batch = None
@@ -253,17 +230,34 @@ class ModelTpServer:
253
230
  if self.out_pyobjs and self.running_batch.has_stream():
254
231
  break
255
232
  else:
256
- # Check the available size
257
- available_size = (
258
- self.token_to_kv_pool.available_size()
259
- + self.tree_cache.evictable_size()
260
- )
261
- if available_size != self.max_total_num_tokens:
262
- warnings.warn(
263
- "Warning: "
264
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
265
- "KV cache pool leak detected!"
266
- )
233
+ self.check_memory()
234
+
235
+ def print_stats(self):
236
+ num_used = self.max_total_num_tokens - (
237
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
238
+ )
239
+ throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
240
+ self.num_generated_tokens = 0
241
+ self.last_stats_tic = time.time()
242
+ logger.info(
243
+ f"[gpu_id={self.gpu_id}] Decode batch. "
244
+ f"#running-req: {len(self.running_batch.reqs)}, "
245
+ f"#token: {num_used}, "
246
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
247
+ f"gen throughput (token/s): {throughput:.2f}, "
248
+ f"#queue-req: {len(self.forward_queue)}"
249
+ )
250
+
251
+ def check_memory(self):
252
+ available_size = (
253
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
254
+ )
255
+ if available_size != self.max_total_num_tokens:
256
+ warnings.warn(
257
+ "Warning: "
258
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
259
+ "KV cache pool leak detected!"
260
+ )
267
261
 
268
262
  def handle_generate_request(
269
263
  self,
@@ -310,6 +304,12 @@ class ModelTpServer:
310
304
  self.model_config.context_len - 1 - len(req.origin_input_ids),
311
305
  self.max_total_num_tokens - 128 - len(req.origin_input_ids),
312
306
  )
307
+ if req.sampling_params.max_new_tokens < 0:
308
+ req.origin_input_ids = req.origin_input_ids[
309
+ : self.max_total_num_tokens - 128
310
+ ]
311
+ logger.error("Request longer than memory pool size, truncated!!!")
312
+
313
313
  self.forward_queue.append(req)
314
314
 
315
315
  def get_new_prefill_batch(self) -> Optional[Batch]:
@@ -343,7 +343,8 @@ class ModelTpServer:
343
343
  if self.running_batch:
344
344
  available_size -= sum(
345
345
  [
346
- (r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
346
+ (r.sampling_params.max_new_tokens - len(r.output_ids))
347
+ * self.new_token_ratio
347
348
  for r in self.running_batch.reqs
348
349
  ]
349
350
  )
@@ -365,7 +366,9 @@ class ModelTpServer:
365
366
  req.image_offset += 1
366
367
 
367
368
  if (
368
- req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
369
+ req.extend_input_len
370
+ + req.sampling_params.max_new_tokens
371
+ + new_batch_total_tokens
369
372
  < available_size
370
373
  and (
371
374
  req.extend_input_len + new_batch_input_tokens
@@ -377,7 +380,9 @@ class ModelTpServer:
377
380
  available_size += delta
378
381
 
379
382
  if not (
380
- req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
383
+ req.extend_input_len
384
+ + req.sampling_params.max_new_tokens
385
+ + new_batch_total_tokens
381
386
  < available_size
382
387
  ):
383
388
  # Undo locking
@@ -419,12 +424,6 @@ class ModelTpServer:
419
424
  f"#running-req: {running_bs}, "
420
425
  f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
421
426
  )
422
- # logger.debug(
423
- # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
424
- # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
425
- # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
426
- # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
427
- # )
428
427
 
429
428
  # Return the new batch
430
429
  new_batch = Batch.init_new(
@@ -445,7 +444,7 @@ class ModelTpServer:
445
444
  # Forward and sample the next tokens
446
445
  if batch.extend_num_tokens != 0:
447
446
  output = self.model_runner.forward(batch, ForwardMode.EXTEND)
448
- next_token_ids, _ = batch.sample(output.next_token_logits)
447
+ next_token_ids = batch.sample(output.next_token_logits)
449
448
 
450
449
  # Move logprobs to cpu
451
450
  if output.next_token_logprobs is not None:
@@ -568,7 +567,7 @@ class ModelTpServer:
568
567
 
569
568
  # Forward and sample the next tokens
570
569
  output = self.model_runner.forward(batch, ForwardMode.DECODE)
571
- next_token_ids, _ = batch.sample(output.next_token_logits)
570
+ next_token_ids = batch.sample(output.next_token_logits)
572
571
 
573
572
  # Move logprobs to cpu
574
573
  if output.next_token_logprobs is not None:
@@ -596,9 +595,10 @@ class ModelTpServer:
596
595
 
597
596
  def handle_finished_requests(self, batch: Batch):
598
597
  output_rids = []
598
+ output_vids = []
599
599
  decoded_texts = []
600
- surr_output_ids = []
601
- read_output_ids = []
600
+ output_read_ids = []
601
+ output_read_offsets = []
602
602
  output_skip_special_tokens = []
603
603
  output_spaces_between_special_tokens = []
604
604
  output_meta_info = []
@@ -621,10 +621,11 @@ class ModelTpServer:
621
621
  )
622
622
  ):
623
623
  output_rids.append(req.rid)
624
+ output_vids.append(req.vid)
624
625
  decoded_texts.append(req.decoded_text)
625
- surr_ids, read_ids, _ = req.init_detokenize_incrementally()
626
- surr_output_ids.append(surr_ids)
627
- read_output_ids.append(read_ids)
626
+ read_ids, read_offset = req.init_incremental_detokenize()
627
+ output_read_ids.append(read_ids)
628
+ output_read_offsets.append(read_offset)
628
629
  output_skip_special_tokens.append(
629
630
  req.sampling_params.skip_special_tokens
630
631
  )
@@ -660,9 +661,10 @@ class ModelTpServer:
660
661
  self.out_pyobjs.append(
661
662
  BatchTokenIDOut(
662
663
  output_rids,
664
+ output_vids,
663
665
  decoded_texts,
664
- surr_output_ids,
665
- read_output_ids,
666
+ output_read_ids,
667
+ output_read_offsets,
666
668
  output_skip_special_tokens,
667
669
  output_spaces_between_special_tokens,
668
670
  output_meta_info,
@@ -727,87 +729,74 @@ class ModelTpServer:
727
729
  break
728
730
 
729
731
 
730
- class ModelTpService(rpyc.Service):
731
- exposed_ModelTpServer = ModelTpServer
732
-
733
-
734
- class ModelTpClient:
735
- def __init__(
736
- self,
737
- gpu_ids: List[int],
738
- server_args: ServerArgs,
739
- model_port_args: ModelPortArgs,
740
- model_overide_args,
741
- ):
742
- server_args, model_port_args = obtain(server_args), obtain(model_port_args)
743
- self.tp_size = server_args.tp_size
732
+ def run_tp_server(
733
+ gpu_id: int,
734
+ tp_rank: int,
735
+ server_args: ServerArgs,
736
+ nccl_port: int,
737
+ model_overide_args: dict,
738
+ ):
739
+ """Run a tensor parallel server."""
740
+ try:
741
+ model_server = ModelTpServer(
742
+ gpu_id,
743
+ tp_rank,
744
+ server_args,
745
+ nccl_port,
746
+ model_overide_args,
747
+ )
748
+ tp_cpu_group = model_server.model_runner.tp_group.cpu_group
749
+
750
+ while True:
751
+ recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
752
+ model_server.exposed_step(recv_reqs)
753
+ except Exception:
754
+ logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
755
+ raise
756
+
757
+
758
+ def launch_tp_servers(
759
+ gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
760
+ ):
761
+ """Launch multiple tensor parallel servers."""
762
+ procs = []
763
+ for i in tp_rank_range:
764
+ proc = multiprocessing.Process(
765
+ target=run_tp_server,
766
+ args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
767
+ )
768
+ proc.start()
769
+ procs.append(proc)
744
770
 
745
- if self.tp_size * server_args.dp_size == 1:
746
- # Init model
747
- assert len(gpu_ids) == 1
748
- self.model_server = ModelTpService().exposed_ModelTpServer(
749
- gpu_ids[0],
750
- 0,
751
- server_args,
752
- model_port_args,
753
- model_overide_args,
754
- )
771
+ return procs
755
772
 
756
- # Wrap functions
757
- def async_wrap(f):
758
- async def _func(*args, **kwargs):
759
- return f(*args, **kwargs)
760
773
 
761
- return _func
774
+ def broadcast_recv_input(data, rank, dist_group):
775
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
762
776
 
763
- self.step = async_wrap(self.model_server.exposed_step)
777
+ if rank == 0:
778
+ if len(data) == 0:
779
+ tensor_size = torch.tensor([0], dtype=torch.long)
780
+ dist.broadcast(tensor_size, src=0, group=dist_group)
764
781
  else:
765
- with ThreadPoolExecutor(self.tp_size) as executor:
766
- # Launch model processes
767
- if server_args.nnodes == 1:
768
- self.procs = list(
769
- executor.map(
770
- lambda args: start_rpyc_service_process(*args),
771
- [
772
- (ModelTpService, p)
773
- for p in model_port_args.model_tp_ports
774
- ],
775
- )
776
- )
777
- addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
778
- else:
779
- addrs = [
780
- (ip, port)
781
- for ip, port in zip(
782
- model_port_args.model_tp_ips, model_port_args.model_tp_ports
783
- )
784
- ]
785
-
786
- self.model_services = list(
787
- executor.map(lambda args: connect_rpyc_service(*args), addrs)
788
- )
789
-
790
- # Init model
791
- def init_model(i):
792
- return self.model_services[i].ModelTpServer(
793
- gpu_ids[i],
794
- i,
795
- server_args,
796
- model_port_args,
797
- model_overide_args,
798
- )
799
-
800
- self.model_servers = list(executor.map(init_model, range(self.tp_size)))
801
-
802
- # Wrap functions
803
- def async_wrap(func_name):
804
- fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
805
-
806
- async def _func(*args, **kwargs):
807
- tasks = [f(*args, **kwargs) for f in fs]
808
- await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
809
- return obtain(tasks[0].value)
810
-
811
- return _func
812
-
813
- self.step = async_wrap("step")
782
+ serialized_data = pickle.dumps(data)
783
+ size = len(serialized_data)
784
+ tensor_data = torch.ByteTensor(list(serialized_data))
785
+ tensor_size = torch.tensor([size], dtype=torch.long)
786
+
787
+ dist.broadcast(tensor_size, src=0, group=dist_group)
788
+ dist.broadcast(tensor_data, src=0, group=dist_group)
789
+ else:
790
+ tensor_size = torch.tensor([0], dtype=torch.long)
791
+ dist.broadcast(tensor_size, src=0, group=dist_group)
792
+ size = tensor_size.item()
793
+
794
+ if size == 0:
795
+ return []
796
+
797
+ tensor_data = torch.empty(size, dtype=torch.uint8)
798
+ dist.broadcast(tensor_data, src=0, group=dist_group)
799
+
800
+ serialized_data = bytes(tensor_data.tolist())
801
+ data = pickle.loads(serialized_data)
802
+ return data
@@ -1,7 +1,9 @@
1
1
  """DetokenizerManager is a process that detokenizes the token ids."""
2
2
 
3
3
  import asyncio
4
+ import dataclasses
4
5
  import inspect
6
+ from typing import List
5
7
 
6
8
  import uvloop
7
9
  import zmq
@@ -16,6 +18,15 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
16
18
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
17
19
 
18
20
 
21
+ @dataclasses.dataclass
22
+ class DecodeStatus:
23
+ vid: int
24
+ decoded_text: str
25
+ decode_ids: List[int]
26
+ surr_offset: int
27
+ read_offset: int
28
+
29
+
19
30
  class DetokenizerManager:
20
31
  def __init__(
21
32
  self,
@@ -35,19 +46,43 @@ class DetokenizerManager:
35
46
  trust_remote_code=server_args.trust_remote_code,
36
47
  )
37
48
 
49
+ self.decode_status = {}
50
+
38
51
  async def handle_loop(self):
39
52
  while True:
40
53
  recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
41
54
  assert isinstance(recv_obj, BatchTokenIDOut)
55
+ bs = len(recv_obj.rids)
56
+
57
+ # Initialize decode status
58
+ read_ids, surr_ids = [], []
59
+ for i in range(bs):
60
+ rid = recv_obj.rids[i]
61
+ vid = recv_obj.vids[i]
62
+ if rid not in self.decode_status or self.decode_status[rid].vid != vid:
63
+ s = DecodeStatus(
64
+ vid=vid,
65
+ decoded_text=recv_obj.decoded_texts[i],
66
+ decode_ids=recv_obj.decode_ids[i],
67
+ surr_offset=0,
68
+ read_offset=recv_obj.read_offsets[i],
69
+ )
70
+ self.decode_status[rid] = s
71
+ else:
72
+ s = self.decode_status[rid]
73
+ s.decode_ids = recv_obj.decode_ids[i]
74
+
75
+ read_ids.append(s.decode_ids[s.surr_offset :])
76
+ surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
42
77
 
43
78
  # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
79
  surr_texts = self.tokenizer.batch_decode(
45
- recv_obj.surr_output_ids,
80
+ surr_ids,
46
81
  skip_special_tokens=recv_obj.skip_special_tokens[0],
47
82
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
48
83
  )
49
84
  read_texts = self.tokenizer.batch_decode(
50
- recv_obj.read_output_ids,
85
+ read_ids,
51
86
  skip_special_tokens=recv_obj.skip_special_tokens[0],
52
87
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
53
88
  )
@@ -55,11 +90,20 @@ class DetokenizerManager:
55
90
  # Trim stop str
56
91
  # TODO(lmzheng): handle the case where multiple stop strs are hit
57
92
  output_strs = []
58
- for i in range(len(recv_obj.rids)):
93
+ for i in range(bs):
94
+ s = self.decode_status[recv_obj.rids[i]]
59
95
  new_text = read_texts[i][len(surr_texts[i]) :]
60
96
  if recv_obj.finished_reason[i] is None:
61
- new_text = find_printable_text(new_text)
62
- output_strs.append(recv_obj.decoded_texts[i] + new_text)
97
+ # Streaming chunk: update the decode status
98
+ if len(new_text) > 0 and not new_text.endswith("�"):
99
+ s.decoded_text = s.decoded_text + new_text
100
+ s.surr_offset = s.read_offset
101
+ s.read_offset = len(s.decode_ids)
102
+ new_text = ""
103
+ else:
104
+ new_text = find_printable_text(new_text)
105
+
106
+ output_strs.append(s.decoded_text + new_text)
63
107
 
64
108
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
65
109
  pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
@@ -13,25 +13,26 @@ from sglang.srt.sampling_params import SamplingParams
13
13
 
14
14
  @dataclass
15
15
  class GenerateReqInput:
16
- # The input prompt
16
+ # The input prompt. It can be a single prompt or a batch of prompts.
17
17
  text: Optional[Union[List[str], str]] = None
18
- # The token ids for text; one can either specify text or input_ids
18
+ # The token ids for text; one can either specify text or input_ids.
19
19
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
20
- # The image input
20
+ # The image input. It can be a file name, a url, or base64 encoded string.
21
+ # See also python/sglang/srt/utils.py:load_image.
21
22
  image_data: Optional[Union[List[str], str]] = None
22
- # The sampling_params
23
+ # The sampling_params.
23
24
  sampling_params: Union[List[Dict], Dict] = None
24
- # The request id
25
+ # The request id.
25
26
  rid: Optional[Union[List[str], str]] = None
26
- # Whether to return logprobs
27
+ # Whether to return logprobs.
27
28
  return_logprob: Optional[Union[List[bool], bool]] = None
28
- # The start location of the prompt for return_logprob
29
+ # The start location of the prompt for return_logprob.
29
30
  logprob_start_len: Optional[Union[List[int], int]] = None
30
- # The number of top logprobs to return
31
+ # The number of top logprobs to return.
31
32
  top_logprobs_num: Optional[Union[List[int], int]] = None
32
- # Whether to detokenize tokens in logprobs
33
+ # Whether to detokenize tokens in logprobs.
33
34
  return_text_in_logprobs: bool = False
34
- # Whether to stream output
35
+ # Whether to stream output.
35
36
  stream: bool = False
36
37
 
37
38
  def post_init(self):
@@ -39,11 +40,13 @@ class GenerateReqInput:
39
40
  self.text is not None and self.input_ids is not None
40
41
  ):
41
42
  raise ValueError("Either text or input_ids should be provided.")
42
-
43
- if self.text is not None:
44
- is_single = isinstance(self.text, str)
43
+ if self.sampling_params.get("n", 1) != 1:
44
+ is_single = False
45
45
  else:
46
- is_single = isinstance(self.input_ids[0], int)
46
+ if self.text is not None:
47
+ is_single = isinstance(self.text, str)
48
+ else:
49
+ is_single = isinstance(self.input_ids[0], int)
47
50
  self.is_single = is_single
48
51
 
49
52
  if is_single:
@@ -58,7 +61,22 @@ class GenerateReqInput:
58
61
  if self.top_logprobs_num is None:
59
62
  self.top_logprobs_num = 0
60
63
  else:
61
- num = len(self.text) if self.text is not None else len(self.input_ids)
64
+
65
+ parallel_sample_num = self.sampling_params.get("n", 1)
66
+
67
+ if parallel_sample_num != 1:
68
+ # parallel sampling +1 represents the original prefill stage
69
+ num = parallel_sample_num + 1
70
+ if isinstance(self.text, List):
71
+ ## suppot batch operation
72
+ self.batch_size = len(self.text)
73
+ num = num * len(self.text)
74
+ else:
75
+ self.batch_size = 1
76
+ else:
77
+ ## support select operation
78
+ num = len(self.text) if self.text is not None else len(self.input_ids)
79
+ self.batch_size = num
62
80
 
63
81
  if self.image_data is None:
64
82
  self.image_data = [None] * num
@@ -110,9 +128,10 @@ class TokenizedGenerateReqInput:
110
128
  @dataclass
111
129
  class BatchTokenIDOut:
112
130
  rids: List[str]
131
+ vids: List[int]
113
132
  decoded_texts: List[str]
114
- surr_output_ids: List[List[int]]
115
- read_output_ids: List[List[int]]
133
+ decode_ids: List[int]
134
+ read_offsets: List[int]
116
135
  skip_special_tokens: List[bool]
117
136
  spaces_between_special_tokens: List[bool]
118
137
  meta_info: List[Dict]