sglang 0.3.6__py3-none-any.whl → 0.3.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_one_batch.py +4 -7
  4. sglang/bench_one_batch_server.py +2 -2
  5. sglang/bench_serving.py +75 -26
  6. sglang/check_env.py +7 -1
  7. sglang/lang/backend/base_backend.py +1 -1
  8. sglang/lang/backend/runtime_endpoint.py +2 -2
  9. sglang/lang/tracer.py +1 -1
  10. sglang/launch_server.py +0 -3
  11. sglang/srt/configs/model_config.py +15 -20
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +13 -15
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +38 -57
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +13 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +14 -7
  21. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  22. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  23. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  24. sglang/srt/layers/custom_op_util.py +13 -14
  25. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  26. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  27. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  28. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  29. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  30. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  31. sglang/srt/layers/layernorm.py +13 -15
  32. sglang/srt/layers/logits_processor.py +13 -15
  33. sglang/srt/layers/quantization/__init__.py +77 -17
  34. sglang/srt/layers/radix_attention.py +13 -15
  35. sglang/srt/layers/rotary_embedding.py +13 -13
  36. sglang/srt/layers/sampler.py +1 -1
  37. sglang/srt/lora/lora.py +13 -14
  38. sglang/srt/lora/lora_config.py +13 -14
  39. sglang/srt/lora/lora_manager.py +22 -24
  40. sglang/srt/managers/data_parallel_controller.py +25 -19
  41. sglang/srt/managers/detokenizer_manager.py +13 -18
  42. sglang/srt/managers/image_processor.py +6 -9
  43. sglang/srt/managers/io_struct.py +43 -28
  44. sglang/srt/managers/schedule_batch.py +92 -27
  45. sglang/srt/managers/schedule_policy.py +13 -15
  46. sglang/srt/managers/scheduler.py +94 -72
  47. sglang/srt/managers/session_controller.py +29 -19
  48. sglang/srt/managers/tokenizer_manager.py +29 -22
  49. sglang/srt/managers/tp_worker.py +13 -15
  50. sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
  51. sglang/srt/metrics/collector.py +13 -15
  52. sglang/srt/metrics/func_timer.py +13 -15
  53. sglang/srt/mm_utils.py +13 -14
  54. sglang/srt/model_executor/cuda_graph_runner.py +20 -19
  55. sglang/srt/model_executor/forward_batch_info.py +19 -17
  56. sglang/srt/model_executor/model_runner.py +42 -30
  57. sglang/srt/models/chatglm.py +15 -16
  58. sglang/srt/models/commandr.py +15 -16
  59. sglang/srt/models/dbrx.py +15 -16
  60. sglang/srt/models/deepseek.py +15 -15
  61. sglang/srt/models/deepseek_v2.py +15 -15
  62. sglang/srt/models/exaone.py +14 -15
  63. sglang/srt/models/gemma.py +14 -14
  64. sglang/srt/models/gemma2.py +24 -19
  65. sglang/srt/models/gemma2_reward.py +13 -14
  66. sglang/srt/models/gpt_bigcode.py +14 -14
  67. sglang/srt/models/grok.py +15 -15
  68. sglang/srt/models/internlm2.py +13 -15
  69. sglang/srt/models/internlm2_reward.py +13 -14
  70. sglang/srt/models/llama.py +21 -21
  71. sglang/srt/models/llama_classification.py +13 -14
  72. sglang/srt/models/llama_reward.py +13 -14
  73. sglang/srt/models/llava.py +20 -16
  74. sglang/srt/models/llavavid.py +13 -15
  75. sglang/srt/models/minicpm.py +13 -15
  76. sglang/srt/models/minicpm3.py +13 -15
  77. sglang/srt/models/mistral.py +13 -15
  78. sglang/srt/models/mixtral.py +15 -15
  79. sglang/srt/models/mixtral_quant.py +14 -14
  80. sglang/srt/models/olmo.py +21 -19
  81. sglang/srt/models/olmoe.py +23 -20
  82. sglang/srt/models/qwen.py +14 -14
  83. sglang/srt/models/qwen2.py +22 -19
  84. sglang/srt/models/qwen2_moe.py +17 -18
  85. sglang/srt/models/stablelm.py +18 -16
  86. sglang/srt/models/torch_native_llama.py +15 -17
  87. sglang/srt/models/xverse.py +13 -14
  88. sglang/srt/models/xverse_moe.py +15 -16
  89. sglang/srt/models/yivl.py +13 -15
  90. sglang/srt/openai_api/adapter.py +13 -15
  91. sglang/srt/openai_api/protocol.py +13 -15
  92. sglang/srt/sampling/sampling_batch_info.py +4 -1
  93. sglang/srt/sampling/sampling_params.py +13 -15
  94. sglang/srt/server.py +60 -34
  95. sglang/srt/server_args.py +22 -22
  96. sglang/srt/utils.py +208 -19
  97. sglang/test/few_shot_gsm8k.py +8 -4
  98. sglang/test/runners.py +13 -14
  99. sglang/test/test_utils.py +2 -2
  100. sglang/version.py +1 -1
  101. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/LICENSE +1 -1
  102. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/METADATA +25 -15
  103. sglang-0.3.6.post2.dist-info/RECORD +164 -0
  104. sglang/srt/layers/fused_moe/__init__.py +0 -1
  105. sglang-0.3.6.dist-info/RECORD +0 -161
  106. /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
  107. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/WHEEL +0 -0
  108. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Store information about requests and batches.
18
16
 
@@ -33,6 +31,7 @@ import dataclasses
33
31
  import logging
34
32
  from typing import List, Optional, Tuple, Union
35
33
 
34
+ import numpy as np
36
35
  import torch
37
36
  import triton
38
37
  import triton.language as tl
@@ -169,6 +168,30 @@ class ImageInputs:
169
168
 
170
169
  return ret
171
170
 
171
+ def merge(self, other, vocab_size):
172
+ assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
173
+ self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
174
+ self.image_hashes += other.image_hashes
175
+
176
+ self.pad_values = [
177
+ (self.image_hashes) % vocab_size,
178
+ (self.image_hashes >> 16) % vocab_size,
179
+ (self.image_hashes >> 32) % vocab_size,
180
+ (self.image_hashes >> 64) % vocab_size,
181
+ ]
182
+
183
+ optional_args = [
184
+ "image_sizes",
185
+ "image_offsets",
186
+ # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
187
+ "aspect_ratio_ids",
188
+ "aspect_ratio_mask",
189
+ "image_grid_thws",
190
+ ]
191
+ for arg in optional_args:
192
+ if getattr(self, arg, None) is not None:
193
+ setattr(self, arg, getattr(self, arg) + getattr(other, arg))
194
+
172
195
 
173
196
  class Req:
174
197
  """The input and output status of a request."""
@@ -179,13 +202,19 @@ class Req:
179
202
  origin_input_text: str,
180
203
  origin_input_ids: Tuple[int],
181
204
  sampling_params: SamplingParams,
205
+ origin_input_ids_unpadded: Optional[Tuple[int]] = None,
182
206
  lora_path: Optional[str] = None,
207
+ input_embeds: Optional[List[List[float]]] = None,
183
208
  session_id: Optional[str] = None,
184
209
  ):
185
210
  # Input and output info
186
211
  self.rid = rid
187
212
  self.origin_input_text = origin_input_text
188
- self.origin_input_ids_unpadded = origin_input_ids # Before image padding
213
+ self.origin_input_ids_unpadded = (
214
+ origin_input_ids_unpadded
215
+ if origin_input_ids_unpadded
216
+ else origin_input_ids # Before image padding
217
+ )
189
218
  self.origin_input_ids = origin_input_ids
190
219
  self.output_ids = [] # Each decode stage's output ids
191
220
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
@@ -193,6 +222,7 @@ class Req:
193
222
 
194
223
  self.sampling_params = sampling_params
195
224
  self.lora_path = lora_path
225
+ self.input_embeds = input_embeds
196
226
 
197
227
  # Memory pool info
198
228
  self.req_pool_idx = None
@@ -260,6 +290,12 @@ class Req:
260
290
  # The number of cached tokens, that were already cached in the KV cache
261
291
  self.cached_tokens = 0
262
292
 
293
+ def extend_image_inputs(self, image_inputs, vocab_size):
294
+ if self.image_inputs is None:
295
+ self.image_inputs = image_inputs
296
+ else:
297
+ self.image_inputs.merge(image_inputs, vocab_size)
298
+
263
299
  # whether request reached finished condition
264
300
  def finished(self) -> bool:
265
301
  return self.finished_reason is not None
@@ -439,14 +475,18 @@ class ScheduleBatch:
439
475
  token_to_kv_pool: BaseTokenToKVPool = None
440
476
  tree_cache: BasePrefixCache = None
441
477
 
442
- # For utility
478
+ # Batch configs
443
479
  model_config: ModelConfig = None
444
480
  forward_mode: ForwardMode = None
481
+ enable_overlap: bool = False
482
+
483
+ # Sampling info
445
484
  sampling_info: SamplingBatchInfo = None
446
485
  next_batch_sampling_info: SamplingBatchInfo = None
447
486
 
448
487
  # Batched arguments to model runner
449
488
  input_ids: torch.Tensor = None
489
+ input_embeds: torch.Tensor = None
450
490
  req_pool_indices: torch.Tensor = None
451
491
  seq_lens: torch.Tensor = None
452
492
  # The output locations of the KV cache
@@ -469,6 +509,7 @@ class ScheduleBatch:
469
509
  extend_lens: List[int] = None
470
510
  extend_num_tokens: int = None
471
511
  decoding_reqs: List[Req] = None
512
+ extend_logprob_start_lens: List[int] = None
472
513
 
473
514
  # For encoder-decoder
474
515
  encoder_cached: Optional[List[bool]] = None
@@ -489,10 +530,11 @@ class ScheduleBatch:
489
530
  def init_new(
490
531
  cls,
491
532
  reqs: List[Req],
492
- req_to_token_pool,
493
- token_to_kv_pool,
494
- tree_cache,
495
- model_config,
533
+ req_to_token_pool: ReqToTokenPool,
534
+ token_to_kv_pool: ReqToTokenPool,
535
+ tree_cache: BasePrefixCache,
536
+ model_config: ModelConfig,
537
+ enable_overlap: bool,
496
538
  ):
497
539
  return cls(
498
540
  reqs=reqs,
@@ -500,6 +542,7 @@ class ScheduleBatch:
500
542
  token_to_kv_pool=token_to_kv_pool,
501
543
  tree_cache=tree_cache,
502
544
  model_config=model_config,
545
+ enable_overlap=enable_overlap,
503
546
  return_logprob=any(req.return_logprob for req in reqs),
504
547
  has_stream=any(req.stream for req in reqs),
505
548
  has_grammar=any(req.grammar for req in reqs),
@@ -613,7 +656,7 @@ class ScheduleBatch:
613
656
 
614
657
  assert len(self.out_cache_loc) == self.extend_num_tokens
615
658
 
616
- def prepare_for_extend(self, enable_overlap_schedule: bool = False):
659
+ def prepare_for_extend(self):
617
660
  self.forward_mode = ForwardMode.EXTEND
618
661
 
619
662
  bs = len(self.reqs)
@@ -627,6 +670,9 @@ class ScheduleBatch:
627
670
  req_pool_indices = self.alloc_req_slots(bs)
628
671
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
629
672
 
673
+ input_embeds = []
674
+
675
+ pt = 0
630
676
  for i, req in enumerate(reqs):
631
677
  already_computed = (
632
678
  req.extend_logprob_start_len + 1 + req.cached_tokens
@@ -645,6 +691,11 @@ class ScheduleBatch:
645
691
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
646
692
  )
647
693
 
694
+ # If input_embeds are available, store them
695
+ if req.input_embeds is not None:
696
+ # If req.input_embeds is already a list, append its content directly
697
+ input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
698
+
648
699
  # Compute the relative logprob_start_len in an extend batch
649
700
  if req.logprob_start_len >= pre_len:
650
701
  extend_logprob_start_len = min(
@@ -667,6 +718,12 @@ class ScheduleBatch:
667
718
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
668
719
  self.device, non_blocking=True
669
720
  )
721
+ self.input_embeds = (
722
+ torch.tensor(input_embeds).to(self.device, non_blocking=True)
723
+ if input_embeds
724
+ else None
725
+ )
726
+
670
727
  self.out_cache_loc = out_cache_loc
671
728
 
672
729
  self.seq_lens_sum = sum(seq_lens)
@@ -707,7 +764,7 @@ class ScheduleBatch:
707
764
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
708
765
  self,
709
766
  self.model_config.vocab_size,
710
- enable_overlap_schedule=enable_overlap_schedule,
767
+ enable_overlap_schedule=self.enable_overlap,
711
768
  )
712
769
 
713
770
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -724,16 +781,20 @@ class ScheduleBatch:
724
781
  self.merge_batch(running_batch)
725
782
  self.input_ids = input_ids
726
783
  self.out_cache_loc = out_cache_loc
727
- self.extend_num_tokens += running_bs
784
+
785
+ # For overlap scheduler, the output_ids has one step delay
786
+ delta = 0 if self.enable_overlap else -1
728
787
 
729
788
  # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
730
789
  self.prefix_lens.extend(
731
790
  [
732
- len(r.origin_input_ids) + len(r.output_ids) - 1
791
+ len(r.origin_input_ids) + len(r.output_ids) + delta
733
792
  for r in running_batch.reqs
734
793
  ]
735
794
  )
736
795
  self.extend_lens.extend([1] * running_bs)
796
+ self.extend_num_tokens += running_bs
797
+ # TODO (lianmin): Revisit this. It should be seq_len - 1
737
798
  self.extend_logprob_start_lens.extend([0] * running_bs)
738
799
 
739
800
  def check_decode_mem(self):
@@ -897,7 +958,7 @@ class ScheduleBatch:
897
958
  self.seq_lens_sum = 0
898
959
  self.extend_num_tokens = 0
899
960
 
900
- def prepare_for_decode(self, enable_overlap: bool = False):
961
+ def prepare_for_decode(self):
901
962
  self.forward_mode = ForwardMode.DECODE
902
963
 
903
964
  self.input_ids = self.output_ids
@@ -914,7 +975,7 @@ class ScheduleBatch:
914
975
  else:
915
976
  locs = self.seq_lens
916
977
 
917
- if enable_overlap:
978
+ if self.enable_overlap:
918
979
  # Do not use in-place operations in the overlap mode
919
980
  self.req_to_token_pool.write(
920
981
  (self.req_pool_indices, locs), self.out_cache_loc
@@ -1045,6 +1106,7 @@ class ScheduleBatch:
1045
1106
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1046
1107
  lora_paths=[req.lora_path for req in self.reqs],
1047
1108
  sampling_info=self.sampling_info,
1109
+ input_embeds=self.input_embeds,
1048
1110
  )
1049
1111
 
1050
1112
  def copy(self):
@@ -1115,6 +1177,9 @@ class ModelWorkerBatch:
1115
1177
  # Sampling info
1116
1178
  sampling_info: SamplingBatchInfo
1117
1179
 
1180
+ # The input Embeds
1181
+ input_embeds: Optional[torch.tensor] = None
1182
+
1118
1183
 
1119
1184
  @triton.jit
1120
1185
  def write_req_to_token_pool_triton(
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Request scheduler policy"""
17
15
 
18
16
  import os
@@ -1,21 +1,18 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """A scheduler that manages a tensor parallel GPU worker."""
17
15
 
18
- import dataclasses
19
16
  import logging
20
17
  import os
21
18
  import threading
@@ -30,7 +27,7 @@ import torch
30
27
  import zmq
31
28
 
32
29
  from sglang.global_config import global_config
33
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
30
+ from sglang.srt.configs.model_config import ModelConfig
34
31
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
32
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
36
33
  from sglang.srt.managers.io_struct import (
@@ -74,8 +71,10 @@ from sglang.srt.utils import (
74
71
  broadcast_pyobj,
75
72
  configure_logger,
76
73
  crash_on_warnings,
74
+ get_bool_env_var,
77
75
  get_zmq_socket,
78
76
  kill_parent_process,
77
+ set_gpu_proc_affinity,
79
78
  set_random_seed,
80
79
  suppress_other_loggers,
81
80
  )
@@ -84,7 +83,7 @@ from sglang.utils import get_exception_traceback
84
83
  logger = logging.getLogger(__name__)
85
84
 
86
85
  # Test retract decode
87
- test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
86
+ test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
88
87
 
89
88
 
90
89
  class Scheduler:
@@ -304,6 +303,9 @@ class Scheduler:
304
303
  ) / global_config.default_new_token_ratio_decay_steps
305
304
  self.new_token_ratio = self.init_new_token_ratio
306
305
 
306
+ # Tells whether the current running batch is full so that we can skip
307
+ # the check of whether to prefill new requests.
308
+ # This is an optimization to reduce the overhead of the prefill check.
307
309
  self.batch_is_full = False
308
310
 
309
311
  # Init watchdog thread
@@ -466,6 +468,7 @@ class Scheduler:
466
468
  self.token_to_kv_pool,
467
469
  self.tree_cache,
468
470
  self.model_config,
471
+ self.enable_overlap,
469
472
  )
470
473
  idle_batch.prepare_for_idle()
471
474
  return idle_batch
@@ -524,14 +527,23 @@ class Scheduler:
524
527
  recv_req: TokenizedGenerateReqInput,
525
528
  ):
526
529
  if recv_req.session_id is None or recv_req.session_id not in self.sessions:
530
+ # Create a new request
531
+ if recv_req.input_embeds is not None:
532
+ # Generate fake input_ids based on the length of input_embeds
533
+ seq_length = len(recv_req.input_embeds)
534
+ fake_input_ids = [1] * seq_length
535
+ recv_req.input_ids = fake_input_ids
536
+
527
537
  req = Req(
528
538
  recv_req.rid,
529
539
  recv_req.input_text,
530
540
  recv_req.input_ids,
531
541
  recv_req.sampling_params,
532
542
  lora_path=recv_req.lora_path,
543
+ input_embeds=recv_req.input_embeds,
533
544
  )
534
545
  req.tokenizer = self.tokenizer
546
+
535
547
  if recv_req.session_id is not None:
536
548
  req.finished_reason = FINISH_ABORT(
537
549
  f"Invalid request: session id {recv_req.session_id} does not exist"
@@ -539,23 +551,22 @@ class Scheduler:
539
551
  self.waiting_queue.append(req)
540
552
  return
541
553
  else:
542
- # Handle sessions
554
+ # Create a new request from a previsou session
543
555
  session = self.sessions[recv_req.session_id]
544
- req, new_session_id = session.create_req(recv_req, self.tokenizer)
545
- del self.sessions[recv_req.session_id]
546
- self.sessions[new_session_id] = session
556
+ req = session.create_req(recv_req, self.tokenizer)
547
557
  if isinstance(req.finished_reason, FINISH_ABORT):
548
558
  self.waiting_queue.append(req)
549
559
  return
550
560
 
551
561
  # Image inputs
552
562
  if recv_req.image_inputs is not None:
553
- req.image_inputs = ImageInputs.from_dict(
563
+ image_inputs = ImageInputs.from_dict(
554
564
  recv_req.image_inputs, self.model_config.vocab_size
555
565
  )
556
566
  req.origin_input_ids = self.pad_input_ids_func(
557
- req.origin_input_ids_unpadded, req.image_inputs
567
+ req.origin_input_ids, image_inputs
558
568
  )
569
+ req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
559
570
 
560
571
  if len(req.origin_input_ids) > self.max_req_input_len:
561
572
  req.finished_reason = FINISH_ABORT(
@@ -723,40 +734,30 @@ class Scheduler:
723
734
 
724
735
  def get_next_batch_to_run(self):
725
736
  # Merge the prefill batch into the running batch
726
- if (
727
- self.last_batch
728
- and not self.last_batch.forward_mode.is_decode()
729
- and not self.last_batch.is_empty()
730
- ):
737
+ if self.last_batch and self.last_batch.forward_mode.is_extend():
731
738
  if self.being_chunked_req:
739
+ # Move the chunked request out of the batch
732
740
  self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
733
741
  self.tree_cache.cache_unfinished_req(self.being_chunked_req)
734
- # Inflight request keeps its rid but will get a new req_pool_idx.
742
+ # Inflight request keeps its rid but will get a new req_pool_idx
735
743
  self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
736
744
  self.batch_is_full = False
745
+
737
746
  if not self.last_batch.is_empty():
738
747
  if self.running_batch is None:
739
748
  self.running_batch = self.last_batch
740
749
  else:
741
750
  self.running_batch.merge_batch(self.last_batch)
742
751
 
743
- # Prefill first
752
+ # Run prefill first if possible
744
753
  new_batch = self.get_new_batch_prefill()
745
754
  if new_batch is not None:
746
755
  return new_batch
747
756
 
748
- # Check memory
749
- if self.running_batch is None:
750
- return
751
-
752
757
  # Run decode
753
- before_bs = self.running_batch.batch_size()
754
- self.update_running_batch()
755
- if not self.running_batch:
756
- self.batch_is_full = False
758
+ if self.running_batch is None:
757
759
  return None
758
- if before_bs != self.running_batch.batch_size():
759
- self.batch_is_full = False
760
+ self.running_batch = self.update_running_batch(self.running_batch)
760
761
  return self.running_batch
761
762
 
762
763
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
@@ -852,14 +853,20 @@ class Scheduler:
852
853
  self.token_to_kv_pool,
853
854
  self.tree_cache,
854
855
  self.model_config,
856
+ self.enable_overlap,
855
857
  )
856
- new_batch.prepare_for_extend(self.enable_overlap)
858
+ new_batch.prepare_for_extend()
857
859
 
858
860
  # Mixed-style chunked prefill
859
- if self.is_mixed_chunk and self.running_batch is not None:
861
+ if (
862
+ self.is_mixed_chunk
863
+ and self.running_batch is not None
864
+ and not (new_batch.return_logprob or self.running_batch.return_logprob)
865
+ ):
866
+ # TODO (lianmin): support return_logprob + mixed chunked prefill
860
867
  self.running_batch.filter_batch()
861
868
  if not self.running_batch.is_empty():
862
- self.running_batch.prepare_for_decode(self.enable_overlap)
869
+ self.running_batch.prepare_for_decode()
863
870
  new_batch.mix_with_running(self.running_batch)
864
871
  new_batch.decoding_reqs = self.running_batch.reqs
865
872
  self.running_batch = None
@@ -868,15 +875,16 @@ class Scheduler:
868
875
 
869
876
  return new_batch
870
877
 
871
- def update_running_batch(self):
878
+ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
872
879
  """Update the current running decoding batch."""
873
880
  global test_retract
874
- batch = self.running_batch
881
+
882
+ initial_bs = batch.batch_size()
875
883
 
876
884
  batch.filter_batch()
877
885
  if batch.is_empty():
878
- self.running_batch = None
879
- return
886
+ self.batch_is_full = False
887
+ return None
880
888
 
881
889
  # Check if decode out of memory
882
890
  if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
@@ -902,11 +910,15 @@ class Scheduler:
902
910
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
903
911
  self.waiting_queue.extend(jump_forward_reqs)
904
912
  if batch.is_empty():
905
- self.running_batch = None
906
- return
913
+ self.batch_is_full = False
914
+ return None
915
+
916
+ if batch.batch_size() < initial_bs:
917
+ self.batch_is_full = False
907
918
 
908
919
  # Update batch tensors
909
- batch.prepare_for_decode(self.enable_overlap)
920
+ batch.prepare_for_decode()
921
+ return batch
910
922
 
911
923
  def run_batch(self, batch: ScheduleBatch):
912
924
  """Run a batch."""
@@ -981,8 +993,13 @@ class Scheduler:
981
993
  if req.is_retracted:
982
994
  continue
983
995
 
996
+ if self.is_mixed_chunk and self.enable_overlap and req.finished():
997
+ # Free the one delayed token for the mixed decode batch
998
+ j = len(batch.out_cache_loc) - len(batch.reqs) + i
999
+ self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
1000
+ continue
1001
+
984
1002
  if req.is_being_chunked <= 0:
985
- # Inflight reqs' prefill is not finished
986
1003
  req.completion_tokens_wo_jump_forward += 1
987
1004
  req.output_ids.append(next_token_id)
988
1005
  req.check_finished()
@@ -992,14 +1009,15 @@ class Scheduler:
992
1009
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
993
1010
  self.tree_cache.cache_unfinished_req(req)
994
1011
 
995
- if req.grammar is not None:
996
- req.grammar.accept_token(next_token_id)
997
-
998
1012
  if req.return_logprob:
999
1013
  logprob_pt += self.add_logprob_return_values(
1000
1014
  i, req, logprob_pt, next_token_ids, logits_output
1001
1015
  )
1016
+
1017
+ if req.grammar is not None:
1018
+ req.grammar.accept_token(next_token_id)
1002
1019
  else:
1020
+ # Inflight reqs' prefill is not finished
1003
1021
  req.is_being_chunked -= 1
1004
1022
 
1005
1023
  if batch.next_batch_sampling_info:
@@ -1017,18 +1035,18 @@ class Scheduler:
1017
1035
  continue
1018
1036
 
1019
1037
  req.embedding = embeddings[i]
1020
- if req.is_being_chunked > 0:
1021
- req.is_being_chunked -= 1
1022
- else:
1023
- # Inflight reqs' prefill is not finished
1024
- # dummy output token for embedding models
1038
+ if req.is_being_chunked <= 0:
1039
+ # Dummy output token for embedding models
1025
1040
  req.output_ids.append(0)
1026
1041
  req.check_finished()
1027
1042
 
1028
- if req.finished():
1029
- self.tree_cache.cache_finished_req(req)
1043
+ if req.finished():
1044
+ self.tree_cache.cache_finished_req(req)
1045
+ else:
1046
+ self.tree_cache.cache_unfinished_req(req)
1030
1047
  else:
1031
- self.tree_cache.cache_unfinished_req(req)
1048
+ # Inflight reqs' prefill is not finished
1049
+ req.is_being_chunked -= 1
1032
1050
 
1033
1051
  self.stream_output(batch.reqs)
1034
1052
 
@@ -1056,6 +1074,7 @@ class Scheduler:
1056
1074
  continue
1057
1075
 
1058
1076
  if self.enable_overlap and req.finished():
1077
+ # Free the one delayed token
1059
1078
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
1060
1079
  continue
1061
1080
 
@@ -1063,9 +1082,6 @@ class Scheduler:
1063
1082
  req.output_ids.append(next_token_id)
1064
1083
  req.check_finished()
1065
1084
 
1066
- if req.grammar is not None:
1067
- req.grammar.accept_token(next_token_id)
1068
-
1069
1085
  if req.finished():
1070
1086
  self.tree_cache.cache_finished_req(req)
1071
1087
 
@@ -1076,6 +1092,9 @@ class Scheduler:
1076
1092
  if req.top_logprobs_num > 0:
1077
1093
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
1078
1094
 
1095
+ if req.grammar is not None:
1096
+ req.grammar.accept_token(next_token_id)
1097
+
1079
1098
  if batch.next_batch_sampling_info:
1080
1099
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1081
1100
  torch.cuda.current_stream().synchronize()
@@ -1179,7 +1198,6 @@ class Scheduler:
1179
1198
  output_skip_special_tokens = []
1180
1199
  output_spaces_between_special_tokens = []
1181
1200
  output_no_stop_trim = []
1182
- output_session_ids = []
1183
1201
  else: # embedding or reward model
1184
1202
  output_embeddings = []
1185
1203
 
@@ -1207,7 +1225,6 @@ class Scheduler:
1207
1225
  req.sampling_params.spaces_between_special_tokens
1208
1226
  )
1209
1227
  output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1210
- output_session_ids.append(req.session_id)
1211
1228
 
1212
1229
  meta_info = {
1213
1230
  "prompt_tokens": len(req.origin_input_ids),
@@ -1258,7 +1275,6 @@ class Scheduler:
1258
1275
  output_meta_info,
1259
1276
  output_finished_reason,
1260
1277
  output_no_stop_trim,
1261
- output_session_ids,
1262
1278
  )
1263
1279
  )
1264
1280
  else: # embedding or reward model
@@ -1389,9 +1405,13 @@ def run_scheduler_process(
1389
1405
  dp_rank: Optional[int],
1390
1406
  pipe_writer,
1391
1407
  ):
1408
+ # set cpu affinity to this gpu process
1409
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1410
+ set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1411
+
1392
1412
  # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
1393
- if dp_rank is None:
1394
- dp_rank = int(os.getenv("DP_RANK", -1))
1413
+ if dp_rank is None and "DP_RANK" in os.environ:
1414
+ dp_rank = int(os.environ["DP_RANK"])
1395
1415
 
1396
1416
  if dp_rank is None:
1397
1417
  configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -1402,7 +1422,9 @@ def run_scheduler_process(
1402
1422
 
1403
1423
  try:
1404
1424
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1405
- pipe_writer.send("ready")
1425
+ pipe_writer.send(
1426
+ {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
1427
+ )
1406
1428
  if scheduler.enable_overlap:
1407
1429
  scheduler.event_loop_overlap()
1408
1430
  else: