sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
1
+ import asyncio
2
+ from typing import List, Union
3
+
4
+ import torch
5
+
6
+ from sglang.srt.managers.multimodal_processors.base_processor import (
7
+ BaseMultimodalProcessor,
8
+ MultimodalSpecialTokens,
9
+ get_global_processor,
10
+ )
11
+ from sglang.srt.models.minicpmo import MiniCPMO
12
+ from sglang.srt.models.minicpmv import MiniCPMV
13
+
14
+
15
+ # Compatible with both 'O' and 'V'
16
+ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
17
+ models = [MiniCPMV, MiniCPMO]
18
+
19
+ def __init__(self, hf_config, server_args, _processor):
20
+ super().__init__(hf_config, server_args, _processor)
21
+ self.image_token = "(<image>./</image>)"
22
+ self.audio_token = "(<audio>./</audio>)"
23
+
24
+ @staticmethod
25
+ def _process_data_task(input_text, images=None, audios=None):
26
+
27
+ if isinstance(images, list) and len(images) == 0:
28
+ images = None
29
+ if isinstance(audios, list) and len(audios) == 0:
30
+ audios = None
31
+ result = get_global_processor().__call__(
32
+ text=input_text,
33
+ images=images,
34
+ audios=audios,
35
+ return_tensors="pt",
36
+ chunk_input=True,
37
+ )
38
+ return {
39
+ "input_ids": result.input_ids,
40
+ "pixel_values": getattr(result, "pixel_values", None),
41
+ "tgt_sizes": getattr(result, "tgt_sizes", None),
42
+ "audio_features": getattr(result, "audio_features", None),
43
+ "audio_feature_lens": getattr(result, "audio_feature_lens", None),
44
+ "audio_bounds": getattr(result, "audio_bounds", None),
45
+ }
46
+
47
+ async def _process_data(self, images, input_text, audios=None):
48
+ if self.executor is not None:
49
+ loop = asyncio.get_event_loop()
50
+ multimodal_data_inputs = await loop.run_in_executor(
51
+ self.executor,
52
+ MiniCPMMultimodalProcessor._process_data_task,
53
+ input_text,
54
+ images,
55
+ audios,
56
+ )
57
+ else:
58
+ multimodal_data_inputs = self._processor(
59
+ images=images, text=input_text, audios=audios, return_tensors="pt"
60
+ )
61
+
62
+ return multimodal_data_inputs
63
+
64
+ async def process_mm_data_async(
65
+ self,
66
+ image_data: List[Union[str, bytes]],
67
+ input_ids,
68
+ request_obj,
69
+ max_req_input_len,
70
+ ):
71
+ audio_data = request_obj.audio_data
72
+ if not image_data and not audio_data:
73
+ return None
74
+ if not isinstance(image_data, list):
75
+ image_data = [image_data]
76
+ if not isinstance(audio_data, list):
77
+ audio_data = [audio_data]
78
+
79
+ base_output = self.load_mm_data(
80
+ input_ids=input_ids,
81
+ max_req_input_len=max_req_input_len,
82
+ audio_data=audio_data,
83
+ image_data=image_data,
84
+ multimodal_tokens=MultimodalSpecialTokens(
85
+ image_token=self.image_token, audio_token=self.audio_token
86
+ ),
87
+ )
88
+ if base_output is None:
89
+ return None
90
+
91
+ res = await self._process_data(
92
+ images=base_output.images,
93
+ input_text=base_output.input_text,
94
+ audios=base_output.audios,
95
+ )
96
+
97
+ # Collect special token ids
98
+ tokenizer = self._processor.tokenizer
99
+ slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
100
+ None,
101
+ None,
102
+ None,
103
+ None,
104
+ )
105
+ if tokenizer.slice_start_id:
106
+ slice_start_id = tokenizer.slice_start_id
107
+ slice_end_id = tokenizer.slice_end_id
108
+ if hasattr(tokenizer, "audio_start_id"):
109
+ audio_start_id = tokenizer.audio_start_id
110
+ audio_end_id = tokenizer.audio_end_id
111
+
112
+ im_token_id = tokenizer.unk_token_id
113
+ pixel_values = res["pixel_values"]
114
+ tgt_sizes = res["tgt_sizes"]
115
+
116
+ if not isinstance(pixel_values, (torch.Tensor, list)):
117
+ raise ValueError(
118
+ "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
119
+ )
120
+
121
+ if not isinstance(tgt_sizes, (torch.Tensor, list)):
122
+ raise ValueError(
123
+ "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
124
+ )
125
+
126
+ if len(pixel_values) != len(tgt_sizes):
127
+ raise ValueError(
128
+ "Inconsistent batch lengths, found: "
129
+ f"{len(pixel_values)} vs. {len(tgt_sizes)}"
130
+ )
131
+
132
+ pixel_values_flat: List[torch.Tensor] = []
133
+ tgt_sizes_flat: List[torch.Tensor] = []
134
+ for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
135
+ # per image
136
+ if len(pixel_b) != len(tgt_b):
137
+ raise ValueError(
138
+ "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
139
+ )
140
+ for pixel_n, tgt_n in zip(pixel_b, tgt_b):
141
+ pixel_values_flat += [pixel_n]
142
+ tgt_sizes_flat += [tgt_n]
143
+
144
+ pixel_values = pixel_values_flat
145
+ if len(tgt_sizes_flat) == 0:
146
+ tgt_sizes = None
147
+ else:
148
+ tgt_sizes = torch.stack(tgt_sizes_flat)
149
+ if not isinstance(res["audio_features"], list):
150
+ res["audio_features"] = [res["audio_features"]]
151
+ return {
152
+ "input_ids": res["input_ids"].flatten().tolist(),
153
+ "pixel_values": pixel_values,
154
+ "tgt_sizes": tgt_sizes,
155
+ "data_hashes": base_output.mm_data_hashes,
156
+ "modalities": request_obj.modalities or ["image"],
157
+ "audio_start_id": audio_start_id,
158
+ "audio_end_id": audio_end_id,
159
+ "audio_features": res["audio_features"],
160
+ "audio_bounds": res["audio_bounds"],
161
+ "audio_feature_lens": res["audio_feature_lens"],
162
+ "im_token_id": im_token_id,
163
+ "im_start_id": tokenizer.im_start_id,
164
+ "im_end_id": tokenizer.im_end_id,
165
+ "slice_start_id": slice_start_id,
166
+ "slice_end_id": slice_end_id,
167
+ }
@@ -1,15 +1,17 @@
1
1
  import asyncio
2
2
  from typing import List, Union
3
3
 
4
- from sglang.srt.managers.image_processor import BaseImageProcessor
5
- from sglang.srt.managers.image_processors.base_image_processor import (
4
+ from sglang.srt.managers.multimodal_processors.base_processor import (
5
+ BaseMultimodalProcessor,
6
6
  get_global_processor,
7
7
  )
8
8
  from sglang.srt.models.mllama import MllamaForConditionalGeneration
9
9
  from sglang.srt.utils import load_image
10
10
 
11
11
 
12
- class MllamaImageProcessor(BaseImageProcessor):
12
+ class MllamaImageProcessor(BaseMultimodalProcessor):
13
+ models = [MllamaForConditionalGeneration]
14
+
13
15
  def __init__(self, hf_config, server_args, _processor):
14
16
  super().__init__(hf_config, server_args, _processor)
15
17
 
@@ -32,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor):
32
34
 
33
35
  return image_inputs
34
36
 
35
- async def process_images_async(
37
+ async def process_mm_data_async(
36
38
  self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
37
39
  ):
38
40
  if not image_data:
@@ -51,10 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor):
51
53
  images = load_image(image_data[0])[0]
52
54
 
53
55
  image_inputs = await self._process_single_image(images, input_text)
54
- image_inputs["image_hashes"] = [hash(str(image_data))]
56
+ image_inputs["data_hashes"] = [hash(str(image_data))]
55
57
  image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
56
58
 
57
59
  return image_inputs
58
-
59
-
60
- ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}
@@ -1,11 +1,16 @@
1
1
  import asyncio
2
2
  import math
3
+ import time
3
4
  from typing import List, Union
4
5
 
6
+ import torch
5
7
  from PIL import Image
6
8
 
7
- from sglang.srt.managers.image_processor import BaseImageProcessor
8
- from sglang.srt.managers.image_processors.base_image_processor import (
9
+ from sglang.srt.managers.multimodal_processor import (
10
+ BaseMultimodalProcessor as SGLangBaseProcessor,
11
+ )
12
+ from sglang.srt.managers.multimodal_processors.base_processor import (
13
+ MultimodalSpecialTokens,
9
14
  get_global_processor,
10
15
  )
11
16
  from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
@@ -13,7 +18,9 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
13
18
 
14
19
 
15
20
  # Compatible with Qwen2VL and Qwen2_5VL
16
- class Qwen2_5VLImageProcessor(BaseImageProcessor):
21
+ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
22
+ models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
23
+
17
24
  def __init__(self, hf_config, server_args, _processor):
18
25
  super().__init__(hf_config, server_args, _processor)
19
26
  self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
@@ -25,7 +32,6 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
25
32
  self.IMAGE_FACTOR = 28
26
33
  self.MIN_PIXELS = 4 * 28 * 28
27
34
  self.MAX_PIXELS = 16384 * 28 * 28
28
- self.MAX_PIXELS = 16384 * 28 * 28
29
35
  self.MAX_RATIO = 200
30
36
 
31
37
  @staticmethod
@@ -44,7 +50,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
44
50
  "video_grid_thws": getattr(result, "video_grid_thws", None),
45
51
  }
46
52
 
47
- async def _process_images(self, images, input_text) -> dict:
53
+ async def _process_single_image(self, images, input_text) -> dict:
48
54
  if self.executor is not None:
49
55
  loop = asyncio.get_event_loop()
50
56
  return await loop.run_in_executor(
@@ -57,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
57
63
  else:
58
64
  return self._process_images_task(images, input_text, self.hf_config)
59
65
 
60
- async def process_images_async(
66
+ async def process_mm_data_async(
61
67
  self,
62
68
  image_data: List[Union[str, bytes]],
63
69
  input_ids,
@@ -66,17 +72,18 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
66
72
  *args,
67
73
  **kwargs,
68
74
  ):
75
+ start = time.time()
69
76
  if not image_data:
70
77
  return None
71
78
  if isinstance(image_data, str):
72
79
  image_data = [image_data]
73
80
 
74
81
  image_token = self.IMAGE_TOKEN
75
- base_output = self.load_images(
76
- input_ids,
77
- image_data,
78
- image_token,
79
- max_req_input_len,
82
+ base_output = self.load_mm_data(
83
+ input_ids=input_ids,
84
+ image_data=image_data,
85
+ multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
86
+ max_req_input_len=max_req_input_len,
80
87
  )
81
88
 
82
89
  def smart_resize(
@@ -137,25 +144,24 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
137
144
  """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
138
145
  return math.floor(number / factor) * factor
139
146
 
140
- images = [resize_image(image) for image in base_output.all_frames]
147
+ images = [resize_image(image) for image in base_output.images]
141
148
 
142
- ret = await self._process_images(images, base_output.input_text)
149
+ ret = await self._process_single_image(
150
+ images=images, input_text=base_output.input_text
151
+ )
152
+
153
+ image_grid_thws = torch.concat([ret["image_grid_thw"]])
154
+ video_grid_thws = None
143
155
  return {
144
156
  "input_ids": ret["input_ids"].flatten().tolist(),
145
157
  "pixel_values": ret["pixel_values"],
146
- "image_hashes": base_output.image_hashes,
158
+ "data_hashes": base_output.mm_data_hashes,
147
159
  "modalities": request_obj.modalities or ["image"],
148
- "image_grid_thws": ret["image_grid_thw"],
149
- "video_grid_thws": ret["video_grid_thws"],
160
+ "image_grid_thws": image_grid_thws,
161
+ "video_grid_thws": video_grid_thws,
150
162
  "im_start_id": self.IM_START_TOKEN_ID,
151
163
  "im_end_id": self.IM_END_TOKEN_ID,
152
164
  "im_token_id": self.image_token_id,
153
165
  "video_token_id": self.video_token_id,
154
166
  "second_per_grid_ts": ret["second_per_grid_ts"],
155
167
  }
156
-
157
-
158
- ImageProcessorMapping = {
159
- Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
160
- Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
161
- }
@@ -42,6 +42,8 @@ import triton.language as tl
42
42
  from sglang.global_config import global_config
43
43
  from sglang.srt.configs.model_config import ModelConfig
44
44
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
45
+ from sglang.srt.disaggregation.conn import KVSender
46
+ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
45
47
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
46
48
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
47
49
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
@@ -49,7 +51,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
49
51
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
50
52
  from sglang.srt.sampling.sampling_params import SamplingParams
51
53
  from sglang.srt.server_args import ServerArgs
52
- from sglang.srt.utils import get_compiler_backend, next_power_of_2
54
+ from sglang.srt.utils import get_compiler_backend
53
55
 
54
56
  if TYPE_CHECKING:
55
57
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -67,12 +69,15 @@ global_server_args_dict = {
67
69
  "enable_nan_detection": ServerArgs.enable_nan_detection,
68
70
  "enable_dp_attention": ServerArgs.enable_dp_attention,
69
71
  "enable_ep_moe": ServerArgs.enable_ep_moe,
72
+ "enable_deepep_moe": ServerArgs.enable_deepep_moe,
70
73
  "device": ServerArgs.device,
71
74
  "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
72
75
  "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
73
76
  "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
77
+ "enable_flashmla": ServerArgs.enable_flashmla,
74
78
  "disable_radix_cache": ServerArgs.disable_radix_cache,
75
79
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
80
+ "chunked_prefill_size": ServerArgs.chunked_prefill_size,
76
81
  }
77
82
 
78
83
  logger = logging.getLogger(__name__)
@@ -139,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason):
139
144
 
140
145
 
141
146
  @dataclasses.dataclass
142
- class ImageInputs:
147
+ class MultimodalInputs:
143
148
  """The image related inputs."""
144
149
 
145
150
  pixel_values: Union[torch.Tensor, np.array]
146
- image_hashes: Optional[list] = None
151
+ data_hashes: Optional[list] = None
147
152
  image_sizes: Optional[list] = None
148
153
  image_offsets: Optional[list] = None
149
154
  image_pad_len: Optional[list] = None
@@ -156,34 +161,48 @@ class ImageInputs:
156
161
  aspect_ratio_mask: Optional[List[torch.Tensor]] = None
157
162
 
158
163
  # QWen2-VL related
159
- image_grid_thws: List[Tuple[int, int, int]] = None
164
+ # [num_of_images, t, h, w]
165
+ image_grid_thws: torch.Tensor = None
160
166
  mrope_position_delta: Optional[torch.Tensor] = None
167
+ # Qwen2-VL video related
168
+ video_token_id: Optional[int] = None
169
+ video_grid_thws: List[Tuple[int, int, int]] = None
170
+ second_per_grid_ts: Optional[List[torch.Tensor]] = None
171
+
172
+ # deepseek vl2 related
173
+ images_emb_mask: Optional[List[torch.Tensor]] = None
174
+ image_spatial_crop: Optional[List[torch.Tensor]] = None
161
175
 
162
176
  # The id of the single-image placeholder token
163
177
  im_token_id: Optional[torch.Tensor] = None
178
+
164
179
  # All the images in the batch should share the same special image
165
180
  # bound token ids.
166
181
  im_start_id: Optional[int] = None
167
182
  im_end_id: Optional[int] = None
168
183
  slice_start_id: Optional[int] = None
169
184
  slice_end_id: Optional[int] = None
185
+ # [num_images, 2 (w, h)]
170
186
  tgt_sizes: Optional[list] = None
171
187
 
172
- # denotes the number of valid image tokens in each image
173
- images_emb_mask: Optional[torch.BoolTensor] = None
188
+ # audio
189
+ audio_start_id: Optional[torch.Tensor] = None
190
+ audio_end_id: Optional[torch.Tensor] = None
191
+ audio_features: Optional[List[torch.Tensor]] = None
192
+ audio_feature_lens: Optional[List[torch.Tensor]] = None
174
193
 
175
194
  @staticmethod
176
195
  def from_dict(obj: dict):
177
- ret = ImageInputs(
196
+ ret = MultimodalInputs(
178
197
  pixel_values=obj["pixel_values"],
179
- image_hashes=obj["image_hashes"],
198
+ data_hashes=obj["data_hashes"],
180
199
  )
181
200
 
182
201
  # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
183
202
  # Please note that if the `input_ids` is later used in the model forward,
184
203
  # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
185
204
  # errors in cuda kernels. See also llava.py for example.
186
- ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
205
+ ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
187
206
 
188
207
  optional_args = [
189
208
  "image_sizes",
@@ -191,43 +210,104 @@ class ImageInputs:
191
210
  "aspect_ratio_ids",
192
211
  "aspect_ratio_mask",
193
212
  "image_grid_thws",
213
+ "images_emb_mask",
214
+ "image_spatial_crop",
194
215
  "im_token_id",
195
216
  "im_start_id",
196
217
  "im_end_id",
197
218
  "slice_start_id",
198
219
  "slice_end_id",
199
220
  "tgt_sizes",
200
- "images_emb_mask",
221
+ "audio_start_id",
222
+ "audio_end_id",
223
+ "audio_features",
224
+ "audio_feature_lens",
201
225
  ]
202
226
  for arg in optional_args:
203
227
  if arg in obj:
204
228
  setattr(ret, arg, obj[arg])
205
229
 
230
+ # validate
231
+ assert (
232
+ isinstance(ret.pixel_values, torch.Tensor)
233
+ or isinstance(ret.pixel_values, np.ndarray)
234
+ or isinstance(ret.pixel_values, list)
235
+ )
236
+
237
+ assert ret.audio_features is None or isinstance(ret.audio_features, list)
238
+
206
239
  return ret
207
240
 
208
- def merge(self, other):
209
- assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
210
- self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
241
+ def contains_image_inputs(self) -> bool:
242
+ """ """
243
+ return self.pixel_values is not None and self.pixel_values != []
244
+
245
+ def contains_audio_inputs(self) -> bool:
246
+ """ """
247
+ return self.audio_features is not None and self.audio_features != []
248
+
249
+ def merge(self, other: MultimodalInputs):
250
+ """
251
+ merge image inputs when requests are being merged
252
+ """
253
+ if isinstance(self.pixel_values, list):
254
+ # in some rare cases, pixel values are list of patches with different shapes
255
+ # e.g. minicpm
256
+ self.pixel_values += other.pixel_values
257
+ else:
258
+ assert (
259
+ self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
260
+ ), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
261
+ self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
262
+
263
+ # args would be stacked along first dim
264
+ # usually these are already tensors
265
+ stack_args = [
266
+ # TODO: merge with image_grid_thws, basically the same thing
267
+ "tgt_sizes",
268
+ "image_spatial_crop",
269
+ ]
270
+ for arg in stack_args:
271
+ if getattr(self, arg, None) is None:
272
+ setattr(self, arg, getattr(other, arg, None))
273
+ elif getattr(other, arg, None) is not None:
274
+ # self and other both not None
275
+ setattr(
276
+ self,
277
+ arg,
278
+ torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
279
+ )
280
+
281
+ if self.image_grid_thws is None:
282
+ self.image_grid_thws = other.image_grid_thws
283
+ elif other.image_grid_thws is not None:
284
+ self.image_grid_thws = torch.concat(
285
+ [self.image_grid_thws, other.image_grid_thws]
286
+ )
211
287
 
212
288
  # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
213
289
  # Please note that if the `input_ids` is later used in the model forward,
214
290
  # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
215
291
  # errors in cuda kernels. See also llava.py for example.
216
- self.image_hashes += other.image_hashes
217
- self.pad_values = [x % (1 << 30) for x in self.image_hashes]
292
+ self.data_hashes += other.data_hashes
293
+ self.pad_values = [x % (1 << 30) for x in self.data_hashes]
218
294
 
295
+ # args needed to be merged
219
296
  optional_args = [
297
+ "audio_features",
220
298
  "image_sizes",
221
299
  "image_offsets",
222
300
  "image_pad_len",
223
301
  # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
224
302
  "aspect_ratio_ids",
225
303
  "aspect_ratio_mask",
226
- "image_grid_thws",
304
+ "images_emb_mask",
227
305
  ]
228
306
  for arg in optional_args:
229
- if getattr(self, arg, None) is not None:
230
- setattr(self, arg, getattr(self, arg) + getattr(other, arg))
307
+ self_arg = getattr(self, arg, None)
308
+ if self_arg is not None:
309
+ setattr(self, arg, self_arg + getattr(other, arg))
310
+ # other args would be kept intact
231
311
 
232
312
 
233
313
  class Req:
@@ -305,7 +385,7 @@ class Req:
305
385
  self.decoded_text = ""
306
386
 
307
387
  # For multimodal inputs
308
- self.image_inputs: Optional[ImageInputs] = None
388
+ self.multimodal_inputs: Optional[MultimodalInputs] = None
309
389
 
310
390
  # Prefix info
311
391
  # The indices to kv cache for the shared prefix.
@@ -378,15 +458,33 @@ class Req:
378
458
  self.spec_verify_ct = 0
379
459
  self.lora_path = lora_path
380
460
 
461
+ # For disaggregation
462
+ self.bootstrap_host: str = "0.0.0.0"
463
+ self.bootstrap_room: Optional[int] = None
464
+ self.disagg_kv_sender: Optional[KVSender] = None
465
+
466
+ # used for warmup because we don't have a pair yet when init
467
+ self.skip_kv_transfer: bool = False
468
+ # the start index of the sent kv cache
469
+ # We want to send it chunk by chunk for chunked prefill.
470
+ # After every chunk forward, we do the following:
471
+ # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
472
+ # start_send_idx = len(req.fill_ids)
473
+ self.start_send_idx: int = 0
474
+
475
+ self.metadata_buffer_index: int = -1
476
+ # The first output_id transferred from prefill instance.
477
+ self.transferred_output_id: Optional[int] = None
478
+
381
479
  @property
382
480
  def seqlen(self):
383
481
  return len(self.origin_input_ids) + len(self.output_ids)
384
482
 
385
483
  def extend_image_inputs(self, image_inputs):
386
- if self.image_inputs is None:
387
- self.image_inputs = image_inputs
484
+ if self.multimodal_inputs is None:
485
+ self.multimodal_inputs = image_inputs
388
486
  else:
389
- self.image_inputs.merge(image_inputs)
487
+ self.multimodal_inputs.merge(image_inputs)
390
488
 
391
489
  def finished(self) -> bool:
392
490
  # Whether request reached finished condition
@@ -513,7 +611,7 @@ bid = 0
513
611
 
514
612
 
515
613
  @dataclasses.dataclass
516
- class ScheduleBatch:
614
+ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
517
615
  """Store all information of a batch on the scheduler."""
518
616
 
519
617
  # Request, memory pool, and cache
@@ -727,7 +825,7 @@ class ScheduleBatch:
727
825
  self.encoder_cached = []
728
826
 
729
827
  for req in self.reqs:
730
- im = req.image_inputs
828
+ im = req.multimodal_inputs
731
829
  if im is None or im.num_image_tokens is None:
732
830
  # No image input
733
831
  self.encoder_lens_cpu.append(0)
@@ -840,6 +938,8 @@ class ScheduleBatch:
840
938
  # If req.input_embeds is already a list, append its content directly
841
939
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
842
940
 
941
+ if req.is_retracted:
942
+ req.already_computed = 0
843
943
  req.cached_tokens += pre_len - req.already_computed
844
944
  req.already_computed = seq_len
845
945
  req.is_retracted = False
@@ -1244,14 +1344,14 @@ class ScheduleBatch:
1244
1344
  self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
1245
1345
  self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1246
1346
 
1247
- self.req_pool_indices = torch.concat(
1347
+ self.req_pool_indices = torch.cat(
1248
1348
  [self.req_pool_indices, other.req_pool_indices]
1249
1349
  )
1250
- self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1350
+ self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1251
1351
  self.out_cache_loc = None
1252
1352
  self.seq_lens_sum += other.seq_lens_sum
1253
1353
  if self.output_ids is not None:
1254
- self.output_ids = torch.concat([self.output_ids, other.output_ids])
1354
+ self.output_ids = torch.cat([self.output_ids, other.output_ids])
1255
1355
  if self.return_logprob and other.return_logprob:
1256
1356
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
1257
1357
  self.token_ids_logprobs.extend(other.token_ids_logprobs)
@@ -1273,7 +1373,10 @@ class ScheduleBatch:
1273
1373
 
1274
1374
  def get_model_worker_batch(self) -> ModelWorkerBatch:
1275
1375
  if self.forward_mode.is_decode_or_idle():
1276
- if global_server_args_dict["enable_flashinfer_mla"]:
1376
+ if (
1377
+ global_server_args_dict["enable_flashinfer_mla"]
1378
+ or global_server_args_dict["enable_flashmla"]
1379
+ ):
1277
1380
  decode_seq_lens = self.seq_lens.cpu()
1278
1381
  else:
1279
1382
  decode_seq_lens = None
@@ -1311,7 +1414,7 @@ class ScheduleBatch:
1311
1414
  extend_seq_lens=extend_seq_lens,
1312
1415
  extend_prefix_lens=extend_prefix_lens,
1313
1416
  extend_logprob_start_lens=extend_logprob_start_lens,
1314
- image_inputs=[r.image_inputs for r in self.reqs],
1417
+ multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1315
1418
  encoder_cached=self.encoder_cached,
1316
1419
  encoder_lens=self.encoder_lens,
1317
1420
  encoder_lens_cpu=self.encoder_lens_cpu,
@@ -1394,7 +1497,7 @@ class ModelWorkerBatch:
1394
1497
  extend_input_logprob_token_ids: Optional[torch.Tensor]
1395
1498
 
1396
1499
  # For multimodal
1397
- image_inputs: Optional[List[ImageInputs]]
1500
+ multimodal_inputs: Optional[List[MultimodalInputs]]
1398
1501
 
1399
1502
  # For encoder-decoder
1400
1503
  encoder_cached: Optional[List[bool]]