sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,127 @@
1
+ import asyncio
2
+ import math
3
+ from typing import List, Union
4
+
5
+ from transformers.models.pixtral.image_processing_pixtral import (
6
+ _num_image_tokens as _get_pixtral_hf_num_image_tokens,
7
+ )
8
+
9
+ from sglang.srt.managers.multimodal_processors.base_processor import (
10
+ BaseMultimodalProcessor,
11
+ MultimodalSpecialTokens,
12
+ )
13
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
14
+ from sglang.srt.models.pixtral import PixtralVisionModel
15
+
16
+
17
+ class PixtralProcessor(BaseMultimodalProcessor):
18
+ models = [PixtralVisionModel]
19
+
20
+ PAD_TOKEN = "<pad>"
21
+ IMG_BREAK_TOKEN_ID = 12
22
+ IMG_END_TOKEN_ID = 13
23
+
24
+ def get_patch_grid_size(
25
+ self,
26
+ *,
27
+ image_width: int,
28
+ image_height: int,
29
+ ) -> tuple[int, int]:
30
+ max_width = max_height = self.image_size
31
+ patch_width = patch_height = self.patch_size
32
+
33
+ ratio = max(image_width / max_width, image_height / max_height)
34
+
35
+ if ratio > 1:
36
+ image_width = int(math.floor(image_width / ratio))
37
+ image_height = int(math.floor(image_height / ratio))
38
+
39
+ nrows, ncols = _get_pixtral_hf_num_image_tokens(
40
+ (image_height, image_width),
41
+ (patch_height, patch_width),
42
+ )
43
+
44
+ return ncols, nrows
45
+
46
+ def __init__(self, hf_config, server_args, _processor):
47
+ super().__init__(hf_config, server_args, _processor)
48
+ self.image_token_id = getattr(
49
+ hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
50
+ )
51
+ # Instantiate the patcher logic helper using the class defined above
52
+
53
+ self.vision_config = hf_config.vision_config
54
+ self.image_size = self.vision_config.image_size
55
+ self.patch_size = self.vision_config.patch_size
56
+ self.multimodal_tokens = MultimodalSpecialTokens(
57
+ image_token=_processor.image_token
58
+ )
59
+ _processor.tokenizer.add_special_tokens(
60
+ {
61
+ "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
62
+ }
63
+ )
64
+
65
+ async def _resize(self, image):
66
+ num_w_tokens, num_h_tokens = self.get_patch_grid_size(
67
+ image_width=image.size[0],
68
+ image_height=image.size[1],
69
+ )
70
+ new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
71
+ return image.resize(new_size)
72
+
73
+ async def process_mm_data_async(
74
+ self,
75
+ image_data: List[Union[str, bytes]],
76
+ input_text,
77
+ request_obj,
78
+ *args,
79
+ **kwargs,
80
+ ):
81
+ if not image_data:
82
+ return None
83
+
84
+ if isinstance(image_data, str):
85
+ image_data = [image_data]
86
+
87
+ mm_data = self.load_mm_data(
88
+ prompt=input_text,
89
+ multimodal_tokens=self.multimodal_tokens,
90
+ max_req_input_len=kwargs.get("max_req_input_len", 4096),
91
+ image_data=image_data,
92
+ return_text=True,
93
+ )
94
+
95
+ if mm_data.images:
96
+ resize_tasks = [self._resize(image) for image in mm_data.images]
97
+ mm_data.images = await asyncio.gather(*resize_tasks)
98
+
99
+ processor_output = self.process_mm_data(
100
+ input_text=mm_data.input_text,
101
+ images=mm_data.images,
102
+ )
103
+
104
+ if "pixel_values" in processor_output:
105
+ input_ids = processor_output["input_ids"].view(-1)
106
+ image_offsets = self.get_mm_items_offset(
107
+ input_ids=input_ids,
108
+ mm_token_id=self.image_token_id,
109
+ )
110
+ mm_items = [
111
+ MultimodalDataItem(
112
+ pixel_values=processor_output["pixel_values"],
113
+ image_sizes=processor_output["image_sizes"],
114
+ modality=Modality.IMAGE,
115
+ image_offsets=image_offsets,
116
+ )
117
+ ]
118
+
119
+ input_ids = input_ids.tolist()
120
+ processor_output.update(
121
+ input_ids=input_ids,
122
+ mm_items=mm_items,
123
+ # there's no im_start_id for pixtral, only im_token and im_end_token
124
+ im_end_id=self.IMG_END_TOKEN_ID,
125
+ im_token_id=self.image_token_id,
126
+ )
127
+ return processor_output
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import math
3
- from typing import List, Union
3
+ import re
4
+ from typing import Dict, List, Union
4
5
 
5
6
  import torch
6
7
  from PIL import Image
@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
23
24
 
24
25
  def __init__(self, hf_config, server_args, _processor):
25
26
  super().__init__(hf_config, server_args, _processor)
27
+ # The single, pre-expanded image token.
26
28
  self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
29
+ # The regex that matches expanded image tokens.
30
+ self.IMAGE_TOKEN_REGEX = re.compile(
31
+ r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
32
+ )
27
33
  self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
28
34
  self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
29
35
  self.image_token_id = hf_config.image_token_id
@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
38
44
 
39
45
  async def process_mm_data_async(
40
46
  self,
41
- image_data: List[Union[str, bytes]],
47
+ image_data: List[Union[str, bytes, Dict]],
42
48
  input_text,
43
49
  request_obj,
44
50
  max_req_input_len,
@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
48
54
  if isinstance(image_data, str):
49
55
  image_data = [image_data]
50
56
 
51
- image_token = self.IMAGE_TOKEN
52
57
  base_output = self.load_mm_data(
53
58
  prompt=input_text,
54
59
  image_data=image_data,
55
- multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
60
+ multimodal_tokens=MultimodalSpecialTokens(
61
+ image_token=self.IMAGE_TOKEN,
62
+ image_token_regex=self.IMAGE_TOKEN_REGEX,
63
+ ),
56
64
  max_req_input_len=max_req_input_len,
57
65
  )
58
66
 
@@ -117,26 +125,60 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
117
125
  async def resize_image_async(image):
118
126
  return resize_image(image)
119
127
 
120
- if base_output.images:
128
+ images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
129
+ if base_output.images and not images_are_preprocessed:
121
130
  resize_tasks = [resize_image_async(image) for image in base_output.images]
122
131
  base_output.images = await asyncio.gather(*resize_tasks)
123
132
 
124
133
  ret = self.process_mm_data(
125
134
  input_text=base_output.input_text,
126
- images=base_output.images,
135
+ images=None if images_are_preprocessed else base_output.images,
127
136
  )
128
-
137
+ input_ids = ret["input_ids"].flatten().tolist()
138
+ image_offsets = self.get_mm_items_offset(
139
+ input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
140
+ )
141
+ image_grid_thw = None
142
+ video_grid_thw = None # TODO
129
143
  items = []
130
144
 
131
- input_ids = ret["input_ids"].flatten().tolist()
132
- if "pixel_values" in ret:
145
+ if base_output.images:
146
+ if images_are_preprocessed:
147
+ image_grid_thw = torch.concat(
148
+ [
149
+ torch.as_tensor(item.image_grid_thws)
150
+ for item in base_output.images
151
+ ]
152
+ )
153
+ all_pixel_values = [
154
+ item.pixel_values
155
+ for item in base_output.images
156
+ if item.pixel_values is not None
157
+ ]
158
+ all_precomputed_features = [
159
+ item.precomputed_features
160
+ for item in base_output.images
161
+ if item.precomputed_features is not None
162
+ ]
163
+ pixel_values = (
164
+ torch.concat(all_pixel_values) if all_pixel_values else None
165
+ )
166
+ precomputed_features = (
167
+ torch.concat(all_precomputed_features)
168
+ if all_precomputed_features
169
+ else None
170
+ )
171
+ else:
172
+ image_grid_thw = ret["image_grid_thw"]
173
+ pixel_values = ret["pixel_values"]
174
+ precomputed_features = None
133
175
  items += [
134
176
  MultimodalDataItem(
135
- pixel_values=ret["pixel_values"],
136
- image_grid_thws=torch.concat([ret["image_grid_thw"]]),
137
- # TODO
138
- video_grid_thws=None,
139
- second_per_grid_ts=ret.get("second_per_grid_ts", None),
177
+ pixel_values=pixel_values,
178
+ image_grid_thws=image_grid_thw,
179
+ video_grid_thws=video_grid_thw,
180
+ precomputed_features=precomputed_features,
181
+ image_offsets=image_offsets,
140
182
  modality=Modality.IMAGE,
141
183
  )
142
184
  ]
@@ -151,8 +193,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
151
193
  self.hf_config.vision_config, "tokens_per_second", None
152
194
  ),
153
195
  input_ids=torch.tensor(input_ids).unsqueeze(0),
154
- image_grid_thw=ret.get("image_grid_thw", None),
155
- video_grid_thw=ret.get("video_grid_thw", None),
196
+ image_grid_thw=image_grid_thw,
197
+ video_grid_thw=video_grid_thw,
156
198
  second_per_grid_ts=ret.get("second_per_grid_ts", None),
157
199
  )
158
200
  mrope_positions = mrope_positions.squeeze(1)
@@ -1,8 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import hashlib
4
- from enum import Enum, auto
5
-
6
3
  # Copyright 2023-2024 SGLang Team
7
4
  # Licensed under the Apache License, Version 2.0 (the "License");
8
5
  # you may not use this file except in compliance with the License.
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
30
27
  It will be transformed from CPU scheduler to GPU model runner.
31
28
  - ForwardBatch is managed by `model_runner.py::ModelRunner`.
32
29
  It contains low-level tensor data. Most of the data consists of GPU tensors.
30
+
31
+ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
33
32
  """
34
33
 
35
34
  import copy
36
35
  import dataclasses
36
+ import hashlib
37
37
  import logging
38
38
  import threading
39
+ from enum import Enum, auto
39
40
  from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
40
41
 
41
42
  import numpy as np
@@ -47,10 +48,14 @@ from sglang.global_config import global_config
47
48
  from sglang.srt.configs.model_config import ModelConfig
48
49
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
49
50
  from sglang.srt.disaggregation.base import BaseKVSender
50
- from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
51
+ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
+ ScheduleBatchDisaggregationDecodeMixin,
53
+ )
54
+ from sglang.srt.layers.multimodal import gpu_tensor_hash
51
55
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
52
56
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
53
57
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
58
+ from sglang.srt.metrics.collector import TimeStats
54
59
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
55
60
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
56
61
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -73,17 +78,21 @@ global_server_args_dict = {
73
78
  "disable_radix_cache": ServerArgs.disable_radix_cache,
74
79
  "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
80
  "enable_dp_attention": ServerArgs.enable_dp_attention,
81
+ "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
76
82
  "enable_ep_moe": ServerArgs.enable_ep_moe,
83
+ "deepep_config": ServerArgs.deepep_config,
77
84
  "enable_nan_detection": ServerArgs.enable_nan_detection,
78
85
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
79
86
  "max_micro_batch_size": ServerArgs.max_micro_batch_size,
80
87
  "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
88
+ "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
81
89
  "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
82
90
  "sampling_backend": ServerArgs.sampling_backend,
83
91
  "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
84
92
  "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
85
93
  "torchao_config": ServerArgs.torchao_config,
86
94
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
95
+ "ep_num_redundant_experts": ServerArgs.ep_num_redundant_experts,
87
96
  }
88
97
 
89
98
  logger = logging.getLogger(__name__)
@@ -134,9 +143,9 @@ class FINISH_LENGTH(BaseFinishReason):
134
143
 
135
144
 
136
145
  class FINISH_ABORT(BaseFinishReason):
137
- def __init__(self, message="Unknown error", status_code=None, err_type=None):
146
+ def __init__(self, message=None, status_code=None, err_type=None):
138
147
  super().__init__(is_error=True)
139
- self.message = message
148
+ self.message = message or "Aborted"
140
149
  self.status_code = status_code
141
150
  self.err_type = err_type
142
151
 
@@ -174,10 +183,10 @@ class MultimodalDataItem:
174
183
  image_offsets: Optional[list] = None
175
184
 
176
185
  # the real data, pixel_values or audio_features
177
- # data: Union[List[torch.Tensor], List[np.array]]
178
- pixel_values: Union[torch.Tensor, np.array] = None
179
- image_grid_thws: Union[torch.Tensor, np.array] = None
180
- video_grid_thws: Union[torch.Tensor, np.array] = None
186
+ # data: Union[List[torch.Tensor], List[np.ndarray]]
187
+ pixel_values: Union[torch.Tensor, np.ndarray] = None
188
+ image_grid_thws: Union[torch.Tensor, np.ndarray] = None
189
+ video_grid_thws: Union[torch.Tensor, np.ndarray] = None
181
190
 
182
191
  image_emb_mask: Optional[torch.Tensor] = None
183
192
  image_spatial_crop: Optional[torch.Tensor] = None
@@ -186,8 +195,11 @@ class MultimodalDataItem:
186
195
  # [num_images, (n, w, h)]
187
196
  tgt_size: Tuple[int, int] = None
188
197
 
189
- audio_features: Union[torch.Tensor, np.array] = None
198
+ audio_features: Union[torch.Tensor, np.ndarray] = None
190
199
  audio_feature_lens: Optional[List[torch.Tensor]] = None
200
+ audio_offsets: Optional[List[Tuple[int, int]]] = None
201
+
202
+ precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
191
203
 
192
204
  @staticmethod
193
205
  def is_empty_list(l):
@@ -216,7 +228,8 @@ class MultimodalDataItem:
216
228
  for x in tensor_list
217
229
  ]
218
230
  tensor = torch.concat(tensor_list)
219
-
231
+ if tensor.is_cuda:
232
+ return gpu_tensor_hash(tensor)
220
233
  tensor = tensor.detach().contiguous()
221
234
 
222
235
  if tensor.dtype == torch.bfloat16:
@@ -246,7 +259,9 @@ class MultimodalDataItem:
246
259
  return tensor_hash([f])
247
260
  return data_hash(f)
248
261
 
249
- if self.is_audio():
262
+ if self.precomputed_features is not None:
263
+ self.hash = hash_feature(self.precomputed_features)
264
+ elif self.is_audio():
250
265
  self.hash = hash_feature(self.audio_features)
251
266
  else:
252
267
  self.hash = hash_feature(self.pixel_values)
@@ -255,19 +270,24 @@ class MultimodalDataItem:
255
270
  self.pad_value = self.hash % (1 << 30)
256
271
 
257
272
  def is_audio(self):
258
- return (
259
- self.modality == Modality.AUDIO
260
- ) and not MultimodalDataItem.is_empty_list(self.audio_features)
273
+ return (self.modality == Modality.AUDIO) and (
274
+ self.precomputed_features is not None
275
+ or not MultimodalDataItem.is_empty_list(self.audio_features)
276
+ )
261
277
 
262
278
  def is_image(self):
263
279
  return (
264
280
  self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
265
- ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
281
+ ) and (
282
+ self.precomputed_features is not None
283
+ or not MultimodalDataItem.is_empty_list(self.pixel_values)
284
+ )
266
285
 
267
286
  def is_video(self):
268
- return (
269
- self.modality == Modality.VIDEO
270
- ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
287
+ return (self.modality == Modality.VIDEO) and (
288
+ self.precomputed_features is not None
289
+ or not MultimodalDataItem.is_empty_list(self.pixel_values)
290
+ )
271
291
 
272
292
  def is_valid(self) -> bool:
273
293
  return self.is_image() or self.is_video() or self.is_audio()
@@ -276,6 +296,16 @@ class MultimodalDataItem:
276
296
  ...
277
297
  # TODO
278
298
 
299
+ @staticmethod
300
+ def from_dict(obj: dict):
301
+ kwargs = dict(obj)
302
+ modality = kwargs.pop("modality")
303
+ if isinstance(modality, str):
304
+ modality = Modality[modality]
305
+ ret = MultimodalDataItem(modality=modality, **kwargs)
306
+ ret.validate()
307
+ return ret
308
+
279
309
 
280
310
  @dataclasses.dataclass
281
311
  class MultimodalInputs:
@@ -301,8 +331,9 @@ class MultimodalInputs:
301
331
  video_token_id: Optional[int] = None
302
332
 
303
333
  # audio
304
- audio_start_id: Optional[torch.Tensor] = None
305
- audio_end_id: Optional[torch.Tensor] = None
334
+ audio_token_id: Optional[int] = None
335
+ audio_start_id: Optional[int] = None
336
+ audio_end_id: Optional[int] = None
306
337
 
307
338
  @staticmethod
308
339
  def from_dict(obj: dict):
@@ -326,6 +357,7 @@ class MultimodalInputs:
326
357
  "slice_end_id",
327
358
  "audio_start_id",
328
359
  "audio_end_id",
360
+ "audio_token_id",
329
361
  ]
330
362
  for arg in optional_args:
331
363
  if arg in obj:
@@ -434,6 +466,7 @@ class Req:
434
466
  self.sampling_params = sampling_params
435
467
  self.custom_logit_processor = custom_logit_processor
436
468
  self.return_hidden_states = return_hidden_states
469
+ self.lora_path = lora_path
437
470
 
438
471
  # Memory pool info
439
472
  self.req_pool_idx: Optional[int] = None
@@ -441,11 +474,13 @@ class Req:
441
474
  # Check finish
442
475
  self.tokenizer = None
443
476
  self.finished_reason = None
477
+ # Whether this request has finished output
478
+ self.finished_output = None
444
479
  # If we want to abort the request in the middle of the event loop, set this to true
445
480
  # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
446
481
  self.to_abort = False
447
482
  # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
448
- self.to_abort_message: str = "Unknown error"
483
+ self.to_abort_message: str = None
449
484
  self.stream = stream
450
485
  self.eos_token_ids = eos_token_ids
451
486
 
@@ -483,6 +518,13 @@ class Req:
483
518
  # For retraction
484
519
  self.is_retracted = False
485
520
 
521
+ # Incremental streamining
522
+ self.send_token_offset: int = 0
523
+ self.send_decode_id_offset: int = 0
524
+ # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
525
+ # because the decode server does not have the first output token logprobs
526
+ self.send_output_token_logprobs_offset: int = 0
527
+
486
528
  # Logprobs (arguments)
487
529
  self.return_logprob = return_logprob
488
530
  # Start index to compute logprob from.
@@ -492,11 +534,9 @@ class Req:
492
534
  self.temp_scaled_logprobs = False
493
535
  self.top_p_normalized_logprobs = False
494
536
 
495
- # Latency Breakdown
496
- self.queue_time_start = None
497
- self.queue_time_end = None
498
-
499
537
  # Logprobs (return values)
538
+ # True means the input logprob has been already sent to detokenizer.
539
+ self.input_logprob_sent: bool = False
500
540
  self.input_token_logprobs_val: Optional[List[float]] = None
501
541
  self.input_token_logprobs_idx: Optional[List[int]] = None
502
542
  self.input_top_logprobs_val: Optional[List[float]] = None
@@ -511,8 +551,10 @@ class Req:
511
551
  self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
512
552
 
513
553
  if return_logprob:
554
+ # shape: (bs, 1)
514
555
  self.output_token_logprobs_val = []
515
556
  self.output_token_logprobs_idx = []
557
+ # shape: (bs, k)
516
558
  self.output_top_logprobs_val = []
517
559
  self.output_top_logprobs_idx = []
518
560
  self.output_token_ids_logprobs_val = []
@@ -530,6 +572,7 @@ class Req:
530
572
 
531
573
  # Constrained decoding
532
574
  self.grammar: Optional[BaseGrammarObject] = None
575
+ self.grammar_wait_ct = 0
533
576
 
534
577
  # The number of cached tokens that were already cached in the KV cache
535
578
  self.cached_tokens = 0
@@ -538,7 +581,12 @@ class Req:
538
581
  # The number of verification forward passes in the speculative decoding.
539
582
  # This is used to compute the average acceptance length per request.
540
583
  self.spec_verify_ct = 0
541
- self.lora_path = lora_path
584
+
585
+ # For metrics
586
+ self.time_stats: TimeStats = TimeStats()
587
+ self.has_log_time_stats: bool = False
588
+ self.queue_time_start = None
589
+ self.queue_time_end = None
542
590
 
543
591
  # For disaggregation
544
592
  self.bootstrap_host: str = bootstrap_host
@@ -546,8 +594,6 @@ class Req:
546
594
  self.bootstrap_room: Optional[int] = bootstrap_room
547
595
  self.disagg_kv_sender: Optional[BaseKVSender] = None
548
596
 
549
- # used for warmup because we don't have a pair yet when init
550
- self.skip_kv_transfer: bool = False
551
597
  # the start index of the sent kv cache
552
598
  # We want to send it chunk by chunk for chunked prefill.
553
599
  # After every chunk forward, we do the following:
@@ -555,14 +601,11 @@ class Req:
555
601
  # start_send_idx = len(req.fill_ids)
556
602
  self.start_send_idx: int = 0
557
603
 
558
- self.metadata_buffer_index: int = -1
559
- # The first output_id transferred from prefill instance.
560
- self.transferred_output_id: Optional[int] = None
561
-
562
604
  # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
563
605
  # This is because kv is not ready in `process_prefill_chunk`.
564
606
  # We use `tmp_end_idx` to store the end index of the kv cache to send.
565
607
  self.tmp_end_idx: int = -1
608
+ self.metadata_buffer_index: int = -1
566
609
 
567
610
  @property
568
611
  def seqlen(self):
@@ -653,6 +696,11 @@ class Req:
653
696
  )
654
697
  return
655
698
 
699
+ if self.grammar is not None:
700
+ if self.grammar.is_terminated():
701
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
702
+ return
703
+
656
704
  last_token_id = self.output_ids[-1]
657
705
 
658
706
  if not self.sampling_params.ignore_eos:
@@ -697,13 +745,41 @@ class Req:
697
745
  self.req_pool_idx = None
698
746
  self.already_computed = 0
699
747
 
748
+ def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
749
+ token_indices = req_to_token_pool.req_to_token[
750
+ self.req_pool_idx, : self.seqlen - 1
751
+ ]
752
+ self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
753
+
754
+ def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
755
+ token_indices = req_to_token_pool.req_to_token[
756
+ self.req_pool_idx, : self.seqlen - 1
757
+ ]
758
+ token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
759
+ del self.kv_cache_cpu
760
+
761
+ def log_time_stats(self):
762
+ # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
763
+ if self.has_log_time_stats is True:
764
+ return
765
+
766
+ if self.bootstrap_room is not None:
767
+ prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
768
+ else:
769
+ prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
770
+ logger.info(f"{prefix}: {self.time_stats}")
771
+ self.has_log_time_stats = True
772
+
700
773
  def __repr__(self):
701
774
  return (
702
775
  f"Req(rid={self.rid}, "
703
- f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
776
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
777
+ f"{self.grammar=}, "
778
+ f"{self.sampling_params=})"
704
779
  )
705
780
 
706
781
 
782
+ # Batch id
707
783
  bid = 0
708
784
 
709
785
 
@@ -862,7 +938,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
862
938
  error_msg = (
863
939
  f"{phase_str} out of memory. Try to lower your batch size.\n"
864
940
  f"Try to allocate {num_tokens} tokens.\n"
865
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
941
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
866
942
  )
867
943
  logger.error(error_msg)
868
944
  if self.tree_cache is not None:
@@ -903,7 +979,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
903
979
  error_msg = (
904
980
  f"Prefill out of memory. Try to lower your batch size.\n"
905
981
  f"Try to allocate {extend_num_tokens} tokens.\n"
906
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
982
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
907
983
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
908
984
  f"{self.tree_cache.evictable_size()=}\n"
909
985
  )
@@ -938,7 +1014,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
938
1014
  error_msg = (
939
1015
  f"Decode out of memory. Try to lower your batch size.\n"
940
1016
  f"Try to allocate {len(seq_lens)} tokens.\n"
941
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
1017
+ f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
942
1018
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
943
1019
  f"{self.tree_cache.evictable_size()=}\n"
944
1020
  )
@@ -1019,7 +1095,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1019
1095
  else:
1020
1096
  self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
1021
1097
 
1022
- assert len(self.out_cache_loc) == self.extend_num_tokens
1098
+ assert (
1099
+ len(self.out_cache_loc) == self.extend_num_tokens
1100
+ ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1023
1101
 
1024
1102
  def prepare_for_extend(self):
1025
1103
  self.forward_mode = ForwardMode.EXTEND
@@ -1447,7 +1525,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1447
1525
  i
1448
1526
  for i in range(len(self.reqs))
1449
1527
  if not self.reqs[i].finished()
1450
- and not self.reqs[i] in chunked_req_to_exclude
1528
+ and self.reqs[i] not in chunked_req_to_exclude
1451
1529
  ]
1452
1530
 
1453
1531
  if keep_indices is None or len(keep_indices) == 0:
@@ -1468,7 +1546,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1468
1546
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1469
1547
 
1470
1548
  self.reqs = [self.reqs[i] for i in keep_indices]
1471
- self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1549
+ if self.multimodal_inputs is not None:
1550
+ self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1472
1551
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1473
1552
  self.seq_lens = self.seq_lens[keep_indices_device]
1474
1553
  self.out_cache_loc = None
@@ -1517,7 +1596,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1517
1596
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1518
1597
  self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1519
1598
  self.reqs.extend(other.reqs)
1520
- self.multimodal_inputs.extend(other.multimodal_inputs)
1599
+ if self.multimodal_inputs is not None:
1600
+ self.multimodal_inputs.extend(other.multimodal_inputs)
1521
1601
 
1522
1602
  self.return_logprob |= other.return_logprob
1523
1603
  self.has_stream |= other.has_stream
@@ -22,11 +22,7 @@ from typing import Dict, List, Optional, Set, Union
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.managers.schedule_batch import (
26
- Req,
27
- ScheduleBatch,
28
- global_server_args_dict,
29
- )
25
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
30
26
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
31
27
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
32
28
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode