sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ limitations under the License.
15
15
 
16
16
  """A scheduler that manages a tensor parallel GPU worker."""
17
17
 
18
+ import dataclasses
18
19
  import logging
19
20
  import os
20
21
  import threading
@@ -29,16 +30,19 @@ import torch
29
30
  import zmq
30
31
 
31
32
  from sglang.global_config import global_config
32
- from sglang.srt.configs.model_config import ModelConfig
33
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
34
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
35
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
36
  from sglang.srt.managers.io_struct import (
36
37
  AbortReq,
37
38
  BatchEmbeddingOut,
38
39
  BatchTokenIDOut,
40
+ CloseSessionReqInput,
39
41
  FlushCacheReq,
40
42
  GetMemPoolSizeReq,
41
43
  GetMemPoolSizeReqOutput,
44
+ OpenSessionReqInput,
45
+ OpenSessionReqOutput,
42
46
  ProfileReq,
43
47
  TokenizedEmbeddingReqInput,
44
48
  TokenizedGenerateReqInput,
@@ -58,15 +62,18 @@ from sglang.srt.managers.schedule_policy import (
58
62
  PrefillAdder,
59
63
  SchedulePolicy,
60
64
  )
65
+ from sglang.srt.managers.session_controller import Session
61
66
  from sglang.srt.managers.tp_worker import TpModelWorker
62
67
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
68
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
64
69
  from sglang.srt.mem_cache.radix_cache import RadixCache
65
70
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
71
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
66
72
  from sglang.srt.server_args import PortArgs, ServerArgs
67
73
  from sglang.srt.utils import (
68
74
  broadcast_pyobj,
69
75
  configure_logger,
76
+ crash_on_warnings,
70
77
  get_zmq_socket,
71
78
  kill_parent_process,
72
79
  set_random_seed,
@@ -76,10 +83,6 @@ from sglang.utils import get_exception_traceback
76
83
 
77
84
  logger = logging.getLogger(__name__)
78
85
 
79
-
80
- # Crash on warning if we are running CI tests
81
- crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
82
-
83
86
  # Test retract decode
84
87
  test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
85
88
 
@@ -103,14 +106,17 @@ class Scheduler:
103
106
  self.disable_jump_forward = server_args.disable_jump_forward
104
107
  self.lora_paths = server_args.lora_paths
105
108
  self.max_loras_per_batch = server_args.max_loras_per_batch
106
- self.enable_overlap = server_args.enable_overlap_schedule
109
+ self.enable_overlap = not server_args.disable_overlap_schedule
107
110
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
111
  self.enable_metrics = server_args.enable_metrics
109
112
 
113
+ # Session info
114
+ self.sessions = {}
115
+
110
116
  # Init inter-process communication
111
117
  context = zmq.Context(2)
112
118
 
113
- if self.tp_rank == 0:
119
+ if self.tp_rank == 0 or self.server_args.enable_dp_attention:
114
120
  self.recv_from_tokenizer = get_zmq_socket(
115
121
  context, zmq.PULL, port_args.scheduler_input_ipc_name
116
122
  )
@@ -160,6 +166,14 @@ class Scheduler:
160
166
  trust_remote_code=server_args.trust_remote_code,
161
167
  )
162
168
 
169
+ # Check whether overlap can be enabled
170
+ if not self.is_generation:
171
+ self.enable_overlap = False
172
+ logger.info("Overlap scheduler is disabled for embedding models.")
173
+
174
+ if self.enable_overlap:
175
+ self.disable_jump_forward = True
176
+
163
177
  # Launch a tensor parallel worker
164
178
  if self.enable_overlap:
165
179
  TpWorkerClass = TpModelWorkerClient
@@ -223,8 +237,12 @@ class Scheduler:
223
237
 
224
238
  # Init running status
225
239
  self.waiting_queue: List[Req] = []
240
+ # The running decoding batch for continuous batching
226
241
  self.running_batch: Optional[ScheduleBatch] = None
242
+ # The current forward batch
227
243
  self.cur_batch: Optional[ScheduleBatch] = None
244
+ # The current forward batch
245
+ self.last_batch: Optional[ScheduleBatch] = None
228
246
  self.forward_ct = 0
229
247
  self.forward_ct_decode = 0
230
248
  self.num_generated_tokens = 0
@@ -337,46 +355,34 @@ class Scheduler:
337
355
 
338
356
  kill_parent_process()
339
357
 
340
- @torch.inference_mode()
358
+ @torch.no_grad()
341
359
  def event_loop_normal(self):
342
- """A normal blocking scheduler loop."""
343
- self.last_batch = None
344
-
360
+ """A normal scheduler loop."""
345
361
  while True:
346
362
  recv_reqs = self.recv_requests()
347
363
  self.process_input_requests(recv_reqs)
348
364
 
349
365
  batch = self.get_next_batch_to_run()
366
+ if self.server_args.enable_dp_attention:
367
+ batch = self.prepare_dp_attn_batch(batch)
368
+
350
369
  self.cur_batch = batch
351
370
 
352
371
  if batch:
353
372
  result = self.run_batch(batch)
354
373
  self.process_batch_result(batch, result)
355
-
356
- # Decode multiple steps to reduce the overhead
357
- if batch.forward_mode.is_decode():
358
- for _ in range(self.server_args.num_continuous_decode_steps - 1):
359
- if not self.running_batch:
360
- break
361
- self.update_running_batch()
362
- if not self.running_batch:
363
- break
364
- result = self.run_batch(batch)
365
- self.process_batch_result(batch, result)
366
374
  else:
375
+ # Self-check and re-init some states when the server is idle
367
376
  self.check_memory()
368
377
  self.new_token_ratio = self.init_new_token_ratio
369
378
 
370
379
  self.last_batch = batch
371
380
 
372
- @torch.inference_mode()
381
+ @torch.no_grad()
373
382
  def event_loop_overlap(self):
374
383
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
375
384
  result_queue = deque()
376
385
 
377
- self.last_batch = None
378
- self.running_batch = None
379
-
380
386
  while True:
381
387
  recv_reqs = self.recv_requests()
382
388
  self.process_input_requests(recv_reqs)
@@ -387,17 +393,85 @@ class Scheduler:
387
393
  result = self.run_batch(batch)
388
394
  result_queue.append((batch.copy(), result))
389
395
 
396
+ if self.last_batch is None:
397
+ # A dummy first batch to start the pipeline for overlap scheduler.
398
+ # It is now used for triggering the sampling_info_done event.
399
+ tmp_batch = ScheduleBatch(
400
+ reqs=None,
401
+ forward_mode=ForwardMode.DUMMY_FIRST,
402
+ next_batch_sampling_info=self.tp_worker.cur_sampling_info,
403
+ )
404
+ self.process_batch_result(tmp_batch, None)
405
+
390
406
  if self.last_batch:
391
407
  tmp_batch, tmp_result = result_queue.popleft()
408
+ tmp_batch.next_batch_sampling_info = (
409
+ self.tp_worker.cur_sampling_info if batch else None
410
+ )
392
411
  self.process_batch_result(tmp_batch, tmp_result)
393
412
  elif batch is None:
413
+ # Self-check and re-init some states when the server is idle
394
414
  self.check_memory()
395
415
  self.new_token_ratio = self.init_new_token_ratio
396
416
 
397
417
  self.last_batch = batch
398
418
 
419
+ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
420
+ # Check if other DP workers have running batches
421
+ if local_batch is None:
422
+ num_tokens = 0
423
+ elif local_batch.forward_mode.is_decode():
424
+ num_tokens = local_batch.batch_size()
425
+ else:
426
+ num_tokens = local_batch.extend_num_tokens
427
+
428
+ local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
429
+ global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
430
+ torch.distributed.all_gather_into_tensor(
431
+ global_num_tokens,
432
+ local_num_tokens,
433
+ group=self.tp_cpu_group,
434
+ )
435
+
436
+ if local_batch is None and global_num_tokens.max().item() > 0:
437
+ local_batch = self.get_idle_batch()
438
+
439
+ if local_batch is not None:
440
+ local_batch.global_num_tokens = global_num_tokens.tolist()
441
+
442
+ # Check forward mode for cuda graph
443
+ if not self.server_args.disable_cuda_graph:
444
+ forward_mode_state = torch.tensor(
445
+ (
446
+ 1
447
+ if local_batch.forward_mode.is_decode()
448
+ or local_batch.forward_mode.is_idle()
449
+ else 0
450
+ ),
451
+ dtype=torch.int32,
452
+ )
453
+ torch.distributed.all_reduce(
454
+ forward_mode_state,
455
+ op=torch.distributed.ReduceOp.MIN,
456
+ group=self.tp_cpu_group,
457
+ )
458
+ local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
459
+
460
+ return local_batch
461
+
462
+ def get_idle_batch(self):
463
+ idle_batch = ScheduleBatch.init_new(
464
+ [],
465
+ self.req_to_token_pool,
466
+ self.token_to_kv_pool,
467
+ self.tree_cache,
468
+ self.model_config,
469
+ )
470
+ idle_batch.prepare_for_idle()
471
+ return idle_batch
472
+
399
473
  def recv_requests(self):
400
- if self.tp_rank == 0:
474
+ if self.tp_rank == 0 or self.server_args.enable_dp_attention:
401
475
  recv_reqs = []
402
476
 
403
477
  while True:
@@ -409,7 +483,7 @@ class Scheduler:
409
483
  else:
410
484
  recv_reqs = None
411
485
 
412
- if self.tp_size != 1:
486
+ if self.tp_size != 1 and not self.server_args.enable_dp_attention:
413
487
  recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
414
488
  return recv_reqs
415
489
 
@@ -433,6 +507,11 @@ class Scheduler:
433
507
  self.start_profile()
434
508
  else:
435
509
  self.stop_profile()
510
+ elif isinstance(recv_req, OpenSessionReqInput):
511
+ session_id = self.open_session(recv_req)
512
+ self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
513
+ elif isinstance(recv_req, CloseSessionReqInput):
514
+ self.close_session(recv_req)
436
515
  elif isinstance(recv_req, GetMemPoolSizeReq):
437
516
  self.send_to_tokenizer.send_pyobj(
438
517
  GetMemPoolSizeReqOutput(self.max_total_num_tokens)
@@ -444,14 +523,30 @@ class Scheduler:
444
523
  self,
445
524
  recv_req: TokenizedGenerateReqInput,
446
525
  ):
447
- req = Req(
448
- recv_req.rid,
449
- recv_req.input_text,
450
- recv_req.input_ids,
451
- recv_req.sampling_params,
452
- lora_path=recv_req.lora_path,
453
- )
454
- req.tokenizer = self.tokenizer
526
+ if recv_req.session_id is None or recv_req.session_id not in self.sessions:
527
+ req = Req(
528
+ recv_req.rid,
529
+ recv_req.input_text,
530
+ recv_req.input_ids,
531
+ recv_req.sampling_params,
532
+ lora_path=recv_req.lora_path,
533
+ )
534
+ req.tokenizer = self.tokenizer
535
+ if recv_req.session_id is not None:
536
+ req.finished_reason = FINISH_ABORT(
537
+ f"Invalid request: session id {recv_req.session_id} does not exist"
538
+ )
539
+ self.waiting_queue.append(req)
540
+ return
541
+ else:
542
+ # Handle sessions
543
+ session = self.sessions[recv_req.session_id]
544
+ req, new_session_id = session.create_req(recv_req, self.tokenizer)
545
+ del self.sessions[recv_req.session_id]
546
+ self.sessions[new_session_id] = session
547
+ if isinstance(req.finished_reason, FINISH_ABORT):
548
+ self.waiting_queue.append(req)
549
+ return
455
550
 
456
551
  # Image inputs
457
552
  if recv_req.image_inputs is not None:
@@ -462,6 +557,15 @@ class Scheduler:
462
557
  req.origin_input_ids_unpadded, req.image_inputs
463
558
  )
464
559
 
560
+ if len(req.origin_input_ids) > self.max_req_input_len:
561
+ req.finished_reason = FINISH_ABORT(
562
+ "Image request length is longer than the KV cache pool size or "
563
+ "the max context length aborting because you cannot truncate the image embeds"
564
+ )
565
+ req.sampling_params.max_new_tokens = 0
566
+ self.waiting_queue.append(req)
567
+ return
568
+
465
569
  req.return_logprob = recv_req.return_logprob
466
570
  req.top_logprobs_num = recv_req.top_logprobs_num
467
571
  req.stream = recv_req.stream
@@ -599,21 +703,23 @@ class Scheduler:
599
703
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
600
704
  )
601
705
  if available_size != self.max_total_num_tokens:
602
- warnings.warn(
603
- "Warning: "
604
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
706
+ msg = (
605
707
  "KV cache pool leak detected!"
708
+ f"{available_size=}, {self.max_total_num_tokens=}\n"
606
709
  )
607
- exit(1) if crash_on_warning else None
710
+ warnings.warn(msg)
711
+ if crash_on_warnings():
712
+ raise ValueError(msg)
608
713
 
609
714
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
610
- warnings.warn(
611
- "Warning: "
612
- f"available req slots={len(self.req_to_token_pool.free_slots)}, "
613
- f"total slots={self.req_to_token_pool.size}\n"
715
+ msg = (
614
716
  "Memory pool leak detected!"
717
+ f"available_size={len(self.req_to_token_pool.free_slots)}, "
718
+ f"total_size={self.req_to_token_pool.size}\n"
615
719
  )
616
- exit(1) if crash_on_warning else None
720
+ warnings.warn(msg)
721
+ if crash_on_warnings():
722
+ raise ValueError(msg)
617
723
 
618
724
  def get_next_batch_to_run(self):
619
725
  # Merge the prefill batch into the running batch
@@ -747,7 +853,7 @@ class Scheduler:
747
853
  self.tree_cache,
748
854
  self.model_config,
749
855
  )
750
- new_batch.prepare_for_extend()
856
+ new_batch.prepare_for_extend(self.enable_overlap)
751
857
 
752
858
  # Mixed-style chunked prefill
753
859
  if self.is_mixed_chunk and self.running_batch is not None:
@@ -812,6 +918,10 @@ class Scheduler:
812
918
  logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
813
919
  model_worker_batch
814
920
  )
921
+ elif batch.forward_mode.is_idle():
922
+ model_worker_batch = batch.get_model_worker_batch()
923
+ self.tp_worker.forward_batch_idle(model_worker_batch)
924
+ return
815
925
  else:
816
926
  logits_output = None
817
927
  if self.skip_tokenizer_init:
@@ -834,8 +944,12 @@ class Scheduler:
834
944
  self.process_batch_result_decode(batch, result)
835
945
  if batch.is_empty():
836
946
  self.running_batch = None
837
- else:
947
+ elif batch.forward_mode.is_extend():
838
948
  self.process_batch_result_prefill(batch, result)
949
+ elif batch.forward_mode.is_dummy_first():
950
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
951
+ torch.cuda.current_stream().synchronize()
952
+ batch.next_batch_sampling_info.sampling_info_done.set()
839
953
 
840
954
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
841
955
 
@@ -843,7 +957,7 @@ class Scheduler:
843
957
  logits_output, next_token_ids, bid = result
844
958
 
845
959
  if self.enable_overlap:
846
- logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
960
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
847
961
  else:
848
962
  # Move next_token_ids and logprobs to cpu
849
963
  if batch.return_logprob:
@@ -863,14 +977,14 @@ class Scheduler:
863
977
 
864
978
  # Check finish conditions
865
979
  logprob_pt = 0
866
- for i, req in enumerate(batch.reqs):
980
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
867
981
  if req.is_retracted:
868
982
  continue
869
983
 
870
984
  if req.is_being_chunked <= 0:
871
985
  # Inflight reqs' prefill is not finished
872
986
  req.completion_tokens_wo_jump_forward += 1
873
- req.output_ids.append(next_token_ids[i])
987
+ req.output_ids.append(next_token_id)
874
988
  req.check_finished()
875
989
 
876
990
  if req.finished():
@@ -879,7 +993,7 @@ class Scheduler:
879
993
  self.tree_cache.cache_unfinished_req(req)
880
994
 
881
995
  if req.grammar is not None:
882
- req.grammar.accept_token(next_token_ids[i])
996
+ req.grammar.accept_token(next_token_id)
883
997
 
884
998
  if req.return_logprob:
885
999
  logprob_pt += self.add_logprob_return_values(
@@ -888,6 +1002,11 @@ class Scheduler:
888
1002
  else:
889
1003
  req.is_being_chunked -= 1
890
1004
 
1005
+ if batch.next_batch_sampling_info:
1006
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1007
+ torch.cuda.current_stream().synchronize()
1008
+ batch.next_batch_sampling_info.sampling_info_done.set()
1009
+
891
1010
  else: # embedding or reward model
892
1011
  embeddings, bid = result
893
1012
  embeddings = embeddings.tolist()
@@ -918,7 +1037,7 @@ class Scheduler:
918
1037
  self.num_generated_tokens += len(batch.reqs)
919
1038
 
920
1039
  if self.enable_overlap:
921
- logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
1040
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
922
1041
  next_token_logprobs = logits_output.next_token_logprobs
923
1042
  else:
924
1043
  # Move next_token_ids and logprobs to cpu
@@ -936,7 +1055,7 @@ class Scheduler:
936
1055
  if req.is_retracted:
937
1056
  continue
938
1057
 
939
- if self.server_args.enable_overlap_schedule and (req.finished()):
1058
+ if self.enable_overlap and req.finished():
940
1059
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
941
1060
  continue
942
1061
 
@@ -957,6 +1076,11 @@ class Scheduler:
957
1076
  if req.top_logprobs_num > 0:
958
1077
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
959
1078
 
1079
+ if batch.next_batch_sampling_info:
1080
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1081
+ torch.cuda.current_stream().synchronize()
1082
+ batch.next_batch_sampling_info.sampling_info_done.set()
1083
+
960
1084
  self.stream_output(batch.reqs)
961
1085
 
962
1086
  self.token_to_kv_pool.free_group_end()
@@ -1055,6 +1179,7 @@ class Scheduler:
1055
1179
  output_skip_special_tokens = []
1056
1180
  output_spaces_between_special_tokens = []
1057
1181
  output_no_stop_trim = []
1182
+ output_session_ids = []
1058
1183
  else: # embedding or reward model
1059
1184
  output_embeddings = []
1060
1185
 
@@ -1082,6 +1207,7 @@ class Scheduler:
1082
1207
  req.sampling_params.spaces_between_special_tokens
1083
1208
  )
1084
1209
  output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1210
+ output_session_ids.append(req.session_id)
1085
1211
 
1086
1212
  meta_info = {
1087
1213
  "prompt_tokens": len(req.origin_input_ids),
@@ -1132,6 +1258,7 @@ class Scheduler:
1132
1258
  output_meta_info,
1133
1259
  output_finished_reason,
1134
1260
  output_no_stop_trim,
1261
+ output_session_ids,
1135
1262
  )
1136
1263
  )
1137
1264
  else: # embedding or reward model
@@ -1234,6 +1361,25 @@ class Scheduler:
1234
1361
  )
1235
1362
  logger.info("Profiler is done")
1236
1363
 
1364
+ def open_session(self, recv_req: OpenSessionReqInput) -> str:
1365
+ # handle error
1366
+ session_id = recv_req.session_id
1367
+ if session_id in self.sessions:
1368
+ logger.warning(f"session id {session_id} already exist, cannot open.")
1369
+ else:
1370
+ self.sessions[session_id] = Session(
1371
+ recv_req.capacity_of_str_len, session_id
1372
+ )
1373
+ return session_id
1374
+
1375
+ def close_session(self, recv_req: CloseSessionReqInput):
1376
+ # handle error
1377
+ session_id = recv_req.session_id
1378
+ if session_id not in self.sessions:
1379
+ logger.warning(f"session id {session_id} does not exist, cannot delete.")
1380
+ else:
1381
+ del self.sessions[session_id]
1382
+
1237
1383
 
1238
1384
  def run_scheduler_process(
1239
1385
  server_args: ServerArgs,
@@ -1243,6 +1389,10 @@ def run_scheduler_process(
1243
1389
  dp_rank: Optional[int],
1244
1390
  pipe_writer,
1245
1391
  ):
1392
+ # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
1393
+ if dp_rank is None:
1394
+ dp_rank = int(os.getenv("DP_RANK", -1))
1395
+
1246
1396
  if dp_rank is None:
1247
1397
  configure_logger(server_args, prefix=f" TP{tp_rank}")
1248
1398
  else:
@@ -1253,7 +1403,7 @@ def run_scheduler_process(
1253
1403
  try:
1254
1404
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1255
1405
  pipe_writer.send("ready")
1256
- if server_args.enable_overlap_schedule:
1406
+ if scheduler.enable_overlap:
1257
1407
  scheduler.event_loop_overlap()
1258
1408
  else:
1259
1409
  scheduler.event_loop_normal()
@@ -0,0 +1,62 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ import copy
15
+ import uuid
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
20
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
21
+
22
+
23
+ class Session:
24
+ def __init__(self, capacity_of_str_len: int, session_id: str = None):
25
+ self.session_id = session_id if session_id is not None else uuid.uuid4().hex
26
+ self.capacity_of_str_len = capacity_of_str_len
27
+ self.reqs: List[Req] = []
28
+
29
+ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
30
+ # renew session id
31
+ self.session_id = uuid.uuid4().hex
32
+ if req.session_rid is not None:
33
+ while len(self.reqs) > 0:
34
+ if self.reqs[-1].rid == req.session_rid:
35
+ break
36
+ self.reqs = self.reqs[:-1]
37
+ if len(self.reqs) > 0:
38
+ input_ids = (
39
+ self.reqs[-1].origin_input_ids
40
+ + self.reqs[-1].output_ids[
41
+ : self.reqs[-1].sampling_params.max_new_tokens
42
+ ]
43
+ + req.input_ids
44
+ )
45
+ else:
46
+ input_ids = req.input_ids
47
+ new_req = Req(
48
+ req.rid,
49
+ None,
50
+ input_ids,
51
+ req.sampling_params,
52
+ lora_path=req.lora_path,
53
+ session_id=self.session_id,
54
+ )
55
+ new_req.tokenizer = tokenizer
56
+ if req.session_rid is not None and len(self.reqs) == 0:
57
+ new_req.finished_reason = FINISH_ABORT(
58
+ f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
59
+ )
60
+ else:
61
+ self.reqs.append(new_req)
62
+ return new_req, self.session_id
@@ -23,6 +23,7 @@ import os
23
23
  import signal
24
24
  import sys
25
25
  import time
26
+ import uuid
26
27
  from typing import Dict, List, Optional, Tuple, Union
27
28
 
28
29
  import fastapi
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
42
43
  BatchEmbeddingOut,
43
44
  BatchStrOut,
44
45
  BatchTokenIDOut,
46
+ CloseSessionReqInput,
45
47
  EmbeddingReqInput,
46
48
  FlushCacheReq,
47
49
  GenerateReqInput,
48
50
  GetMemPoolSizeReq,
49
51
  GetMemPoolSizeReqOutput,
52
+ OpenSessionReqInput,
53
+ OpenSessionReqOutput,
50
54
  ProfileReq,
51
55
  TokenizedEmbeddingReqInput,
52
56
  TokenizedGenerateReqInput,
@@ -146,6 +150,9 @@ class TokenizerManager:
146
150
  self.model_update_lock = asyncio.Lock()
147
151
  self.model_update_result = None
148
152
 
153
+ # For session info
154
+ self.session_futures = {} # session_id -> asyncio event
155
+
149
156
  # Others
150
157
  self.gracefully_exit = False
151
158
 
@@ -211,6 +218,8 @@ class TokenizerManager:
211
218
  return_logprob = obj.return_logprob
212
219
  logprob_start_len = obj.logprob_start_len
213
220
  top_logprobs_num = obj.top_logprobs_num
221
+ session_id = obj.session_id
222
+ session_rid = obj.session_rid
214
223
 
215
224
  if len(input_ids) >= self.context_len:
216
225
  raise ValueError(
@@ -236,6 +245,8 @@ class TokenizerManager:
236
245
  top_logprobs_num,
237
246
  obj.stream,
238
247
  obj.lora_path,
248
+ session_id=session_id,
249
+ session_rid=session_rid,
239
250
  )
240
251
  elif isinstance(obj, EmbeddingReqInput):
241
252
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -451,6 +462,26 @@ class TokenizerManager:
451
462
  else:
452
463
  return False, "Another update is in progress. Please try again later."
453
464
 
465
+ async def open_session(
466
+ self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
467
+ ):
468
+ if self.to_create_loop:
469
+ self.create_handle_loop()
470
+
471
+ session_id = uuid.uuid4().hex
472
+ obj.session_id = session_id
473
+ self.send_to_scheduler.send_pyobj(obj)
474
+ self.session_futures[session_id] = asyncio.Future()
475
+ session_id = await self.session_futures[session_id]
476
+ del self.session_futures[session_id]
477
+ return session_id
478
+
479
+ async def close_session(
480
+ self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
481
+ ):
482
+ assert not self.to_create_loop, "close session should not be the first request"
483
+ await self.send_to_scheduler.send_pyobj(obj)
484
+
454
485
  def create_abort_task(self, obj: GenerateReqInput):
455
486
  # Abort the request if the client is disconnected.
456
487
  async def abort_request():
@@ -521,6 +552,11 @@ class TokenizerManager:
521
552
  if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
522
553
  self.mem_pool_size.set_result(self.mem_pool_size_tmp)
523
554
  continue
555
+ elif isinstance(recv_obj, OpenSessionReqOutput):
556
+ self.session_futures[recv_obj.session_id].set_result(
557
+ recv_obj.session_id
558
+ )
559
+ continue
524
560
 
525
561
  assert isinstance(
526
562
  recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
@@ -536,11 +572,13 @@ class TokenizerManager:
536
572
  out_dict = {
537
573
  "text": recv_obj.output_strs[i],
538
574
  "meta_info": recv_obj.meta_info[i],
575
+ "session_id": recv_obj.session_ids[i],
539
576
  }
540
577
  elif isinstance(recv_obj, BatchTokenIDOut):
541
578
  out_dict = {
542
579
  "token_ids": recv_obj.output_ids[i],
543
580
  "meta_info": recv_obj.meta_info[i],
581
+ "session_id": recv_obj.session_ids[i],
544
582
  }
545
583
  else:
546
584
  assert isinstance(recv_obj, BatchEmbeddingOut)