sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ limitations under the License.
15
15
 
16
16
  """A scheduler that manages a tensor parallel GPU worker."""
17
17
 
18
+ import dataclasses
18
19
  import logging
19
20
  import os
20
21
  import threading
@@ -29,16 +30,19 @@ import torch
29
30
  import zmq
30
31
 
31
32
  from sglang.global_config import global_config
32
- from sglang.srt.configs.model_config import ModelConfig
33
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
34
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
35
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
36
  from sglang.srt.managers.io_struct import (
36
37
  AbortReq,
37
38
  BatchEmbeddingOut,
38
39
  BatchTokenIDOut,
40
+ CloseSessionReqInput,
39
41
  FlushCacheReq,
40
42
  GetMemPoolSizeReq,
41
43
  GetMemPoolSizeReqOutput,
44
+ OpenSessionReqInput,
45
+ OpenSessionReqOutput,
42
46
  ProfileReq,
43
47
  TokenizedEmbeddingReqInput,
44
48
  TokenizedGenerateReqInput,
@@ -58,15 +62,18 @@ from sglang.srt.managers.schedule_policy import (
58
62
  PrefillAdder,
59
63
  SchedulePolicy,
60
64
  )
65
+ from sglang.srt.managers.session_controller import Session
61
66
  from sglang.srt.managers.tp_worker import TpModelWorker
62
67
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
68
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
64
69
  from sglang.srt.mem_cache.radix_cache import RadixCache
65
70
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
71
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
66
72
  from sglang.srt.server_args import PortArgs, ServerArgs
67
73
  from sglang.srt.utils import (
68
74
  broadcast_pyobj,
69
75
  configure_logger,
76
+ crash_on_warnings,
70
77
  get_zmq_socket,
71
78
  kill_parent_process,
72
79
  set_random_seed,
@@ -76,10 +83,6 @@ from sglang.utils import get_exception_traceback
76
83
 
77
84
  logger = logging.getLogger(__name__)
78
85
 
79
-
80
- # Crash on warning if we are running CI tests
81
- crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
82
-
83
86
  # Test retract decode
84
87
  test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
85
88
 
@@ -103,17 +106,23 @@ class Scheduler:
103
106
  self.disable_jump_forward = server_args.disable_jump_forward
104
107
  self.lora_paths = server_args.lora_paths
105
108
  self.max_loras_per_batch = server_args.max_loras_per_batch
106
- self.enable_overlap = server_args.enable_overlap_schedule
109
+ self.enable_overlap = not server_args.disable_overlap_schedule
107
110
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
111
  self.enable_metrics = server_args.enable_metrics
109
112
 
113
+ # Session info
114
+ self.sessions = {}
115
+
110
116
  # Init inter-process communication
111
117
  context = zmq.Context(2)
112
118
 
113
- if self.tp_rank == 0:
119
+ if self.tp_rank == 0 or self.server_args.enable_dp_attention:
114
120
  self.recv_from_tokenizer = get_zmq_socket(
115
121
  context, zmq.PULL, port_args.scheduler_input_ipc_name
116
122
  )
123
+ self.send_to_tokenizer = get_zmq_socket(
124
+ context, zmq.PUSH, port_args.tokenizer_ipc_name
125
+ )
117
126
 
118
127
  if server_args.skip_tokenizer_init:
119
128
  # Directly send to the tokenizer/api
@@ -127,6 +136,7 @@ class Scheduler:
127
136
  )
128
137
  else:
129
138
  self.recv_from_tokenizer = None
139
+ self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
130
140
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
131
141
 
132
142
  # Init tokenizer
@@ -156,6 +166,14 @@ class Scheduler:
156
166
  trust_remote_code=server_args.trust_remote_code,
157
167
  )
158
168
 
169
+ # Check whether overlap can be enabled
170
+ if not self.is_generation:
171
+ self.enable_overlap = False
172
+ logger.info("Overlap scheduler is disabled for embedding models.")
173
+
174
+ if self.enable_overlap:
175
+ self.disable_jump_forward = True
176
+
159
177
  # Launch a tensor parallel worker
160
178
  if self.enable_overlap:
161
179
  TpWorkerClass = TpModelWorkerClient
@@ -219,8 +237,12 @@ class Scheduler:
219
237
 
220
238
  # Init running status
221
239
  self.waiting_queue: List[Req] = []
240
+ # The running decoding batch for continuous batching
222
241
  self.running_batch: Optional[ScheduleBatch] = None
242
+ # The current forward batch
223
243
  self.cur_batch: Optional[ScheduleBatch] = None
244
+ # The current forward batch
245
+ self.last_batch: Optional[ScheduleBatch] = None
224
246
  self.forward_ct = 0
225
247
  self.forward_ct_decode = 0
226
248
  self.num_generated_tokens = 0
@@ -333,46 +355,34 @@ class Scheduler:
333
355
 
334
356
  kill_parent_process()
335
357
 
336
- @torch.inference_mode()
358
+ @torch.no_grad()
337
359
  def event_loop_normal(self):
338
- """A normal blocking scheduler loop."""
339
- self.last_batch = None
340
-
360
+ """A normal scheduler loop."""
341
361
  while True:
342
362
  recv_reqs = self.recv_requests()
343
363
  self.process_input_requests(recv_reqs)
344
364
 
345
365
  batch = self.get_next_batch_to_run()
366
+ if self.server_args.enable_dp_attention:
367
+ batch = self.prepare_dp_attn_batch(batch)
368
+
346
369
  self.cur_batch = batch
347
370
 
348
371
  if batch:
349
372
  result = self.run_batch(batch)
350
373
  self.process_batch_result(batch, result)
351
-
352
- # Decode multiple steps to reduce the overhead
353
- if batch.forward_mode.is_decode():
354
- for _ in range(self.server_args.num_continuous_decode_steps - 1):
355
- if not self.running_batch:
356
- break
357
- self.update_running_batch()
358
- if not self.running_batch:
359
- break
360
- result = self.run_batch(batch)
361
- self.process_batch_result(batch, result)
362
374
  else:
375
+ # Self-check and re-init some states when the server is idle
363
376
  self.check_memory()
364
377
  self.new_token_ratio = self.init_new_token_ratio
365
378
 
366
379
  self.last_batch = batch
367
380
 
368
- @torch.inference_mode()
381
+ @torch.no_grad()
369
382
  def event_loop_overlap(self):
370
383
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
371
384
  result_queue = deque()
372
385
 
373
- self.last_batch = None
374
- self.running_batch = None
375
-
376
386
  while True:
377
387
  recv_reqs = self.recv_requests()
378
388
  self.process_input_requests(recv_reqs)
@@ -383,17 +393,85 @@ class Scheduler:
383
393
  result = self.run_batch(batch)
384
394
  result_queue.append((batch.copy(), result))
385
395
 
396
+ if self.last_batch is None:
397
+ # A dummy first batch to start the pipeline for overlap scheduler.
398
+ # It is now used for triggering the sampling_info_done event.
399
+ tmp_batch = ScheduleBatch(
400
+ reqs=None,
401
+ forward_mode=ForwardMode.DUMMY_FIRST,
402
+ next_batch_sampling_info=self.tp_worker.cur_sampling_info,
403
+ )
404
+ self.process_batch_result(tmp_batch, None)
405
+
386
406
  if self.last_batch:
387
407
  tmp_batch, tmp_result = result_queue.popleft()
408
+ tmp_batch.next_batch_sampling_info = (
409
+ self.tp_worker.cur_sampling_info if batch else None
410
+ )
388
411
  self.process_batch_result(tmp_batch, tmp_result)
389
412
  elif batch is None:
413
+ # Self-check and re-init some states when the server is idle
390
414
  self.check_memory()
391
415
  self.new_token_ratio = self.init_new_token_ratio
392
416
 
393
417
  self.last_batch = batch
394
418
 
419
+ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
420
+ # Check if other DP workers have running batches
421
+ if local_batch is None:
422
+ num_tokens = 0
423
+ elif local_batch.forward_mode.is_decode():
424
+ num_tokens = local_batch.batch_size()
425
+ else:
426
+ num_tokens = local_batch.extend_num_tokens
427
+
428
+ local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
429
+ global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
430
+ torch.distributed.all_gather_into_tensor(
431
+ global_num_tokens,
432
+ local_num_tokens,
433
+ group=self.tp_cpu_group,
434
+ )
435
+
436
+ if local_batch is None and global_num_tokens.max().item() > 0:
437
+ local_batch = self.get_idle_batch()
438
+
439
+ if local_batch is not None:
440
+ local_batch.global_num_tokens = global_num_tokens.tolist()
441
+
442
+ # Check forward mode for cuda graph
443
+ if not self.server_args.disable_cuda_graph:
444
+ forward_mode_state = torch.tensor(
445
+ (
446
+ 1
447
+ if local_batch.forward_mode.is_decode()
448
+ or local_batch.forward_mode.is_idle()
449
+ else 0
450
+ ),
451
+ dtype=torch.int32,
452
+ )
453
+ torch.distributed.all_reduce(
454
+ forward_mode_state,
455
+ op=torch.distributed.ReduceOp.MIN,
456
+ group=self.tp_cpu_group,
457
+ )
458
+ local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
459
+
460
+ return local_batch
461
+
462
+ def get_idle_batch(self):
463
+ idle_batch = ScheduleBatch.init_new(
464
+ [],
465
+ self.req_to_token_pool,
466
+ self.token_to_kv_pool,
467
+ self.tree_cache,
468
+ self.model_config,
469
+ )
470
+ idle_batch.prepare_for_idle()
471
+ return idle_batch
472
+
395
473
  def recv_requests(self):
396
- if self.tp_rank == 0:
474
+ if self.tp_rank == 0 or self.server_args.enable_dp_attention:
397
475
  recv_reqs = []
398
476
 
399
477
  while True:
@@ -405,7 +483,7 @@ class Scheduler:
405
483
  else:
406
484
  recv_reqs = None
407
485
 
408
- if self.tp_size != 1:
486
+ if self.tp_size != 1 and not self.server_args.enable_dp_attention:
409
487
  recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
410
488
  return recv_reqs
411
489
 
@@ -421,7 +499,7 @@ class Scheduler:
421
499
  self.abort_request(recv_req)
422
500
  elif isinstance(recv_req, UpdateWeightReqInput):
423
501
  success, message = self.update_weights(recv_req)
424
- self.send_to_detokenizer.send_pyobj(
502
+ self.send_to_tokenizer.send_pyobj(
425
503
  UpdateWeightReqOutput(success, message)
426
504
  )
427
505
  elif isinstance(recv_req, ProfileReq):
@@ -429,8 +507,13 @@ class Scheduler:
429
507
  self.start_profile()
430
508
  else:
431
509
  self.stop_profile()
510
+ elif isinstance(recv_req, OpenSessionReqInput):
511
+ session_id = self.open_session(recv_req)
512
+ self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
513
+ elif isinstance(recv_req, CloseSessionReqInput):
514
+ self.close_session(recv_req)
432
515
  elif isinstance(recv_req, GetMemPoolSizeReq):
433
- self.send_to_detokenizer.send_pyobj(
516
+ self.send_to_tokenizer.send_pyobj(
434
517
  GetMemPoolSizeReqOutput(self.max_total_num_tokens)
435
518
  )
436
519
  else:
@@ -440,14 +523,30 @@ class Scheduler:
440
523
  self,
441
524
  recv_req: TokenizedGenerateReqInput,
442
525
  ):
443
- req = Req(
444
- recv_req.rid,
445
- recv_req.input_text,
446
- recv_req.input_ids,
447
- recv_req.sampling_params,
448
- lora_path=recv_req.lora_path,
449
- )
450
- req.tokenizer = self.tokenizer
526
+ if recv_req.session_id is None or recv_req.session_id not in self.sessions:
527
+ req = Req(
528
+ recv_req.rid,
529
+ recv_req.input_text,
530
+ recv_req.input_ids,
531
+ recv_req.sampling_params,
532
+ lora_path=recv_req.lora_path,
533
+ )
534
+ req.tokenizer = self.tokenizer
535
+ if recv_req.session_id is not None:
536
+ req.finished_reason = FINISH_ABORT(
537
+ f"Invalid request: session id {recv_req.session_id} does not exist"
538
+ )
539
+ self.waiting_queue.append(req)
540
+ return
541
+ else:
542
+ # Handle sessions
543
+ session = self.sessions[recv_req.session_id]
544
+ req, new_session_id = session.create_req(recv_req, self.tokenizer)
545
+ del self.sessions[recv_req.session_id]
546
+ self.sessions[new_session_id] = session
547
+ if isinstance(req.finished_reason, FINISH_ABORT):
548
+ self.waiting_queue.append(req)
549
+ return
451
550
 
452
551
  # Image inputs
453
552
  if recv_req.image_inputs is not None:
@@ -458,6 +557,15 @@ class Scheduler:
458
557
  req.origin_input_ids_unpadded, req.image_inputs
459
558
  )
460
559
 
560
+ if len(req.origin_input_ids) > self.max_req_input_len:
561
+ req.finished_reason = FINISH_ABORT(
562
+ "Image request length is longer than the KV cache pool size or "
563
+ "the max context length aborting because you cannot truncate the image embeds"
564
+ )
565
+ req.sampling_params.max_new_tokens = 0
566
+ self.waiting_queue.append(req)
567
+ return
568
+
461
569
  req.return_logprob = recv_req.return_logprob
462
570
  req.top_logprobs_num = recv_req.top_logprobs_num
463
571
  req.stream = recv_req.stream
@@ -595,21 +703,23 @@ class Scheduler:
595
703
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
596
704
  )
597
705
  if available_size != self.max_total_num_tokens:
598
- warnings.warn(
599
- "Warning: "
600
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
706
+ msg = (
601
707
  "KV cache pool leak detected!"
708
+ f"{available_size=}, {self.max_total_num_tokens=}\n"
602
709
  )
603
- exit(1) if crash_on_warning else None
710
+ warnings.warn(msg)
711
+ if crash_on_warnings():
712
+ raise ValueError(msg)
604
713
 
605
714
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
606
- warnings.warn(
607
- "Warning: "
608
- f"available req slots={len(self.req_to_token_pool.free_slots)}, "
609
- f"total slots={self.req_to_token_pool.size}\n"
715
+ msg = (
610
716
  "Memory pool leak detected!"
717
+ f"available_size={len(self.req_to_token_pool.free_slots)}, "
718
+ f"total_size={self.req_to_token_pool.size}\n"
611
719
  )
612
- exit(1) if crash_on_warning else None
720
+ warnings.warn(msg)
721
+ if crash_on_warnings():
722
+ raise ValueError(msg)
613
723
 
614
724
  def get_next_batch_to_run(self):
615
725
  # Merge the prefill batch into the running batch
@@ -743,7 +853,7 @@ class Scheduler:
743
853
  self.tree_cache,
744
854
  self.model_config,
745
855
  )
746
- new_batch.prepare_for_extend()
856
+ new_batch.prepare_for_extend(self.enable_overlap)
747
857
 
748
858
  # Mixed-style chunked prefill
749
859
  if self.is_mixed_chunk and self.running_batch is not None:
@@ -808,6 +918,10 @@ class Scheduler:
808
918
  logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
809
919
  model_worker_batch
810
920
  )
921
+ elif batch.forward_mode.is_idle():
922
+ model_worker_batch = batch.get_model_worker_batch()
923
+ self.tp_worker.forward_batch_idle(model_worker_batch)
924
+ return
811
925
  else:
812
926
  logits_output = None
813
927
  if self.skip_tokenizer_init:
@@ -830,8 +944,12 @@ class Scheduler:
830
944
  self.process_batch_result_decode(batch, result)
831
945
  if batch.is_empty():
832
946
  self.running_batch = None
833
- else:
947
+ elif batch.forward_mode.is_extend():
834
948
  self.process_batch_result_prefill(batch, result)
949
+ elif batch.forward_mode.is_dummy_first():
950
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
951
+ torch.cuda.current_stream().synchronize()
952
+ batch.next_batch_sampling_info.sampling_info_done.set()
835
953
 
836
954
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
837
955
 
@@ -839,7 +957,7 @@ class Scheduler:
839
957
  logits_output, next_token_ids, bid = result
840
958
 
841
959
  if self.enable_overlap:
842
- logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
960
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
843
961
  else:
844
962
  # Move next_token_ids and logprobs to cpu
845
963
  if batch.return_logprob:
@@ -859,14 +977,14 @@ class Scheduler:
859
977
 
860
978
  # Check finish conditions
861
979
  logprob_pt = 0
862
- for i, req in enumerate(batch.reqs):
980
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
863
981
  if req.is_retracted:
864
982
  continue
865
983
 
866
984
  if req.is_being_chunked <= 0:
867
985
  # Inflight reqs' prefill is not finished
868
986
  req.completion_tokens_wo_jump_forward += 1
869
- req.output_ids.append(next_token_ids[i])
987
+ req.output_ids.append(next_token_id)
870
988
  req.check_finished()
871
989
 
872
990
  if req.finished():
@@ -875,7 +993,7 @@ class Scheduler:
875
993
  self.tree_cache.cache_unfinished_req(req)
876
994
 
877
995
  if req.grammar is not None:
878
- req.grammar.accept_token(next_token_ids[i])
996
+ req.grammar.accept_token(next_token_id)
879
997
 
880
998
  if req.return_logprob:
881
999
  logprob_pt += self.add_logprob_return_values(
@@ -884,6 +1002,11 @@ class Scheduler:
884
1002
  else:
885
1003
  req.is_being_chunked -= 1
886
1004
 
1005
+ if batch.next_batch_sampling_info:
1006
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1007
+ torch.cuda.current_stream().synchronize()
1008
+ batch.next_batch_sampling_info.sampling_info_done.set()
1009
+
887
1010
  else: # embedding or reward model
888
1011
  embeddings, bid = result
889
1012
  embeddings = embeddings.tolist()
@@ -914,7 +1037,7 @@ class Scheduler:
914
1037
  self.num_generated_tokens += len(batch.reqs)
915
1038
 
916
1039
  if self.enable_overlap:
917
- logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
1040
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
918
1041
  next_token_logprobs = logits_output.next_token_logprobs
919
1042
  else:
920
1043
  # Move next_token_ids and logprobs to cpu
@@ -932,7 +1055,7 @@ class Scheduler:
932
1055
  if req.is_retracted:
933
1056
  continue
934
1057
 
935
- if self.server_args.enable_overlap_schedule and (req.finished()):
1058
+ if self.enable_overlap and req.finished():
936
1059
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
937
1060
  continue
938
1061
 
@@ -953,6 +1076,11 @@ class Scheduler:
953
1076
  if req.top_logprobs_num > 0:
954
1077
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
955
1078
 
1079
+ if batch.next_batch_sampling_info:
1080
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1081
+ torch.cuda.current_stream().synchronize()
1082
+ batch.next_batch_sampling_info.sampling_info_done.set()
1083
+
956
1084
  self.stream_output(batch.reqs)
957
1085
 
958
1086
  self.token_to_kv_pool.free_group_end()
@@ -1051,6 +1179,7 @@ class Scheduler:
1051
1179
  output_skip_special_tokens = []
1052
1180
  output_spaces_between_special_tokens = []
1053
1181
  output_no_stop_trim = []
1182
+ output_session_ids = []
1054
1183
  else: # embedding or reward model
1055
1184
  output_embeddings = []
1056
1185
 
@@ -1078,6 +1207,7 @@ class Scheduler:
1078
1207
  req.sampling_params.spaces_between_special_tokens
1079
1208
  )
1080
1209
  output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1210
+ output_session_ids.append(req.session_id)
1081
1211
 
1082
1212
  meta_info = {
1083
1213
  "prompt_tokens": len(req.origin_input_ids),
@@ -1128,6 +1258,7 @@ class Scheduler:
1128
1258
  output_meta_info,
1129
1259
  output_finished_reason,
1130
1260
  output_no_stop_trim,
1261
+ output_session_ids,
1131
1262
  )
1132
1263
  )
1133
1264
  else: # embedding or reward model
@@ -1230,6 +1361,25 @@ class Scheduler:
1230
1361
  )
1231
1362
  logger.info("Profiler is done")
1232
1363
 
1364
+ def open_session(self, recv_req: OpenSessionReqInput) -> str:
1365
+ # handle error
1366
+ session_id = recv_req.session_id
1367
+ if session_id in self.sessions:
1368
+ logger.warning(f"session id {session_id} already exist, cannot open.")
1369
+ else:
1370
+ self.sessions[session_id] = Session(
1371
+ recv_req.capacity_of_str_len, session_id
1372
+ )
1373
+ return session_id
1374
+
1375
+ def close_session(self, recv_req: CloseSessionReqInput):
1376
+ # handle error
1377
+ session_id = recv_req.session_id
1378
+ if session_id not in self.sessions:
1379
+ logger.warning(f"session id {session_id} does not exist, cannot delete.")
1380
+ else:
1381
+ del self.sessions[session_id]
1382
+
1233
1383
 
1234
1384
  def run_scheduler_process(
1235
1385
  server_args: ServerArgs,
@@ -1239,6 +1389,10 @@ def run_scheduler_process(
1239
1389
  dp_rank: Optional[int],
1240
1390
  pipe_writer,
1241
1391
  ):
1392
+ # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
1393
+ if dp_rank is None:
1394
+ dp_rank = int(os.getenv("DP_RANK", -1))
1395
+
1242
1396
  if dp_rank is None:
1243
1397
  configure_logger(server_args, prefix=f" TP{tp_rank}")
1244
1398
  else:
@@ -1249,7 +1403,7 @@ def run_scheduler_process(
1249
1403
  try:
1250
1404
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1251
1405
  pipe_writer.send("ready")
1252
- if server_args.enable_overlap_schedule:
1406
+ if scheduler.enable_overlap:
1253
1407
  scheduler.event_loop_overlap()
1254
1408
  else:
1255
1409
  scheduler.event_loop_normal()
@@ -0,0 +1,62 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ import copy
15
+ import uuid
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
20
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
21
+
22
+
23
+ class Session:
24
+ def __init__(self, capacity_of_str_len: int, session_id: str = None):
25
+ self.session_id = session_id if session_id is not None else uuid.uuid4().hex
26
+ self.capacity_of_str_len = capacity_of_str_len
27
+ self.reqs: List[Req] = []
28
+
29
+ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
30
+ # renew session id
31
+ self.session_id = uuid.uuid4().hex
32
+ if req.session_rid is not None:
33
+ while len(self.reqs) > 0:
34
+ if self.reqs[-1].rid == req.session_rid:
35
+ break
36
+ self.reqs = self.reqs[:-1]
37
+ if len(self.reqs) > 0:
38
+ input_ids = (
39
+ self.reqs[-1].origin_input_ids
40
+ + self.reqs[-1].output_ids[
41
+ : self.reqs[-1].sampling_params.max_new_tokens
42
+ ]
43
+ + req.input_ids
44
+ )
45
+ else:
46
+ input_ids = req.input_ids
47
+ new_req = Req(
48
+ req.rid,
49
+ None,
50
+ input_ids,
51
+ req.sampling_params,
52
+ lora_path=req.lora_path,
53
+ session_id=self.session_id,
54
+ )
55
+ new_req.tokenizer = tokenizer
56
+ if req.session_rid is not None and len(self.reqs) == 0:
57
+ new_req.finished_reason = FINISH_ABORT(
58
+ f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
59
+ )
60
+ else:
61
+ self.reqs.append(new_req)
62
+ return new_req, self.session_id