sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_one_batch.py +2 -4
  4. sglang/bench_serving.py +75 -26
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +2 -2
  7. sglang/srt/configs/model_config.py +13 -14
  8. sglang/srt/constrained/__init__.py +13 -14
  9. sglang/srt/constrained/base_grammar_backend.py +13 -15
  10. sglang/srt/constrained/outlines_backend.py +13 -15
  11. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  12. sglang/srt/constrained/xgrammar_backend.py +38 -57
  13. sglang/srt/conversation.py +13 -15
  14. sglang/srt/hf_transformers_utils.py +13 -15
  15. sglang/srt/layers/activation.py +13 -13
  16. sglang/srt/layers/attention/flashinfer_backend.py +13 -6
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  18. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  19. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  20. sglang/srt/layers/custom_op_util.py +13 -14
  21. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  22. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  23. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  24. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  25. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  26. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  27. sglang/srt/layers/layernorm.py +13 -15
  28. sglang/srt/layers/logits_processor.py +13 -15
  29. sglang/srt/layers/quantization/__init__.py +77 -17
  30. sglang/srt/layers/radix_attention.py +13 -15
  31. sglang/srt/layers/rotary_embedding.py +13 -13
  32. sglang/srt/lora/lora.py +13 -14
  33. sglang/srt/lora/lora_config.py +13 -14
  34. sglang/srt/lora/lora_manager.py +22 -24
  35. sglang/srt/managers/data_parallel_controller.py +25 -19
  36. sglang/srt/managers/detokenizer_manager.py +13 -16
  37. sglang/srt/managers/io_struct.py +43 -28
  38. sglang/srt/managers/schedule_batch.py +55 -26
  39. sglang/srt/managers/schedule_policy.py +13 -15
  40. sglang/srt/managers/scheduler.py +89 -70
  41. sglang/srt/managers/session_controller.py +14 -15
  42. sglang/srt/managers/tokenizer_manager.py +29 -22
  43. sglang/srt/managers/tp_worker.py +13 -15
  44. sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
  45. sglang/srt/metrics/collector.py +13 -15
  46. sglang/srt/metrics/func_timer.py +13 -15
  47. sglang/srt/mm_utils.py +13 -14
  48. sglang/srt/model_executor/cuda_graph_runner.py +20 -19
  49. sglang/srt/model_executor/forward_batch_info.py +19 -17
  50. sglang/srt/model_executor/model_runner.py +42 -30
  51. sglang/srt/models/chatglm.py +15 -16
  52. sglang/srt/models/commandr.py +15 -16
  53. sglang/srt/models/dbrx.py +15 -16
  54. sglang/srt/models/deepseek.py +15 -15
  55. sglang/srt/models/deepseek_v2.py +15 -15
  56. sglang/srt/models/exaone.py +14 -15
  57. sglang/srt/models/gemma.py +14 -14
  58. sglang/srt/models/gemma2.py +24 -19
  59. sglang/srt/models/gemma2_reward.py +13 -14
  60. sglang/srt/models/gpt_bigcode.py +14 -14
  61. sglang/srt/models/grok.py +15 -15
  62. sglang/srt/models/internlm2.py +13 -15
  63. sglang/srt/models/internlm2_reward.py +13 -14
  64. sglang/srt/models/llama.py +21 -21
  65. sglang/srt/models/llama_classification.py +13 -14
  66. sglang/srt/models/llama_reward.py +13 -14
  67. sglang/srt/models/llava.py +13 -15
  68. sglang/srt/models/llavavid.py +13 -15
  69. sglang/srt/models/minicpm.py +13 -15
  70. sglang/srt/models/minicpm3.py +13 -15
  71. sglang/srt/models/mistral.py +13 -15
  72. sglang/srt/models/mixtral.py +15 -15
  73. sglang/srt/models/mixtral_quant.py +14 -14
  74. sglang/srt/models/olmo.py +21 -19
  75. sglang/srt/models/olmoe.py +23 -20
  76. sglang/srt/models/qwen.py +14 -14
  77. sglang/srt/models/qwen2.py +22 -19
  78. sglang/srt/models/qwen2_moe.py +17 -18
  79. sglang/srt/models/stablelm.py +18 -16
  80. sglang/srt/models/torch_native_llama.py +15 -17
  81. sglang/srt/models/xverse.py +13 -14
  82. sglang/srt/models/xverse_moe.py +15 -16
  83. sglang/srt/models/yivl.py +13 -15
  84. sglang/srt/openai_api/adapter.py +13 -15
  85. sglang/srt/openai_api/protocol.py +13 -15
  86. sglang/srt/sampling/sampling_batch_info.py +4 -1
  87. sglang/srt/sampling/sampling_params.py +13 -15
  88. sglang/srt/server.py +59 -34
  89. sglang/srt/server_args.py +22 -22
  90. sglang/srt/utils.py +196 -17
  91. sglang/test/few_shot_gsm8k.py +8 -4
  92. sglang/test/runners.py +13 -14
  93. sglang/test/test_utils.py +1 -1
  94. sglang/version.py +1 -1
  95. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  96. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
  97. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  98. sglang/srt/layers/fused_moe/__init__.py +0 -1
  99. sglang-0.3.6.dist-info/RECORD +0 -161
  100. /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
  101. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
  102. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Store information about requests and batches.
18
16
 
@@ -180,6 +178,7 @@ class Req:
180
178
  origin_input_ids: Tuple[int],
181
179
  sampling_params: SamplingParams,
182
180
  lora_path: Optional[str] = None,
181
+ input_embeds: Optional[List[List[float]]] = None,
183
182
  session_id: Optional[str] = None,
184
183
  ):
185
184
  # Input and output info
@@ -193,6 +192,7 @@ class Req:
193
192
 
194
193
  self.sampling_params = sampling_params
195
194
  self.lora_path = lora_path
195
+ self.input_embeds = input_embeds
196
196
 
197
197
  # Memory pool info
198
198
  self.req_pool_idx = None
@@ -439,14 +439,18 @@ class ScheduleBatch:
439
439
  token_to_kv_pool: BaseTokenToKVPool = None
440
440
  tree_cache: BasePrefixCache = None
441
441
 
442
- # For utility
442
+ # Batch configs
443
443
  model_config: ModelConfig = None
444
444
  forward_mode: ForwardMode = None
445
+ enable_overlap: bool = False
446
+
447
+ # Sampling info
445
448
  sampling_info: SamplingBatchInfo = None
446
449
  next_batch_sampling_info: SamplingBatchInfo = None
447
450
 
448
451
  # Batched arguments to model runner
449
452
  input_ids: torch.Tensor = None
453
+ input_embeds: torch.Tensor = None
450
454
  req_pool_indices: torch.Tensor = None
451
455
  seq_lens: torch.Tensor = None
452
456
  # The output locations of the KV cache
@@ -469,6 +473,7 @@ class ScheduleBatch:
469
473
  extend_lens: List[int] = None
470
474
  extend_num_tokens: int = None
471
475
  decoding_reqs: List[Req] = None
476
+ extend_logprob_start_lens: List[int] = None
472
477
 
473
478
  # For encoder-decoder
474
479
  encoder_cached: Optional[List[bool]] = None
@@ -489,10 +494,11 @@ class ScheduleBatch:
489
494
  def init_new(
490
495
  cls,
491
496
  reqs: List[Req],
492
- req_to_token_pool,
493
- token_to_kv_pool,
494
- tree_cache,
495
- model_config,
497
+ req_to_token_pool: ReqToTokenPool,
498
+ token_to_kv_pool: ReqToTokenPool,
499
+ tree_cache: BasePrefixCache,
500
+ model_config: ModelConfig,
501
+ enable_overlap: bool,
496
502
  ):
497
503
  return cls(
498
504
  reqs=reqs,
@@ -500,6 +506,7 @@ class ScheduleBatch:
500
506
  token_to_kv_pool=token_to_kv_pool,
501
507
  tree_cache=tree_cache,
502
508
  model_config=model_config,
509
+ enable_overlap=enable_overlap,
503
510
  return_logprob=any(req.return_logprob for req in reqs),
504
511
  has_stream=any(req.stream for req in reqs),
505
512
  has_grammar=any(req.grammar for req in reqs),
@@ -613,7 +620,7 @@ class ScheduleBatch:
613
620
 
614
621
  assert len(self.out_cache_loc) == self.extend_num_tokens
615
622
 
616
- def prepare_for_extend(self, enable_overlap_schedule: bool = False):
623
+ def prepare_for_extend(self):
617
624
  self.forward_mode = ForwardMode.EXTEND
618
625
 
619
626
  bs = len(self.reqs)
@@ -627,6 +634,9 @@ class ScheduleBatch:
627
634
  req_pool_indices = self.alloc_req_slots(bs)
628
635
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
629
636
 
637
+ input_embeds = []
638
+
639
+ pt = 0
630
640
  for i, req in enumerate(reqs):
631
641
  already_computed = (
632
642
  req.extend_logprob_start_len + 1 + req.cached_tokens
@@ -645,6 +655,11 @@ class ScheduleBatch:
645
655
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
646
656
  )
647
657
 
658
+ # If input_embeds are available, store them
659
+ if req.input_embeds is not None:
660
+ # If req.input_embeds is already a list, append its content directly
661
+ input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
662
+
648
663
  # Compute the relative logprob_start_len in an extend batch
649
664
  if req.logprob_start_len >= pre_len:
650
665
  extend_logprob_start_len = min(
@@ -667,6 +682,12 @@ class ScheduleBatch:
667
682
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
668
683
  self.device, non_blocking=True
669
684
  )
685
+ self.input_embeds = (
686
+ torch.tensor(input_embeds).to(self.device, non_blocking=True)
687
+ if input_embeds
688
+ else None
689
+ )
690
+
670
691
  self.out_cache_loc = out_cache_loc
671
692
 
672
693
  self.seq_lens_sum = sum(seq_lens)
@@ -707,7 +728,7 @@ class ScheduleBatch:
707
728
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
708
729
  self,
709
730
  self.model_config.vocab_size,
710
- enable_overlap_schedule=enable_overlap_schedule,
731
+ enable_overlap_schedule=self.enable_overlap,
711
732
  )
712
733
 
713
734
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -724,16 +745,20 @@ class ScheduleBatch:
724
745
  self.merge_batch(running_batch)
725
746
  self.input_ids = input_ids
726
747
  self.out_cache_loc = out_cache_loc
727
- self.extend_num_tokens += running_bs
748
+
749
+ # For overlap scheduler, the output_ids has one step delay
750
+ delta = 0 if self.enable_overlap else -1
728
751
 
729
752
  # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
730
753
  self.prefix_lens.extend(
731
754
  [
732
- len(r.origin_input_ids) + len(r.output_ids) - 1
755
+ len(r.origin_input_ids) + len(r.output_ids) + delta
733
756
  for r in running_batch.reqs
734
757
  ]
735
758
  )
736
759
  self.extend_lens.extend([1] * running_bs)
760
+ self.extend_num_tokens += running_bs
761
+ # TODO (lianmin): Revisit this. It should be seq_len - 1
737
762
  self.extend_logprob_start_lens.extend([0] * running_bs)
738
763
 
739
764
  def check_decode_mem(self):
@@ -897,7 +922,7 @@ class ScheduleBatch:
897
922
  self.seq_lens_sum = 0
898
923
  self.extend_num_tokens = 0
899
924
 
900
- def prepare_for_decode(self, enable_overlap: bool = False):
925
+ def prepare_for_decode(self):
901
926
  self.forward_mode = ForwardMode.DECODE
902
927
 
903
928
  self.input_ids = self.output_ids
@@ -914,7 +939,7 @@ class ScheduleBatch:
914
939
  else:
915
940
  locs = self.seq_lens
916
941
 
917
- if enable_overlap:
942
+ if self.enable_overlap:
918
943
  # Do not use in-place operations in the overlap mode
919
944
  self.req_to_token_pool.write(
920
945
  (self.req_pool_indices, locs), self.out_cache_loc
@@ -1045,6 +1070,7 @@ class ScheduleBatch:
1045
1070
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1046
1071
  lora_paths=[req.lora_path for req in self.reqs],
1047
1072
  sampling_info=self.sampling_info,
1073
+ input_embeds=self.input_embeds,
1048
1074
  )
1049
1075
 
1050
1076
  def copy(self):
@@ -1115,6 +1141,9 @@ class ModelWorkerBatch:
1115
1141
  # Sampling info
1116
1142
  sampling_info: SamplingBatchInfo
1117
1143
 
1144
+ # The input Embeds
1145
+ input_embeds: Optional[torch.tensor] = None
1146
+
1118
1147
 
1119
1148
  @triton.jit
1120
1149
  def write_req_to_token_pool_triton(
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Request scheduler policy"""
17
15
 
18
16
  import os
@@ -1,21 +1,18 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """A scheduler that manages a tensor parallel GPU worker."""
17
15
 
18
- import dataclasses
19
16
  import logging
20
17
  import os
21
18
  import threading
@@ -30,7 +27,7 @@ import torch
30
27
  import zmq
31
28
 
32
29
  from sglang.global_config import global_config
33
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
30
+ from sglang.srt.configs.model_config import ModelConfig
34
31
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
32
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
36
33
  from sglang.srt.managers.io_struct import (
@@ -75,6 +72,7 @@ from sglang.srt.utils import (
75
72
  configure_logger,
76
73
  crash_on_warnings,
77
74
  get_zmq_socket,
75
+ gpu_proc_affinity,
78
76
  kill_parent_process,
79
77
  set_random_seed,
80
78
  suppress_other_loggers,
@@ -84,7 +82,7 @@ from sglang.utils import get_exception_traceback
84
82
  logger = logging.getLogger(__name__)
85
83
 
86
84
  # Test retract decode
87
- test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
85
+ test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
88
86
 
89
87
 
90
88
  class Scheduler:
@@ -304,6 +302,9 @@ class Scheduler:
304
302
  ) / global_config.default_new_token_ratio_decay_steps
305
303
  self.new_token_ratio = self.init_new_token_ratio
306
304
 
305
+ # Tells whether the current running batch is full so that we can skip
306
+ # the check of whether to prefill new requests.
307
+ # This is an optimization to reduce the overhead of the prefill check.
307
308
  self.batch_is_full = False
308
309
 
309
310
  # Init watchdog thread
@@ -466,6 +467,7 @@ class Scheduler:
466
467
  self.token_to_kv_pool,
467
468
  self.tree_cache,
468
469
  self.model_config,
470
+ self.enable_overlap,
469
471
  )
470
472
  idle_batch.prepare_for_idle()
471
473
  return idle_batch
@@ -524,14 +526,23 @@ class Scheduler:
524
526
  recv_req: TokenizedGenerateReqInput,
525
527
  ):
526
528
  if recv_req.session_id is None or recv_req.session_id not in self.sessions:
529
+ # Create a new request
530
+ if recv_req.input_embeds is not None:
531
+ # Generate fake input_ids based on the length of input_embeds
532
+ seq_length = len(recv_req.input_embeds)
533
+ fake_input_ids = [1] * seq_length
534
+ recv_req.input_ids = fake_input_ids
535
+
527
536
  req = Req(
528
537
  recv_req.rid,
529
538
  recv_req.input_text,
530
539
  recv_req.input_ids,
531
540
  recv_req.sampling_params,
532
541
  lora_path=recv_req.lora_path,
542
+ input_embeds=recv_req.input_embeds,
533
543
  )
534
544
  req.tokenizer = self.tokenizer
545
+
535
546
  if recv_req.session_id is not None:
536
547
  req.finished_reason = FINISH_ABORT(
537
548
  f"Invalid request: session id {recv_req.session_id} does not exist"
@@ -539,11 +550,9 @@ class Scheduler:
539
550
  self.waiting_queue.append(req)
540
551
  return
541
552
  else:
542
- # Handle sessions
553
+ # Create a new request from a previsou session
543
554
  session = self.sessions[recv_req.session_id]
544
- req, new_session_id = session.create_req(recv_req, self.tokenizer)
545
- del self.sessions[recv_req.session_id]
546
- self.sessions[new_session_id] = session
555
+ req = session.create_req(recv_req, self.tokenizer)
547
556
  if isinstance(req.finished_reason, FINISH_ABORT):
548
557
  self.waiting_queue.append(req)
549
558
  return
@@ -723,40 +732,30 @@ class Scheduler:
723
732
 
724
733
  def get_next_batch_to_run(self):
725
734
  # Merge the prefill batch into the running batch
726
- if (
727
- self.last_batch
728
- and not self.last_batch.forward_mode.is_decode()
729
- and not self.last_batch.is_empty()
730
- ):
735
+ if self.last_batch and self.last_batch.forward_mode.is_extend():
731
736
  if self.being_chunked_req:
737
+ # Move the chunked request out of the batch
732
738
  self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
733
739
  self.tree_cache.cache_unfinished_req(self.being_chunked_req)
734
- # Inflight request keeps its rid but will get a new req_pool_idx.
740
+ # Inflight request keeps its rid but will get a new req_pool_idx
735
741
  self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
736
742
  self.batch_is_full = False
743
+
737
744
  if not self.last_batch.is_empty():
738
745
  if self.running_batch is None:
739
746
  self.running_batch = self.last_batch
740
747
  else:
741
748
  self.running_batch.merge_batch(self.last_batch)
742
749
 
743
- # Prefill first
750
+ # Run prefill first if possible
744
751
  new_batch = self.get_new_batch_prefill()
745
752
  if new_batch is not None:
746
753
  return new_batch
747
754
 
748
- # Check memory
749
- if self.running_batch is None:
750
- return
751
-
752
755
  # Run decode
753
- before_bs = self.running_batch.batch_size()
754
- self.update_running_batch()
755
- if not self.running_batch:
756
- self.batch_is_full = False
756
+ if self.running_batch is None:
757
757
  return None
758
- if before_bs != self.running_batch.batch_size():
759
- self.batch_is_full = False
758
+ self.running_batch = self.update_running_batch(self.running_batch)
760
759
  return self.running_batch
761
760
 
762
761
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
@@ -852,14 +851,20 @@ class Scheduler:
852
851
  self.token_to_kv_pool,
853
852
  self.tree_cache,
854
853
  self.model_config,
854
+ self.enable_overlap,
855
855
  )
856
- new_batch.prepare_for_extend(self.enable_overlap)
856
+ new_batch.prepare_for_extend()
857
857
 
858
858
  # Mixed-style chunked prefill
859
- if self.is_mixed_chunk and self.running_batch is not None:
859
+ if (
860
+ self.is_mixed_chunk
861
+ and self.running_batch is not None
862
+ and not (new_batch.return_logprob or self.running_batch.return_logprob)
863
+ ):
864
+ # TODO (lianmin): support return_logprob + mixed chunked prefill
860
865
  self.running_batch.filter_batch()
861
866
  if not self.running_batch.is_empty():
862
- self.running_batch.prepare_for_decode(self.enable_overlap)
867
+ self.running_batch.prepare_for_decode()
863
868
  new_batch.mix_with_running(self.running_batch)
864
869
  new_batch.decoding_reqs = self.running_batch.reqs
865
870
  self.running_batch = None
@@ -868,15 +873,16 @@ class Scheduler:
868
873
 
869
874
  return new_batch
870
875
 
871
- def update_running_batch(self):
876
+ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
872
877
  """Update the current running decoding batch."""
873
878
  global test_retract
874
- batch = self.running_batch
879
+
880
+ initial_bs = batch.batch_size()
875
881
 
876
882
  batch.filter_batch()
877
883
  if batch.is_empty():
878
- self.running_batch = None
879
- return
884
+ self.batch_is_full = False
885
+ return None
880
886
 
881
887
  # Check if decode out of memory
882
888
  if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
@@ -902,11 +908,15 @@ class Scheduler:
902
908
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
903
909
  self.waiting_queue.extend(jump_forward_reqs)
904
910
  if batch.is_empty():
905
- self.running_batch = None
906
- return
911
+ self.batch_is_full = False
912
+ return None
913
+
914
+ if batch.batch_size() < initial_bs:
915
+ self.batch_is_full = False
907
916
 
908
917
  # Update batch tensors
909
- batch.prepare_for_decode(self.enable_overlap)
918
+ batch.prepare_for_decode()
919
+ return batch
910
920
 
911
921
  def run_batch(self, batch: ScheduleBatch):
912
922
  """Run a batch."""
@@ -981,8 +991,13 @@ class Scheduler:
981
991
  if req.is_retracted:
982
992
  continue
983
993
 
994
+ if self.is_mixed_chunk and self.enable_overlap and req.finished():
995
+ # Free the one delayed token for the mixed decode batch
996
+ j = len(batch.out_cache_loc) - len(batch.reqs) + i
997
+ self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
998
+ continue
999
+
984
1000
  if req.is_being_chunked <= 0:
985
- # Inflight reqs' prefill is not finished
986
1001
  req.completion_tokens_wo_jump_forward += 1
987
1002
  req.output_ids.append(next_token_id)
988
1003
  req.check_finished()
@@ -992,14 +1007,15 @@ class Scheduler:
992
1007
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
993
1008
  self.tree_cache.cache_unfinished_req(req)
994
1009
 
995
- if req.grammar is not None:
996
- req.grammar.accept_token(next_token_id)
997
-
998
1010
  if req.return_logprob:
999
1011
  logprob_pt += self.add_logprob_return_values(
1000
1012
  i, req, logprob_pt, next_token_ids, logits_output
1001
1013
  )
1014
+
1015
+ if req.grammar is not None:
1016
+ req.grammar.accept_token(next_token_id)
1002
1017
  else:
1018
+ # Inflight reqs' prefill is not finished
1003
1019
  req.is_being_chunked -= 1
1004
1020
 
1005
1021
  if batch.next_batch_sampling_info:
@@ -1017,18 +1033,18 @@ class Scheduler:
1017
1033
  continue
1018
1034
 
1019
1035
  req.embedding = embeddings[i]
1020
- if req.is_being_chunked > 0:
1021
- req.is_being_chunked -= 1
1022
- else:
1023
- # Inflight reqs' prefill is not finished
1024
- # dummy output token for embedding models
1036
+ if req.is_being_chunked <= 0:
1037
+ # Dummy output token for embedding models
1025
1038
  req.output_ids.append(0)
1026
1039
  req.check_finished()
1027
1040
 
1028
- if req.finished():
1029
- self.tree_cache.cache_finished_req(req)
1041
+ if req.finished():
1042
+ self.tree_cache.cache_finished_req(req)
1043
+ else:
1044
+ self.tree_cache.cache_unfinished_req(req)
1030
1045
  else:
1031
- self.tree_cache.cache_unfinished_req(req)
1046
+ # Inflight reqs' prefill is not finished
1047
+ req.is_being_chunked -= 1
1032
1048
 
1033
1049
  self.stream_output(batch.reqs)
1034
1050
 
@@ -1056,6 +1072,7 @@ class Scheduler:
1056
1072
  continue
1057
1073
 
1058
1074
  if self.enable_overlap and req.finished():
1075
+ # Free the one delayed token
1059
1076
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
1060
1077
  continue
1061
1078
 
@@ -1063,9 +1080,6 @@ class Scheduler:
1063
1080
  req.output_ids.append(next_token_id)
1064
1081
  req.check_finished()
1065
1082
 
1066
- if req.grammar is not None:
1067
- req.grammar.accept_token(next_token_id)
1068
-
1069
1083
  if req.finished():
1070
1084
  self.tree_cache.cache_finished_req(req)
1071
1085
 
@@ -1076,6 +1090,9 @@ class Scheduler:
1076
1090
  if req.top_logprobs_num > 0:
1077
1091
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
1078
1092
 
1093
+ if req.grammar is not None:
1094
+ req.grammar.accept_token(next_token_id)
1095
+
1079
1096
  if batch.next_batch_sampling_info:
1080
1097
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1081
1098
  torch.cuda.current_stream().synchronize()
@@ -1179,7 +1196,6 @@ class Scheduler:
1179
1196
  output_skip_special_tokens = []
1180
1197
  output_spaces_between_special_tokens = []
1181
1198
  output_no_stop_trim = []
1182
- output_session_ids = []
1183
1199
  else: # embedding or reward model
1184
1200
  output_embeddings = []
1185
1201
 
@@ -1207,7 +1223,6 @@ class Scheduler:
1207
1223
  req.sampling_params.spaces_between_special_tokens
1208
1224
  )
1209
1225
  output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1210
- output_session_ids.append(req.session_id)
1211
1226
 
1212
1227
  meta_info = {
1213
1228
  "prompt_tokens": len(req.origin_input_ids),
@@ -1258,7 +1273,6 @@ class Scheduler:
1258
1273
  output_meta_info,
1259
1274
  output_finished_reason,
1260
1275
  output_no_stop_trim,
1261
- output_session_ids,
1262
1276
  )
1263
1277
  )
1264
1278
  else: # embedding or reward model
@@ -1389,9 +1403,12 @@ def run_scheduler_process(
1389
1403
  dp_rank: Optional[int],
1390
1404
  pipe_writer,
1391
1405
  ):
1406
+ # set cpu affinity to this gpu process
1407
+ gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1408
+
1392
1409
  # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
1393
- if dp_rank is None:
1394
- dp_rank = int(os.getenv("DP_RANK", -1))
1410
+ if dp_rank is None and "DP_RANK" in os.environ:
1411
+ dp_rank = int(os.environ["DP_RANK"])
1395
1412
 
1396
1413
  if dp_rank is None:
1397
1414
  configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -1402,7 +1419,9 @@ def run_scheduler_process(
1402
1419
 
1403
1420
  try:
1404
1421
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1405
- pipe_writer.send("ready")
1422
+ pipe_writer.send(
1423
+ {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
1424
+ )
1406
1425
  if scheduler.enable_overlap:
1407
1426
  scheduler.event_loop_overlap()
1408
1427
  else:
@@ -1,15 +1,14 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
- http://www.apache.org/licenses/LICENSE-2.0
7
- Unless required by applicable law or agreed to in writing, software
8
- distributed under the License is distributed on an "AS IS" BASIS,
9
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- See the License for the specific language governing permissions and
11
- limitations under the License.
12
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ # ==============================================================================
13
12
 
14
13
  import copy
15
14
  import uuid
@@ -27,13 +26,13 @@ class Session:
27
26
  self.reqs: List[Req] = []
28
27
 
29
28
  def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
30
- # renew session id
31
- self.session_id = uuid.uuid4().hex
32
29
  if req.session_rid is not None:
33
30
  while len(self.reqs) > 0:
34
31
  if self.reqs[-1].rid == req.session_rid:
35
32
  break
36
33
  self.reqs = self.reqs[:-1]
34
+ else:
35
+ self.reqs = []
37
36
  if len(self.reqs) > 0:
38
37
  input_ids = (
39
38
  self.reqs[-1].origin_input_ids
@@ -59,4 +58,4 @@ class Session:
59
58
  )
60
59
  else:
61
60
  self.reqs.append(new_req)
62
- return new_req, self.session_id
61
+ return new_req