sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/runtime_endpoint.py +14 -4
  4. sglang/backend/vertexai.py +5 -4
  5. sglang/bench.py +627 -0
  6. sglang/bench_latency.py +22 -20
  7. sglang/bench_serving.py +758 -0
  8. sglang/check_env.py +171 -0
  9. sglang/global_config.py +3 -1
  10. sglang/lang/backend/__init__.py +0 -0
  11. sglang/lang/backend/anthropic.py +77 -0
  12. sglang/lang/backend/base_backend.py +80 -0
  13. sglang/lang/backend/litellm.py +90 -0
  14. sglang/lang/backend/openai.py +438 -0
  15. sglang/lang/backend/runtime_endpoint.py +283 -0
  16. sglang/lang/backend/vertexai.py +149 -0
  17. sglang/lang/chat_template.py +2 -2
  18. sglang/lang/ir.py +3 -3
  19. sglang/lang/tracer.py +1 -1
  20. sglang/launch_server.py +1 -1
  21. sglang/launch_server_llavavid.py +1 -4
  22. sglang/srt/conversation.py +1 -1
  23. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  24. sglang/srt/layers/extend_attention.py +0 -39
  25. sglang/srt/layers/linear.py +869 -0
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +31 -5
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
  31. sglang/srt/managers/controller/infer_batch.py +76 -72
  32. sglang/srt/managers/controller/manager_multi.py +109 -98
  33. sglang/srt/managers/controller/manager_single.py +105 -50
  34. sglang/srt/managers/controller/model_runner.py +42 -18
  35. sglang/srt/managers/controller/radix_cache.py +4 -3
  36. sglang/srt/managers/controller/schedule_heuristic.py +4 -0
  37. sglang/srt/managers/controller/tp_worker.py +143 -156
  38. sglang/srt/managers/detokenizer_manager.py +49 -5
  39. sglang/srt/managers/io_struct.py +36 -17
  40. sglang/srt/managers/tokenizer_manager.py +228 -125
  41. sglang/srt/memory_pool.py +46 -58
  42. sglang/srt/model_loader/model_loader.py +277 -0
  43. sglang/srt/model_loader/utils.py +260 -0
  44. sglang/srt/models/chatglm.py +1 -0
  45. sglang/srt/models/dbrx.py +1 -0
  46. sglang/srt/models/grok.py +1 -0
  47. sglang/srt/models/internlm2.py +317 -0
  48. sglang/srt/models/llama2.py +65 -16
  49. sglang/srt/models/llama_classification.py +1 -0
  50. sglang/srt/models/llava.py +1 -0
  51. sglang/srt/models/llavavid.py +1 -0
  52. sglang/srt/models/minicpm.py +2 -8
  53. sglang/srt/models/mixtral.py +1 -0
  54. sglang/srt/models/mixtral_quant.py +1 -0
  55. sglang/srt/models/qwen.py +1 -0
  56. sglang/srt/models/qwen2.py +6 -0
  57. sglang/srt/models/qwen2_moe.py +130 -108
  58. sglang/srt/models/stablelm.py +1 -0
  59. sglang/srt/openai_api/adapter.py +432 -0
  60. sglang/srt/openai_api/api_adapter.py +432 -0
  61. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  62. sglang/srt/openai_api/openai_protocol.py +207 -0
  63. sglang/srt/openai_api/protocol.py +208 -0
  64. sglang/srt/openai_protocol.py +17 -0
  65. sglang/srt/sampling_params.py +2 -0
  66. sglang/srt/server.py +114 -90
  67. sglang/srt/server_args.py +27 -17
  68. sglang/srt/utils.py +17 -118
  69. sglang/test/test_conversation.py +1 -1
  70. sglang/test/test_openai_protocol.py +1 -1
  71. sglang/test/test_programs.py +1 -1
  72. sglang/test/test_utils.py +2 -2
  73. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
  74. sglang-0.1.22.dist-info/RECORD +103 -0
  75. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  76. sglang-0.1.20.dist-info/RECORD +0 -82
  77. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  78. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,14 @@
1
1
  """A tensor parallel worker."""
2
2
 
3
- import asyncio
4
3
  import logging
4
+ import multiprocessing
5
+ import pickle
5
6
  import time
6
7
  import warnings
7
- from concurrent.futures import ThreadPoolExecutor
8
8
  from typing import List, Optional
9
9
 
10
- import rpyc
11
10
  import torch
12
- from rpyc.utils.classic import obtain
11
+ import torch.distributed as dist
13
12
 
14
13
  from sglang.global_config import global_config
15
14
  from sglang.srt.constrained.fsm_cache import FSMCache
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
32
31
  TokenizedGenerateReqInput,
33
32
  )
34
33
  from sglang.srt.model_config import ModelConfig
35
- from sglang.srt.server_args import ModelPortArgs, ServerArgs
34
+ from sglang.srt.server_args import ServerArgs
36
35
  from sglang.srt.utils import (
37
- connect_rpyc_service,
38
36
  get_int_token_logit_bias,
39
37
  is_multimodal_model,
40
38
  set_random_seed,
41
- start_rpyc_service_process,
42
39
  suppress_other_loggers,
43
40
  )
44
41
  from sglang.utils import get_exception_traceback
@@ -52,10 +49,9 @@ class ModelTpServer:
52
49
  gpu_id: int,
53
50
  tp_rank: int,
54
51
  server_args: ServerArgs,
55
- model_port_args: ModelPortArgs,
56
- model_overide_args,
52
+ nccl_port: int,
53
+ model_overide_args: dict,
57
54
  ):
58
- server_args, model_port_args = obtain(server_args), obtain(model_port_args)
59
55
  suppress_other_loggers()
60
56
 
61
57
  # Copy arguments
@@ -79,7 +75,7 @@ class ModelTpServer:
79
75
  gpu_id=gpu_id,
80
76
  tp_rank=tp_rank,
81
77
  tp_size=server_args.tp_size,
82
- nccl_port=model_port_args.nccl_port,
78
+ nccl_port=nccl_port,
83
79
  server_args=server_args,
84
80
  )
85
81
 
@@ -98,7 +94,7 @@ class ModelTpServer:
98
94
  )
99
95
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
100
96
  self.max_prefill_tokens = (
101
- 8192
97
+ 16384
102
98
  if server_args.max_prefill_tokens is None
103
99
  else server_args.max_prefill_tokens
104
100
  )
@@ -178,9 +174,6 @@ class ModelTpServer:
178
174
  self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
179
175
 
180
176
  def exposed_step(self, recv_reqs):
181
- if self.tp_size * self.dp_size != 1:
182
- recv_reqs = obtain(recv_reqs)
183
-
184
177
  try:
185
178
  # Recv requests
186
179
  for recv_req in recv_reqs:
@@ -206,11 +199,11 @@ class ModelTpServer:
206
199
 
207
200
  @torch.inference_mode()
208
201
  def forward_step(self):
209
- new_batch = self.get_new_fill_batch()
202
+ new_batch = self.get_new_prefill_batch()
210
203
 
211
204
  if new_batch is not None:
212
- # Run a new fill batch
213
- self.forward_fill_batch(new_batch)
205
+ # Run a new prefill batch
206
+ self.forward_prefill_batch(new_batch)
214
207
  self.cache_filled_batch(new_batch)
215
208
 
216
209
  if not new_batch.is_empty():
@@ -219,33 +212,16 @@ class ModelTpServer:
219
212
  else:
220
213
  self.running_batch.merge(new_batch)
221
214
  else:
222
- # Run decode batch
215
+ # Run a decode batch
223
216
  if self.running_batch is not None:
224
217
  # Run a few decode batches continuously for reducing overhead
225
- for _ in range(10):
218
+ for _ in range(global_config.num_continue_decode_steps):
226
219
  self.num_generated_tokens += len(self.running_batch.reqs)
227
220
  self.forward_decode_batch(self.running_batch)
228
221
 
229
222
  # Print stats
230
- if self.tp_rank == 0:
231
- if self.decode_forward_ct % 40 == 0:
232
- num_used = self.max_total_num_tokens - (
233
- self.token_to_kv_pool.available_size()
234
- + self.tree_cache.evictable_size()
235
- )
236
- throughput = self.num_generated_tokens / (
237
- time.time() - self.last_stats_tic
238
- )
239
- self.num_generated_tokens = 0
240
- self.last_stats_tic = time.time()
241
- logger.info(
242
- f"[gpu_id={self.gpu_id}] Decode batch. "
243
- f"#running-req: {len(self.running_batch.reqs)}, "
244
- f"#token: {num_used}, "
245
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
246
- f"gen throughput (token/s): {throughput:.2f}, "
247
- f"#queue-req: {len(self.forward_queue)}"
248
- )
223
+ if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
224
+ self.print_stats()
249
225
 
250
226
  if self.running_batch.is_empty():
251
227
  self.running_batch = None
@@ -254,17 +230,34 @@ class ModelTpServer:
254
230
  if self.out_pyobjs and self.running_batch.has_stream():
255
231
  break
256
232
  else:
257
- # Check the available size
258
- available_size = (
259
- self.token_to_kv_pool.available_size()
260
- + self.tree_cache.evictable_size()
261
- )
262
- if available_size != self.max_total_num_tokens:
263
- warnings.warn(
264
- "Warning: "
265
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
266
- "KV cache pool leak detected!"
267
- )
233
+ self.check_memory()
234
+
235
+ def print_stats(self):
236
+ num_used = self.max_total_num_tokens - (
237
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
238
+ )
239
+ throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
240
+ self.num_generated_tokens = 0
241
+ self.last_stats_tic = time.time()
242
+ logger.info(
243
+ f"[gpu_id={self.gpu_id}] Decode batch. "
244
+ f"#running-req: {len(self.running_batch.reqs)}, "
245
+ f"#token: {num_used}, "
246
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
247
+ f"gen throughput (token/s): {throughput:.2f}, "
248
+ f"#queue-req: {len(self.forward_queue)}"
249
+ )
250
+
251
+ def check_memory(self):
252
+ available_size = (
253
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
254
+ )
255
+ if available_size != self.max_total_num_tokens:
256
+ warnings.warn(
257
+ "Warning: "
258
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
259
+ "KV cache pool leak detected!"
260
+ )
268
261
 
269
262
  def handle_generate_request(
270
263
  self,
@@ -311,10 +304,18 @@ class ModelTpServer:
311
304
  self.model_config.context_len - 1 - len(req.origin_input_ids),
312
305
  self.max_total_num_tokens - 128 - len(req.origin_input_ids),
313
306
  )
307
+ if req.sampling_params.max_new_tokens < 0:
308
+ req.origin_input_ids = req.origin_input_ids[
309
+ : self.max_total_num_tokens - 128
310
+ ]
311
+ logger.error("Request longer than memory pool size, truncated!!!")
312
+
314
313
  self.forward_queue.append(req)
315
314
 
316
- def get_new_fill_batch(self) -> Optional[Batch]:
317
- running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
315
+ def get_new_prefill_batch(self) -> Optional[Batch]:
316
+ running_bs = (
317
+ len(self.running_batch.reqs) if self.running_batch is not None else 0
318
+ )
318
319
  if running_bs >= self.max_running_requests:
319
320
  return
320
321
 
@@ -342,7 +343,8 @@ class ModelTpServer:
342
343
  if self.running_batch:
343
344
  available_size -= sum(
344
345
  [
345
- (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
346
+ (r.sampling_params.max_new_tokens - len(r.output_ids))
347
+ * self.new_token_ratio
346
348
  for r in self.running_batch.reqs
347
349
  ]
348
350
  )
@@ -356,7 +358,7 @@ class ModelTpServer:
356
358
  req.prefix_indices = req.prefix_indices[:-delta]
357
359
  if req.image_offset is not None:
358
360
  req.image_offset += delta
359
- if req.extend_input_len == 0 and req.max_new_tokens() > 0:
361
+ if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
360
362
  # Need at least one token to compute logits
361
363
  req.extend_input_len = 1
362
364
  req.prefix_indices = req.prefix_indices[:-1]
@@ -364,7 +366,9 @@ class ModelTpServer:
364
366
  req.image_offset += 1
365
367
 
366
368
  if (
367
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
369
+ req.extend_input_len
370
+ + req.sampling_params.max_new_tokens
371
+ + new_batch_total_tokens
368
372
  < available_size
369
373
  and (
370
374
  req.extend_input_len + new_batch_input_tokens
@@ -376,7 +380,9 @@ class ModelTpServer:
376
380
  available_size += delta
377
381
 
378
382
  if not (
379
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
383
+ req.extend_input_len
384
+ + req.sampling_params.max_new_tokens
385
+ + new_batch_total_tokens
380
386
  < available_size
381
387
  ):
382
388
  # Undo locking
@@ -387,7 +393,7 @@ class ModelTpServer:
387
393
  # Add this request to the running batch
388
394
  can_run_list.append(req)
389
395
  new_batch_total_tokens += (
390
- req.extend_input_len + req.max_new_tokens()
396
+ req.extend_input_len + req.sampling_params.max_new_tokens
391
397
  )
392
398
  new_batch_input_tokens += req.extend_input_len
393
399
  else:
@@ -401,9 +407,6 @@ class ModelTpServer:
401
407
 
402
408
  # Print stats
403
409
  if self.tp_rank == 0:
404
- running_req = (
405
- 0 if self.running_batch is None else len(self.running_batch.reqs)
406
- )
407
410
  hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
408
411
  self.tree_cache_metrics["total"] += (
409
412
  hit_tokens + new_batch_input_tokens
@@ -418,15 +421,9 @@ class ModelTpServer:
418
421
  f"#new-token: {new_batch_input_tokens}, "
419
422
  f"#cached-token: {hit_tokens}, "
420
423
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
421
- f"#running-req: {running_req}, "
424
+ f"#running-req: {running_bs}, "
422
425
  f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
423
426
  )
424
- # logger.debug(
425
- # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
426
- # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
427
- # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
428
- # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
429
- # )
430
427
 
431
428
  # Return the new batch
432
429
  new_batch = Batch.init_new(
@@ -438,7 +435,7 @@ class ModelTpServer:
438
435
  self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
439
436
  return new_batch
440
437
 
441
- def forward_fill_batch(self, batch: Batch):
438
+ def forward_prefill_batch(self, batch: Batch):
442
439
  # Build batch tensors
443
440
  batch.prepare_for_extend(
444
441
  self.model_config.vocab_size, self.int_token_logit_bias
@@ -447,7 +444,7 @@ class ModelTpServer:
447
444
  # Forward and sample the next tokens
448
445
  if batch.extend_num_tokens != 0:
449
446
  output = self.model_runner.forward(batch, ForwardMode.EXTEND)
450
- next_token_ids, _ = batch.sample(output.next_token_logits)
447
+ next_token_ids = batch.sample(output.next_token_logits)
451
448
 
452
449
  # Move logprobs to cpu
453
450
  if output.next_token_logprobs is not None:
@@ -570,7 +567,7 @@ class ModelTpServer:
570
567
 
571
568
  # Forward and sample the next tokens
572
569
  output = self.model_runner.forward(batch, ForwardMode.DECODE)
573
- next_token_ids, _ = batch.sample(output.next_token_logits)
570
+ next_token_ids = batch.sample(output.next_token_logits)
574
571
 
575
572
  # Move logprobs to cpu
576
573
  if output.next_token_logprobs is not None:
@@ -598,9 +595,10 @@ class ModelTpServer:
598
595
 
599
596
  def handle_finished_requests(self, batch: Batch):
600
597
  output_rids = []
598
+ output_vids = []
601
599
  decoded_texts = []
602
- surr_output_ids = []
603
- read_output_ids = []
600
+ output_read_ids = []
601
+ output_read_offsets = []
604
602
  output_skip_special_tokens = []
605
603
  output_spaces_between_special_tokens = []
606
604
  output_meta_info = []
@@ -623,10 +621,11 @@ class ModelTpServer:
623
621
  )
624
622
  ):
625
623
  output_rids.append(req.rid)
624
+ output_vids.append(req.vid)
626
625
  decoded_texts.append(req.decoded_text)
627
- surr_ids, read_ids, _ = req.init_detokenize_incrementally()
628
- surr_output_ids.append(surr_ids)
629
- read_output_ids.append(read_ids)
626
+ read_ids, read_offset = req.init_incremental_detokenize()
627
+ output_read_ids.append(read_ids)
628
+ output_read_offsets.append(read_offset)
630
629
  output_skip_special_tokens.append(
631
630
  req.sampling_params.skip_special_tokens
632
631
  )
@@ -662,9 +661,10 @@ class ModelTpServer:
662
661
  self.out_pyobjs.append(
663
662
  BatchTokenIDOut(
664
663
  output_rids,
664
+ output_vids,
665
665
  decoded_texts,
666
- surr_output_ids,
667
- read_output_ids,
666
+ output_read_ids,
667
+ output_read_offsets,
668
668
  output_skip_special_tokens,
669
669
  output_spaces_between_special_tokens,
670
670
  output_meta_info,
@@ -729,87 +729,74 @@ class ModelTpServer:
729
729
  break
730
730
 
731
731
 
732
- class ModelTpService(rpyc.Service):
733
- exposed_ModelTpServer = ModelTpServer
734
-
735
-
736
- class ModelTpClient:
737
- def __init__(
738
- self,
739
- gpu_ids: List[int],
740
- server_args: ServerArgs,
741
- model_port_args: ModelPortArgs,
742
- model_overide_args,
743
- ):
744
- server_args, model_port_args = obtain(server_args), obtain(model_port_args)
745
- self.tp_size = server_args.tp_size
732
+ def run_tp_server(
733
+ gpu_id: int,
734
+ tp_rank: int,
735
+ server_args: ServerArgs,
736
+ nccl_port: int,
737
+ model_overide_args: dict,
738
+ ):
739
+ """Run a tensor parallel server."""
740
+ try:
741
+ model_server = ModelTpServer(
742
+ gpu_id,
743
+ tp_rank,
744
+ server_args,
745
+ nccl_port,
746
+ model_overide_args,
747
+ )
748
+ tp_cpu_group = model_server.model_runner.tp_group.cpu_group
749
+
750
+ while True:
751
+ recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
752
+ model_server.exposed_step(recv_reqs)
753
+ except Exception:
754
+ logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
755
+ raise
756
+
757
+
758
+ def launch_tp_servers(
759
+ gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
760
+ ):
761
+ """Launch multiple tensor parallel servers."""
762
+ procs = []
763
+ for i in tp_rank_range:
764
+ proc = multiprocessing.Process(
765
+ target=run_tp_server,
766
+ args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
767
+ )
768
+ proc.start()
769
+ procs.append(proc)
746
770
 
747
- if self.tp_size * server_args.dp_size == 1:
748
- # Init model
749
- assert len(gpu_ids) == 1
750
- self.model_server = ModelTpService().exposed_ModelTpServer(
751
- 0,
752
- gpu_ids[0],
753
- server_args,
754
- model_port_args,
755
- model_overide_args,
756
- )
771
+ return procs
757
772
 
758
- # Wrap functions
759
- def async_wrap(f):
760
- async def _func(*args, **kwargs):
761
- return f(*args, **kwargs)
762
773
 
763
- return _func
774
+ def broadcast_recv_input(data, rank, dist_group):
775
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
764
776
 
765
- self.step = async_wrap(self.model_server.exposed_step)
777
+ if rank == 0:
778
+ if len(data) == 0:
779
+ tensor_size = torch.tensor([0], dtype=torch.long)
780
+ dist.broadcast(tensor_size, src=0, group=dist_group)
766
781
  else:
767
- with ThreadPoolExecutor(self.tp_size) as executor:
768
- # Launch model processes
769
- if server_args.nnodes == 1:
770
- self.procs = list(
771
- executor.map(
772
- lambda args: start_rpyc_service_process(*args),
773
- [
774
- (ModelTpService, p)
775
- for p in model_port_args.model_tp_ports
776
- ],
777
- )
778
- )
779
- addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
780
- else:
781
- addrs = [
782
- (ip, port)
783
- for ip, port in zip(
784
- model_port_args.model_tp_ips, model_port_args.model_tp_ports
785
- )
786
- ]
787
-
788
- self.model_services = list(
789
- executor.map(lambda args: connect_rpyc_service(*args), addrs)
790
- )
791
-
792
- # Init model
793
- def init_model(i):
794
- return self.model_services[i].ModelTpServer(
795
- gpu_ids[i],
796
- i,
797
- server_args,
798
- model_port_args,
799
- model_overide_args,
800
- )
801
-
802
- self.model_servers = list(executor.map(init_model, range(self.tp_size)))
803
-
804
- # Wrap functions
805
- def async_wrap(func_name):
806
- fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
807
-
808
- async def _func(*args, **kwargs):
809
- tasks = [f(*args, **kwargs) for f in fs]
810
- await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
811
- return obtain(tasks[0].value)
812
-
813
- return _func
814
-
815
- self.step = async_wrap("step")
782
+ serialized_data = pickle.dumps(data)
783
+ size = len(serialized_data)
784
+ tensor_data = torch.ByteTensor(list(serialized_data))
785
+ tensor_size = torch.tensor([size], dtype=torch.long)
786
+
787
+ dist.broadcast(tensor_size, src=0, group=dist_group)
788
+ dist.broadcast(tensor_data, src=0, group=dist_group)
789
+ else:
790
+ tensor_size = torch.tensor([0], dtype=torch.long)
791
+ dist.broadcast(tensor_size, src=0, group=dist_group)
792
+ size = tensor_size.item()
793
+
794
+ if size == 0:
795
+ return []
796
+
797
+ tensor_data = torch.empty(size, dtype=torch.uint8)
798
+ dist.broadcast(tensor_data, src=0, group=dist_group)
799
+
800
+ serialized_data = bytes(tensor_data.tolist())
801
+ data = pickle.loads(serialized_data)
802
+ return data
@@ -1,7 +1,9 @@
1
1
  """DetokenizerManager is a process that detokenizes the token ids."""
2
2
 
3
3
  import asyncio
4
+ import dataclasses
4
5
  import inspect
6
+ from typing import List
5
7
 
6
8
  import uvloop
7
9
  import zmq
@@ -16,6 +18,15 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
16
18
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
17
19
 
18
20
 
21
+ @dataclasses.dataclass
22
+ class DecodeStatus:
23
+ vid: int
24
+ decoded_text: str
25
+ decode_ids: List[int]
26
+ surr_offset: int
27
+ read_offset: int
28
+
29
+
19
30
  class DetokenizerManager:
20
31
  def __init__(
21
32
  self,
@@ -35,19 +46,43 @@ class DetokenizerManager:
35
46
  trust_remote_code=server_args.trust_remote_code,
36
47
  )
37
48
 
49
+ self.decode_status = {}
50
+
38
51
  async def handle_loop(self):
39
52
  while True:
40
53
  recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
41
54
  assert isinstance(recv_obj, BatchTokenIDOut)
55
+ bs = len(recv_obj.rids)
56
+
57
+ # Initialize decode status
58
+ read_ids, surr_ids = [], []
59
+ for i in range(bs):
60
+ rid = recv_obj.rids[i]
61
+ vid = recv_obj.vids[i]
62
+ if rid not in self.decode_status or self.decode_status[rid].vid != vid:
63
+ s = DecodeStatus(
64
+ vid=vid,
65
+ decoded_text=recv_obj.decoded_texts[i],
66
+ decode_ids=recv_obj.decode_ids[i],
67
+ surr_offset=0,
68
+ read_offset=recv_obj.read_offsets[i],
69
+ )
70
+ self.decode_status[rid] = s
71
+ else:
72
+ s = self.decode_status[rid]
73
+ s.decode_ids = recv_obj.decode_ids[i]
74
+
75
+ read_ids.append(s.decode_ids[s.surr_offset :])
76
+ surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
42
77
 
43
78
  # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
79
  surr_texts = self.tokenizer.batch_decode(
45
- recv_obj.surr_output_ids,
80
+ surr_ids,
46
81
  skip_special_tokens=recv_obj.skip_special_tokens[0],
47
82
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
48
83
  )
49
84
  read_texts = self.tokenizer.batch_decode(
50
- recv_obj.read_output_ids,
85
+ read_ids,
51
86
  skip_special_tokens=recv_obj.skip_special_tokens[0],
52
87
  spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
53
88
  )
@@ -55,11 +90,20 @@ class DetokenizerManager:
55
90
  # Trim stop str
56
91
  # TODO(lmzheng): handle the case where multiple stop strs are hit
57
92
  output_strs = []
58
- for i in range(len(recv_obj.rids)):
93
+ for i in range(bs):
94
+ s = self.decode_status[recv_obj.rids[i]]
59
95
  new_text = read_texts[i][len(surr_texts[i]) :]
60
96
  if recv_obj.finished_reason[i] is None:
61
- new_text = find_printable_text(new_text)
62
- output_strs.append(recv_obj.decoded_texts[i] + new_text)
97
+ # Streaming chunk: update the decode status
98
+ if len(new_text) > 0 and not new_text.endswith("�"):
99
+ s.decoded_text = s.decoded_text + new_text
100
+ s.surr_offset = s.read_offset
101
+ s.read_offset = len(s.decode_ids)
102
+ new_text = ""
103
+ else:
104
+ new_text = find_printable_text(new_text)
105
+
106
+ output_strs.append(s.decoded_text + new_text)
63
107
 
64
108
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
65
109
  pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
@@ -13,25 +13,26 @@ from sglang.srt.sampling_params import SamplingParams
13
13
 
14
14
  @dataclass
15
15
  class GenerateReqInput:
16
- # The input prompt
16
+ # The input prompt. It can be a single prompt or a batch of prompts.
17
17
  text: Optional[Union[List[str], str]] = None
18
- # The token ids for text; one can either specify text or input_ids
18
+ # The token ids for text; one can either specify text or input_ids.
19
19
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
20
- # The image input
20
+ # The image input. It can be a file name, a url, or base64 encoded string.
21
+ # See also python/sglang/srt/utils.py:load_image.
21
22
  image_data: Optional[Union[List[str], str]] = None
22
- # The sampling_params
23
+ # The sampling_params.
23
24
  sampling_params: Union[List[Dict], Dict] = None
24
- # The request id
25
+ # The request id.
25
26
  rid: Optional[Union[List[str], str]] = None
26
- # Whether to return logprobs
27
+ # Whether to return logprobs.
27
28
  return_logprob: Optional[Union[List[bool], bool]] = None
28
- # The start location of the prompt for return_logprob
29
+ # The start location of the prompt for return_logprob.
29
30
  logprob_start_len: Optional[Union[List[int], int]] = None
30
- # The number of top logprobs to return
31
+ # The number of top logprobs to return.
31
32
  top_logprobs_num: Optional[Union[List[int], int]] = None
32
- # Whether to detokenize tokens in logprobs
33
+ # Whether to detokenize tokens in logprobs.
33
34
  return_text_in_logprobs: bool = False
34
- # Whether to stream output
35
+ # Whether to stream output.
35
36
  stream: bool = False
36
37
 
37
38
  def post_init(self):
@@ -39,11 +40,13 @@ class GenerateReqInput:
39
40
  self.text is not None and self.input_ids is not None
40
41
  ):
41
42
  raise ValueError("Either text or input_ids should be provided.")
42
-
43
- if self.text is not None:
44
- is_single = isinstance(self.text, str)
43
+ if self.sampling_params.get("n", 1) != 1:
44
+ is_single = False
45
45
  else:
46
- is_single = isinstance(self.input_ids[0], int)
46
+ if self.text is not None:
47
+ is_single = isinstance(self.text, str)
48
+ else:
49
+ is_single = isinstance(self.input_ids[0], int)
47
50
  self.is_single = is_single
48
51
 
49
52
  if is_single:
@@ -58,7 +61,22 @@ class GenerateReqInput:
58
61
  if self.top_logprobs_num is None:
59
62
  self.top_logprobs_num = 0
60
63
  else:
61
- num = len(self.text) if self.text is not None else len(self.input_ids)
64
+
65
+ parallel_sample_num = self.sampling_params.get("n", 1)
66
+
67
+ if parallel_sample_num != 1:
68
+ # parallel sampling +1 represents the original prefill stage
69
+ num = parallel_sample_num + 1
70
+ if isinstance(self.text, List):
71
+ ## suppot batch operation
72
+ self.batch_size = len(self.text)
73
+ num = num * len(self.text)
74
+ else:
75
+ self.batch_size = 1
76
+ else:
77
+ ## support select operation
78
+ num = len(self.text) if self.text is not None else len(self.input_ids)
79
+ self.batch_size = num
62
80
 
63
81
  if self.image_data is None:
64
82
  self.image_data = [None] * num
@@ -110,9 +128,10 @@ class TokenizedGenerateReqInput:
110
128
  @dataclass
111
129
  class BatchTokenIDOut:
112
130
  rids: List[str]
131
+ vids: List[int]
113
132
  decoded_texts: List[str]
114
- surr_output_ids: List[List[int]]
115
- read_output_ids: List[List[int]]
133
+ decode_ids: List[int]
134
+ read_offsets: List[int]
116
135
  skip_special_tokens: List[bool]
117
136
  spaces_between_special_tokens: List[bool]
118
137
  meta_info: List[Dict]