sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +7 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +13 -1
  16. sglang/srt/entrypoints/openai/protocol.py +3 -1
  17. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  18. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  19. sglang/srt/function_call/ebnf_composer.py +10 -3
  20. sglang/srt/function_call/function_call_parser.py +2 -0
  21. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  22. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  24. sglang/srt/layers/attention/vision.py +56 -8
  25. sglang/srt/layers/layernorm.py +26 -1
  26. sglang/srt/layers/logits_processor.py +14 -3
  27. sglang/srt/layers/moe/ep_moe/layer.py +323 -242
  28. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  33. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  34. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  35. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  36. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  37. sglang/srt/layers/moe/topk.py +90 -24
  38. sglang/srt/layers/multimodal.py +11 -8
  39. sglang/srt/layers/quantization/fp8.py +25 -247
  40. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  41. sglang/srt/layers/quantization/modelopt_quant.py +27 -10
  42. sglang/srt/layers/quantization/unquant.py +24 -76
  43. sglang/srt/layers/quantization/w4afp8.py +68 -17
  44. sglang/srt/lora/lora_registry.py +93 -29
  45. sglang/srt/managers/cache_controller.py +9 -7
  46. sglang/srt/managers/data_parallel_controller.py +4 -0
  47. sglang/srt/managers/io_struct.py +12 -0
  48. sglang/srt/managers/mm_utils.py +154 -35
  49. sglang/srt/managers/multimodal_processor.py +3 -14
  50. sglang/srt/managers/schedule_batch.py +14 -8
  51. sglang/srt/managers/scheduler.py +64 -1
  52. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  53. sglang/srt/managers/tokenizer_manager.py +80 -15
  54. sglang/srt/managers/tp_worker.py +8 -0
  55. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  56. sglang/srt/model_executor/model_runner.py +83 -27
  57. sglang/srt/models/deepseek_v2.py +75 -84
  58. sglang/srt/models/glm4_moe.py +1035 -0
  59. sglang/srt/models/glm4_moe_nextn.py +167 -0
  60. sglang/srt/models/interns1.py +328 -0
  61. sglang/srt/models/internvl.py +143 -47
  62. sglang/srt/models/llava.py +9 -5
  63. sglang/srt/models/minicpmo.py +4 -1
  64. sglang/srt/models/qwen2_moe.py +2 -2
  65. sglang/srt/models/qwen3_moe.py +17 -71
  66. sglang/srt/multimodal/processors/base_processor.py +20 -6
  67. sglang/srt/multimodal/processors/clip.py +2 -2
  68. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  69. sglang/srt/multimodal/processors/gemma3.py +2 -2
  70. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  71. sglang/srt/multimodal/processors/internvl.py +21 -8
  72. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  73. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  74. sglang/srt/multimodal/processors/llava.py +4 -4
  75. sglang/srt/multimodal/processors/minicpm.py +2 -3
  76. sglang/srt/multimodal/processors/mlama.py +2 -2
  77. sglang/srt/multimodal/processors/mllama4.py +18 -111
  78. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  79. sglang/srt/multimodal/processors/pixtral.py +2 -2
  80. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  81. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  82. sglang/srt/multimodal/processors/vila.py +3 -1
  83. sglang/srt/poll_based_barrier.py +31 -0
  84. sglang/srt/reasoning_parser.py +2 -1
  85. sglang/srt/server_args.py +65 -6
  86. sglang/srt/two_batch_overlap.py +8 -3
  87. sglang/srt/utils.py +96 -1
  88. sglang/srt/weight_sync/utils.py +119 -0
  89. sglang/test/runners.py +4 -0
  90. sglang/test/test_utils.py +118 -5
  91. sglang/utils.py +19 -0
  92. sglang/version.py +1 -1
  93. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
  94. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
  95. sglang/srt/debug_utils.py +0 -74
  96. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  97. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  98. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,9 @@ Multi-modality utils
3
3
  """
4
4
 
5
5
  import hashlib
6
+ import pickle
6
7
  from abc import abstractmethod
7
- from typing import Callable, Dict, List, Optional, Tuple
8
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
8
9
 
9
10
  import numpy as np
10
11
  import torch
@@ -27,6 +28,128 @@ from sglang.utils import logger
27
28
  # propagation that can cause some log messages (like 'server is fired up') to not appear
28
29
  # in the console when multimodal support is enabled.
29
30
 
31
+ # TODO(mick): nccl
32
+ # cuda_ipc: for intranode tensor sharing
33
+ TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
34
+
35
+
36
+ class TransportProxyTensor(torch.Tensor):
37
+ """
38
+ A convenient torch.Tensor subclass that carries extra metadata and supports
39
+ efficient inter-process communications
40
+ """
41
+
42
+ @staticmethod
43
+ def __new__(
44
+ cls,
45
+ data: torch.Tensor,
46
+ name: Optional[str] = None,
47
+ fields: Optional[Dict[str, Any]] = None,
48
+ transport_mode: TensorTransportMode = "default",
49
+ *args,
50
+ **kwargs,
51
+ ):
52
+
53
+ if not isinstance(data, torch.Tensor):
54
+ raise TypeError(
55
+ f"Input 'data' must be a torch.Tensor, but got {type(data)}"
56
+ )
57
+
58
+ instance = data.as_subclass(cls)
59
+
60
+ instance._metadata = {
61
+ "name": name,
62
+ "fields": fields if fields is not None else {},
63
+ "transport_mode": transport_mode,
64
+ }
65
+
66
+ return instance
67
+
68
+ def __getstate__(self):
69
+ """
70
+ Called during pickling. Implements the serialization logic.
71
+ """
72
+ # acquire all serialize metadata from _metadata
73
+ state = {
74
+ "metadata": self._metadata,
75
+ "tensor_data": None,
76
+ "ipc_extra": None,
77
+ }
78
+
79
+ transport_mode = self._metadata.get("transport_mode", "default")
80
+
81
+ if transport_mode == "cuda_ipc" and self.is_cuda:
82
+ try:
83
+ storage = self.untyped_storage()
84
+ handle = storage._share_cuda_()
85
+
86
+ state["ipc_extra"] = {
87
+ "handle": handle,
88
+ "shape": self.shape,
89
+ "dtype": self.dtype,
90
+ "stride": self.stride(),
91
+ "device_index": self.device.index,
92
+ }
93
+ state["tensor_data"] = None
94
+ except Exception as e:
95
+ # Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
96
+ state["metadata"]["transport_mode"] = "default"
97
+ state["tensor_data"] = self.as_subclass(torch.Tensor)
98
+ else:
99
+ state["metadata"]["transport_mode"] = "default"
100
+ state["tensor_data"] = self.as_subclass(torch.Tensor)
101
+
102
+ return state
103
+
104
+ def __setstate__(self, state: Dict[str, Any]):
105
+ """
106
+ Called during unpickling. Implements the deserialization logic.
107
+ """
108
+ self._metadata = state["metadata"]
109
+
110
+ transport_mode = self._metadata.get("transport_mode", "default")
111
+
112
+ if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
113
+ ipc_extra = state["ipc_extra"]
114
+ handle, shape, dtype, stride, source_device_index = (
115
+ ipc_extra["handle"],
116
+ ipc_extra["shape"],
117
+ ipc_extra["dtype"],
118
+ ipc_extra["stride"],
119
+ ipc_extra["device_index"],
120
+ )
121
+
122
+ try:
123
+ target_device = torch.device(f"cuda:{source_device_index}")
124
+ with torch.cuda.device(target_device):
125
+ storage = torch.UntypedStorage._new_shared_cuda(*handle)
126
+ reconstructed_tensor = torch.empty(
127
+ 0, dtype=dtype, device=target_device
128
+ ).set_(storage, storage_offset=0, size=shape, stride=stride)
129
+ self.set_(reconstructed_tensor)
130
+ except Exception as e:
131
+ print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
132
+ raise e
133
+
134
+ elif state["tensor_data"] is not None:
135
+ self.set_(state["tensor_data"])
136
+ else:
137
+ raise pickle.UnpicklingError(
138
+ "Invalid state for TransportProxyTensor: no tensor data found."
139
+ )
140
+
141
+ @property
142
+ def name(self) -> Optional[str]:
143
+ return self._metadata.get("name")
144
+
145
+ @property
146
+ def fields(self) -> Dict[str, Any]:
147
+ return self._metadata.get("fields", {})
148
+
149
+ @property
150
+ def transport_mode(self) -> TensorTransportMode:
151
+ return self._metadata.get("transport_mode", "default")
152
+
30
153
 
31
154
  class MultiModalityDataPaddingPattern:
32
155
  """
@@ -85,8 +208,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
85
208
  "No data_token_pairs provided, RadixAttention might be influenced."
86
209
  )
87
210
  return input_ids
88
- start_token_ids = [s for s, _e in data_token_pairs]
89
- end_tokens_ids = [e for _s, e in data_token_pairs]
211
+ start_token_ids = {s for s, _e in data_token_pairs}
212
+ end_tokens_ids = {e for _s, e in data_token_pairs}
90
213
 
91
214
  padded_ids = []
92
215
  last_idx = 0
@@ -135,7 +258,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
135
258
  if not input_ids or not mm_inputs.mm_items:
136
259
  return input_ids
137
260
 
138
- input_ids_tensor = torch.tensor(input_ids)
261
+ input_ids_tensor = torch.as_tensor(input_ids)
139
262
 
140
263
  # Create mapping of token_ids to pad_values for each modality
141
264
  token_to_pad_mapping = {}
@@ -211,7 +334,7 @@ def get_embedding_chunk(
211
334
  end_index += extend_end_index - start + 1
212
335
  elif extend_end_index > end:
213
336
  end_index += end - start + 1
214
- # some models embedding is 3-dim, reshape it to 2-dim
337
+ # some models' embedding is 3-dim, reshape it to 2-dim
215
338
  embedding = embedding.reshape(-1, embedding.shape[-1])
216
339
  embedding_chunk = embedding[start_index:end_index]
217
340
  return embedding_chunk, start_index, end_index
@@ -428,7 +551,7 @@ def embed_mm_inputs(
428
551
  modality_id = modality.name.lower()
429
552
  embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
430
553
  if len(items) != 0 and embedder is not None:
431
- placeholder_tensor = torch.tensor(
554
+ placeholder_tensor = torch.as_tensor(
432
555
  [item.pad_value for item in items],
433
556
  device=input_ids.device,
434
557
  )
@@ -473,11 +596,9 @@ def embed_mm_inputs(
473
596
  for embedding, mask in zip(embeddings, masks):
474
597
  if embedding is None or mask is None:
475
598
  continue
476
- mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
477
- inputs_embeds = inputs_embeds.masked_scatter(
478
- mask,
479
- embedding.to(inputs_embeds.device, inputs_embeds.dtype),
480
- )
599
+ # in-place update
600
+ indices = torch.where(mask.squeeze(dim=-1))[0]
601
+ inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
481
602
  return inputs_embeds
482
603
 
483
604
 
@@ -561,34 +682,36 @@ def get_multimodal_data_bounds(
561
682
  [bounds_count, 2]
562
683
  """
563
684
  # All the multimodal data in the batch should share the same special bound token ids.
564
- start_tokens = [s for s, _e in token_pairs]
565
- end_tokens = [e for _s, e in token_pairs]
685
+ start_tokens = {s for s, _e in token_pairs}
686
+ end_tokens = {e for _s, e in token_pairs}
566
687
 
567
688
  assert all(isinstance(t, int) for t in start_tokens)
568
689
  assert all(isinstance(t, int) for t in end_tokens)
569
690
 
570
691
  start_cond = torch.isin(
571
- input_ids, torch.tensor(start_tokens, device=input_ids.device)
692
+ input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
693
+ )
694
+ end_cond = torch.isin(
695
+ input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
572
696
  )
573
- end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
574
697
 
575
698
  (data_start_tokens,) = torch.where(start_cond)
576
699
  (data_end_tokens,) = torch.where(end_cond)
577
700
 
701
+ data_start_tokens_cpu = data_start_tokens.cpu().tolist()
702
+ data_end_tokens_cpu = data_end_tokens.cpu().tolist()
703
+
578
704
  # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
579
- if len(data_start_tokens) != len(data_end_tokens):
705
+ if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
580
706
  if (
581
- len(data_start_tokens) + 1 == len(data_end_tokens)
582
- and input_ids[0] in pad_values
583
- and data_end_tokens[0] < data_start_tokens[0]
707
+ len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
708
+ and input_ids[0].item() in pad_values
709
+ and data_end_tokens_cpu
710
+ and data_start_tokens_cpu
711
+ and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
584
712
  ):
585
- data_start_tokens = torch.cat(
586
- [
587
- torch.tensor([0], device=data_start_tokens.device),
588
- data_start_tokens,
589
- ]
590
- )
591
- valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
713
+ data_start_tokens_cpu.insert(0, 0)
714
+ valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
592
715
 
593
716
  if valid_mm_data_nums == 0:
594
717
  return torch.zeros((0, 2), device=input_ids.device)
@@ -596,8 +719,8 @@ def get_multimodal_data_bounds(
596
719
  # Filter out pairs where start_token >= end_token
597
720
  valid_pairs = []
598
721
  for i in range(valid_mm_data_nums):
599
- start_token = data_start_tokens[i]
600
- end_token = data_end_tokens[i]
722
+ start_token = data_start_tokens_cpu[i]
723
+ end_token = data_end_tokens_cpu[i]
601
724
  if start_token < end_token:
602
725
  valid_pairs.append((start_token + 1, end_token - 1))
603
726
 
@@ -605,7 +728,7 @@ def get_multimodal_data_bounds(
605
728
  return torch.zeros((0, 2), device=input_ids.device)
606
729
 
607
730
  # Convert valid pairs to tensor
608
- valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
731
+ valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
609
732
  return valid_pairs_tensor
610
733
 
611
734
 
@@ -626,7 +749,7 @@ def tensor_hash(tensor_list) -> int:
626
749
  ]
627
750
  tensor = torch.concat(tensor_list)
628
751
  if tensor.is_cuda:
629
- return gpu_tensor_hash(tensor)
752
+ return gpu_tensor_hash(tensor.cuda())
630
753
  tensor = tensor.detach().contiguous()
631
754
 
632
755
  if tensor.dtype == torch.bfloat16:
@@ -634,11 +757,7 @@ def tensor_hash(tensor_list) -> int:
634
757
  tensor = tensor.float()
635
758
 
636
759
  assert isinstance(tensor, torch.Tensor)
637
- if tensor.is_cuda:
638
- # TODO: improve this
639
- tensor_cpu = tensor.cpu()
640
- else:
641
- tensor_cpu = tensor
760
+ tensor_cpu = tensor.cpu()
642
761
 
643
762
  mv = memoryview(tensor_cpu.numpy())
644
763
  return data_hash(mv.tobytes())
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
12
12
  PROCESSOR_MAPPING = {}
13
13
 
14
14
 
15
- class DummyMultimodalProcessor(BaseMultimodalProcessor):
16
- def __init__(self):
17
- pass
18
-
19
- async def process_mm_data_async(self, *args, **kwargs):
20
- return None
21
-
22
-
23
- def get_dummy_processor():
24
- return DummyMultimodalProcessor()
25
-
26
-
27
15
  def import_processors():
28
16
  package_name = "sglang.srt.multimodal.processors"
29
17
  package = importlib.import_module(package_name)
@@ -49,11 +37,12 @@ def import_processors():
49
37
 
50
38
 
51
39
  def get_mm_processor(
52
- hf_config, server_args: ServerArgs, processor
40
+ hf_config, server_args: ServerArgs, processor, transport_mode
53
41
  ) -> BaseMultimodalProcessor:
54
42
  for model_cls, processor_cls in PROCESSOR_MAPPING.items():
55
43
  if model_cls.__name__ in hf_config.architectures:
56
- return processor_cls(hf_config, server_args, processor)
44
+ return processor_cls(hf_config, server_args, processor, transport_mode)
45
+
57
46
  raise ValueError(
58
47
  f"No processor registered for architecture: {hf_config.architectures}.\n"
59
48
  f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
88
88
  "enable_deepep_moe",
89
89
  "deepep_mode",
90
90
  "enable_ep_moe",
91
- "enable_flashinfer_moe",
91
+ "enable_flashinfer_cutlass_moe",
92
+ "enable_flashinfer_trtllm_moe",
92
93
  "enable_flashinfer_allreduce_fusion",
93
94
  "moe_dense_tp_size",
94
95
  "ep_dispatch_algorithm",
@@ -209,10 +210,11 @@ class MultimodalDataItem:
209
210
  hash: int = None
210
211
  pad_value: int = None
211
212
  offsets: Optional[list] = None
213
+
212
214
  # the raw features returned by processor, e.g. pixel_values or audio_features
213
215
  feature: Union[torch.Tensor, np.ndarray] = None
214
-
215
- # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
216
+ # the precomputed embeddings, passed as final encoder embeddings
217
+ # One and only one of the feature and precomputed_embeddings will be empty
216
218
  precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
217
219
 
218
220
  # Model-specific data stored in a dictionary
@@ -1688,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1688
1690
  extend_prefix_lens = self.prefix_lens
1689
1691
  extend_logprob_start_lens = self.extend_logprob_start_lens
1690
1692
 
1693
+ if self.forward_mode.is_decode_or_idle():
1694
+ attention_backend_str = global_server_args_dict["decode_attention_backend"]
1695
+ else:
1696
+ attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1691
1697
  # Create seq_lens_cpu when needed
1692
1698
  if (
1693
- global_server_args_dict["attention_backend"] == "fa3"
1699
+ attention_backend_str == "fa3"
1694
1700
  or (
1695
1701
  global_server_args_dict["use_mla_backend"]
1696
- and global_server_args_dict["attention_backend"] == "flashinfer"
1702
+ and attention_backend_str == "flashinfer"
1697
1703
  )
1698
- or global_server_args_dict["attention_backend"] == "flashmla"
1699
- or global_server_args_dict["attention_backend"] == "cutlass_mla"
1700
- or global_server_args_dict["attention_backend"] == "ascend"
1704
+ or attention_backend_str == "flashmla"
1705
+ or attention_backend_str == "cutlass_mla"
1706
+ or attention_backend_str == "ascend"
1701
1707
  or global_server_args_dict["enable_two_batch_overlap"]
1702
1708
  ):
1703
1709
  seq_lens_cpu = (
@@ -24,6 +24,7 @@ import time
24
24
  from collections import defaultdict, deque
25
25
  from concurrent import futures
26
26
  from dataclasses import dataclass
27
+ from http import HTTPStatus
27
28
  from pathlib import Path
28
29
  from types import SimpleNamespace
29
30
  from typing import Dict, List, Optional, Tuple, Union
@@ -122,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
122
123
  PrefillAdder,
123
124
  SchedulePolicy,
124
125
  )
126
+ from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
125
127
  from sglang.srt.managers.scheduler_output_processor_mixin import (
126
128
  SchedulerOutputProcessorMixin,
127
129
  )
@@ -370,6 +372,7 @@ class Scheduler(
370
372
  self.max_total_num_tokens,
371
373
  self.max_prefill_tokens,
372
374
  self.max_running_requests,
375
+ self.max_queued_requests,
373
376
  self.max_req_len,
374
377
  self.max_req_input_len,
375
378
  self.random_seed,
@@ -458,7 +461,10 @@ class Scheduler(
458
461
  self.grammar_queue: List[Req] = []
459
462
  if not server_args.skip_tokenizer_init:
460
463
  self.grammar_backend = create_grammar_backend(
461
- server_args, self.tokenizer, self.model_config.vocab_size
464
+ server_args,
465
+ self.tokenizer,
466
+ self.model_config.vocab_size,
467
+ self.model_config.hf_eos_token_id,
462
468
  )
463
469
  else:
464
470
  self.grammar_backend = None
@@ -499,6 +505,12 @@ class Scheduler(
499
505
  )
500
506
  self.init_profier()
501
507
 
508
+ self.input_blocker = (
509
+ SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
510
+ if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
511
+ else None
512
+ )
513
+
502
514
  # Init metrics stats
503
515
  self.init_metrics(tp_rank, pp_rank, dp_rank)
504
516
  self.init_kv_events(server_args.kv_events_config)
@@ -1030,6 +1042,9 @@ class Scheduler(
1030
1042
  else:
1031
1043
  recv_reqs = None
1032
1044
 
1045
+ if self.input_blocker is not None:
1046
+ recv_reqs = self.input_blocker.handle(recv_reqs)
1047
+
1033
1048
  if self.server_args.enable_dp_attention:
1034
1049
  if self.attn_tp_rank == 0:
1035
1050
  work_reqs = [
@@ -1083,6 +1098,19 @@ class Scheduler(
1083
1098
  self.return_health_check_ct += 1
1084
1099
  continue
1085
1100
 
1101
+ # If it is a work request, accept or reject the request based on the request queue size.
1102
+ if is_work_request(recv_req):
1103
+ if len(self.waiting_queue) + 1 > self.max_queued_requests:
1104
+ abort_req = AbortReq(
1105
+ recv_req.rid,
1106
+ finished_reason={
1107
+ "type": "abort",
1108
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1109
+ "message": "The request queue is full.",
1110
+ },
1111
+ )
1112
+ self.send_to_tokenizer.send_pyobj(abort_req)
1113
+ continue
1086
1114
  output = self._request_dispatcher(recv_req)
1087
1115
  if output is not None:
1088
1116
  if isinstance(output, RpcReqOutput):
@@ -2437,6 +2465,37 @@ class Scheduler(
2437
2465
  req.grammar.cancel()
2438
2466
  req.set_finish_with_abort("Aborted by AbortReq.")
2439
2467
 
2468
+ # Delete requests not in the waiting queue when PD disaggregation is enabled
2469
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
2470
+ # Abort requests that have not yet been bootstrapped
2471
+ for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2472
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2473
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2474
+ if hasattr(req.disagg_kv_sender, "abort"):
2475
+ req.disagg_kv_sender.abort()
2476
+
2477
+ # Abort in-flight requests
2478
+ for i, req in enumerate(self.disagg_prefill_inflight_queue):
2479
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2480
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2481
+ if hasattr(req.disagg_kv_sender, "abort"):
2482
+ req.disagg_kv_sender.abort()
2483
+
2484
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
2485
+ # Abort requests that have not yet finished preallocation
2486
+ for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2487
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2488
+ if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2489
+ if hasattr(decode_req.kv_receiver, "abort"):
2490
+ decode_req.kv_receiver.abort()
2491
+
2492
+ # Abort requests waiting for kvcache to release tree cache
2493
+ for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2494
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2495
+ if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2496
+ if hasattr(decode_req.kv_receiver, "abort"):
2497
+ decode_req.kv_receiver.abort()
2498
+
2440
2499
  # Delete requests in the running batch
2441
2500
  if self.cur_batch is self.running_batch or self.cur_batch is None:
2442
2501
  reqs = self.running_batch.reqs
@@ -2868,6 +2927,10 @@ def is_health_check_generate_req(recv_req):
2868
2927
  return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2869
2928
 
2870
2929
 
2930
+ def is_work_request(recv_req):
2931
+ return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2932
+
2933
+
2871
2934
  def _export_static_state(model):
2872
2935
  return dict(
2873
2936
  buffers=[
@@ -0,0 +1,106 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import logging
15
+ from contextlib import contextmanager
16
+ from enum import Enum, auto
17
+ from typing import Any, List, Optional
18
+
19
+ from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
20
+ from sglang.srt.poll_based_barrier import PollBasedBarrier
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class SchedulerInputBlocker:
26
+ def __init__(self, noop: bool):
27
+ self._state = _State.UNBLOCKED
28
+ self._pending_reqs = []
29
+ self._noop = noop
30
+ self._global_unblock_barrier = PollBasedBarrier(noop=noop)
31
+
32
+ def handle(self, recv_reqs: Optional[List[Any]]):
33
+ assert (recv_reqs is None) == self._noop
34
+
35
+ if not self._noop:
36
+ output_reqs = []
37
+ for recv_req in recv_reqs:
38
+ output_reqs += self._handle_recv_req(recv_req)
39
+
40
+ global_arrived_unblock_barrier = (
41
+ self._global_unblock_barrier.poll_global_arrived()
42
+ )
43
+ if (
44
+ self._state == _State.GLOBAL_UNBLOCK_BARRIER
45
+ and global_arrived_unblock_barrier
46
+ ):
47
+ output_reqs += self._handle_arrive_unblock_barrier()
48
+
49
+ if not self._noop:
50
+ return output_reqs
51
+
52
+ def _handle_recv_req(self, recv_req):
53
+ if isinstance(recv_req, BlockReqInput):
54
+ if recv_req.type == BlockReqType.BLOCK:
55
+ self._execute_block_req()
56
+ return []
57
+ elif recv_req.type == BlockReqType.UNBLOCK:
58
+ self._execute_unblock_req()
59
+ return []
60
+ else:
61
+ raise NotImplementedError(f"{recv_req=}")
62
+ else:
63
+ if self._state == _State.UNBLOCKED:
64
+ return [recv_req]
65
+ else:
66
+ self._pending_reqs.append(recv_req)
67
+ return []
68
+
69
+ def _execute_block_req(self):
70
+ logger.info("Handle block req")
71
+ self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
72
+
73
+ def _execute_unblock_req(self):
74
+ logger.info("Handle unblock req")
75
+ self._change_state(
76
+ original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
77
+ )
78
+ self._global_unblock_barrier.local_arrive()
79
+
80
+ def _handle_arrive_unblock_barrier(self):
81
+ logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
82
+ self._change_state(
83
+ original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
84
+ )
85
+ output_reqs = [*self._pending_reqs]
86
+ self._pending_reqs.clear()
87
+ return output_reqs
88
+
89
+ def _change_state(self, original: "_State", target: "_State"):
90
+ assert self._state == original, f"{self._state=} {original=} {target=}"
91
+ self._state = target
92
+
93
+
94
+ class _State(Enum):
95
+ UNBLOCKED = auto()
96
+ BLOCKED = auto()
97
+ GLOBAL_UNBLOCK_BARRIER = auto()
98
+
99
+
100
+ @contextmanager
101
+ def input_blocker_guard_region(send_to_scheduler):
102
+ send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
103
+ try:
104
+ yield
105
+ finally:
106
+ send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))