sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,14 @@ import dataclasses
4
4
  import multiprocessing as mp
5
5
  import os
6
6
  from abc import ABC, abstractmethod
7
- from typing import Optional
7
+ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
- from decord import VideoReader, cpu
12
- from PIL import Image
11
+ from transformers import BaseImageProcessorFast
13
12
 
14
- from sglang.srt.utils import encode_video, load_audio, load_image, logger
13
+ from sglang.srt.managers.schedule_batch import Modality
14
+ from sglang.srt.utils import encode_video, load_audio, load_image
15
15
 
16
16
 
17
17
  @dataclasses.dataclass
@@ -78,6 +78,10 @@ class BaseMultimodalProcessor(ABC):
78
78
  kwargs["audios"] = audios
79
79
 
80
80
  processor = self._processor
81
+ if hasattr(processor, "image_processor") and isinstance(
82
+ processor.image_processor, BaseImageProcessorFast
83
+ ):
84
+ kwargs["device"] = "cuda"
81
85
  result = processor.__call__(
82
86
  text=[input_text],
83
87
  padding=True,
@@ -96,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
96
100
  """
97
101
  estimate the total frame count from all visual input
98
102
  """
103
+ # Lazy import because decord is not available on some arm platforms.
104
+ from decord import VideoReader, cpu
105
+
99
106
  # Before processing inputs
100
107
  estimated_frames_list = []
101
108
  for image in image_data:
@@ -111,6 +118,84 @@ class BaseMultimodalProcessor(ABC):
111
118
 
112
119
  return estimated_frames_list
113
120
 
121
+ @staticmethod
122
+ def _load_single_item(
123
+ data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
124
+ ):
125
+ """Static method that can be pickled for multiprocessing"""
126
+ try:
127
+ if is_audio:
128
+ return load_audio(data)
129
+ elif is_video:
130
+ path = data[len("video:") :]
131
+ return encode_video(path, frame_count_limit)
132
+ else:
133
+ img, _ = load_image(data)
134
+ return img.convert("RGB") if discard_alpha_channel else img
135
+ except Exception as e:
136
+ raise RuntimeError(f"Error while loading data {data}: {e}")
137
+
138
+ def submit_data_loading_tasks(
139
+ self,
140
+ text_parts: List[str],
141
+ multimodal_tokens: MultimodalSpecialTokens,
142
+ image_data: Optional[list] = None,
143
+ audio_data: Optional[list] = None,
144
+ discard_alpha_channel: bool = True,
145
+ ):
146
+ """
147
+ load multimodal data parallelly
148
+ """
149
+
150
+ # TODO(mick): load from server_args, env, or sampling_params
151
+ MAX_NUM_FRAMES = 30
152
+ estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
153
+ total_frame_count = sum(estimated_frames_list)
154
+ # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
155
+ # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
156
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
157
+
158
+ assert len(image_data) == len(estimated_frames_list)
159
+ # Submit all tasks
160
+ futures = []
161
+ task_info = []
162
+ image_index, audio_index = 0, 0
163
+
164
+ for text_part in text_parts:
165
+ if text_part == multimodal_tokens.image_token:
166
+ data = image_data[image_index]
167
+ is_video = isinstance(data, str) and data.startswith("video:")
168
+ estimated_frames = estimated_frames_list[image_index]
169
+ frame_count_limit = max(1, int(estimated_frames * scaling_factor))
170
+ futures.append(
171
+ self.io_executor.submit(
172
+ BaseMultimodalProcessor._load_single_item,
173
+ data,
174
+ is_video,
175
+ False,
176
+ frame_count_limit,
177
+ discard_alpha_channel,
178
+ )
179
+ )
180
+ task_info.append((Modality.IMAGE, data, frame_count_limit))
181
+ image_index += 1
182
+ elif text_part == multimodal_tokens.audio_token:
183
+ data = audio_data[audio_index]
184
+ futures.append(
185
+ self.io_executor.submit(
186
+ BaseMultimodalProcessor._load_single_item,
187
+ data,
188
+ False,
189
+ True,
190
+ None,
191
+ discard_alpha_channel,
192
+ )
193
+ )
194
+ task_info.append((Modality.AUDIO, data, None))
195
+ audio_index += 1
196
+
197
+ return futures, task_info
198
+
114
199
  def load_mm_data(
115
200
  self,
116
201
  prompt: str,
@@ -155,84 +240,37 @@ class BaseMultimodalProcessor(ABC):
155
240
  # split text into list of normal text and special tokens
156
241
  text_parts = re.split(pattern, prompt)
157
242
 
158
- # TODO(mick): load from server_args, env, or sampling_params
159
- MAX_NUM_FRAMES = 30
160
- estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
161
- total_frame_count = sum(estimated_frames_list)
162
- # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
163
- # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
164
- scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
165
-
166
- assert len(image_data) == len(estimated_frames_list)
167
-
168
- image_index, audio_index = 0, 0
169
- hashes, image_sizes, images, audios = [], [], [], []
243
+ futures, task_info = self.submit_data_loading_tasks(
244
+ text_parts=text_parts,
245
+ multimodal_tokens=multimodal_tokens,
246
+ image_data=image_data,
247
+ audio_data=audio_data,
248
+ discard_alpha_channel=discard_alpha_channel,
249
+ )
250
+ # Process results
251
+ image_sizes, images, audios = [], [], []
170
252
  new_text = ""
171
- for index, text_part in enumerate(text_parts):
172
- try:
173
- if text_part == multimodal_tokens.image_token:
174
- # load as image
175
- if len(images) >= MAX_NUM_FRAMES:
176
- frames_to_process = 0
177
- else:
178
- estimated_frames = estimated_frames_list[image_index]
179
- frames_to_process = max(
180
- 1, int(estimated_frames * scaling_factor)
181
- )
182
-
183
- if frames_to_process == 0:
184
- frames = []
185
- else:
186
- image_file = image_data[image_index]
187
- if isinstance(image_file, str) and image_file.startswith(
188
- "video:"
189
- ):
190
- # video
191
- path = image_file[len("video:") :]
192
- frames = encode_video(
193
- path, frame_count_limit=frames_to_process
194
- )
195
- else:
196
- # image
197
- raw_image, _size = load_image(image_file)
198
- if discard_alpha_channel:
199
- raw_image = raw_image.convert("RGB")
200
- frames = [raw_image]
201
- if len(frames) == 0:
202
- continue
203
-
204
- image_sizes += frames[0].size * len(frames)
205
-
206
- # Generate a hashable value for the image file
207
- if isinstance(image_file, Image.Image):
208
- # For PIL.Image objects, use the ID as a hashable value
209
- hash_value = hash(id(image_file))
210
- else:
211
- # For other types (strings, etc.), use the regular hash
212
- hash_value = hash(image_file)
213
-
214
- hashes += [hash_value] * len(frames)
215
- images += frames
216
- image_index += 1
217
- if frames_to_process != 0:
253
+ task_ptr = 0
254
+
255
+ for text_part in text_parts:
256
+ if text_part in multimodal_tokens.collect():
257
+ task_type, data, frame_limit = task_info[task_ptr]
258
+ result = futures[task_ptr].result()
259
+ task_ptr += 1
260
+
261
+ if task_type == Modality.IMAGE:
262
+ frames = [result] if not isinstance(result, list) else result
263
+ if frames:
264
+ image_sizes += frames[0].size * len(frames)
265
+ images += frames
218
266
  new_text += multimodal_tokens.image_token * len(frames)
219
- assert frames_to_process == len(frames)
220
- elif text_part == multimodal_tokens.audio_token:
221
- # load as audio
222
- audio_file = audio_data[audio_index]
223
- audio = load_audio(audio_file)
224
- hashes += [hash(audio_file)]
225
- audios += [audio]
226
- audio_index += 1
267
+ elif task_type == Modality.AUDIO:
268
+ # audio
269
+ audios.append(result)
227
270
  new_text += multimodal_tokens.audio_token
228
- else:
229
- # TODO(mick): handle video
230
- # normal text
231
- new_text += text_part
232
-
233
- except Exception as e:
234
- logger.error(f"An exception occurred while loading images: {e}")
235
- raise RuntimeError(f"An exception occurred while loading images: {e}")
271
+ # TODO: handle video
272
+ else:
273
+ new_text += text_part
236
274
 
237
275
  out = BaseMultiModalProcessorOutput(
238
276
  images=images,
@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
33
33
  base_out = self.load_mm_data(
34
34
  prompt=input_ids,
35
35
  image_data=image_data,
36
- multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
36
+ multimodal_tokens=MultimodalSpecialTokens(
37
+ image_token=processor.image_token
38
+ ),
37
39
  max_req_input_len=max_req_input_len,
38
40
  )
39
41
 
@@ -1,10 +1,8 @@
1
- from typing import List, Mapping, Optional, Tuple, Union
1
+ from typing import List, Union
2
2
 
3
3
  import torch
4
- from PIL import Image
5
- from transformers import Llama4Processor
6
4
  from transformers.image_utils import SizeDict
7
- from transformers.models.llama4.image_processing_llama4 import (
5
+ from transformers.models.llama4.image_processing_llama4_fast import (
8
6
  find_supported_resolutions,
9
7
  get_best_fit,
10
8
  )
@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
15
13
  )
16
14
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
17
15
  from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
18
- from sglang.srt.utils import load_image
19
16
 
20
17
 
21
18
  class Mllama4ImageProcessor(BaseMultimodalProcessor):
@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
25
22
  super().__init__(hf_config, server_args, _processor)
26
23
  self.vision_config = hf_config.vision_config
27
24
  self.text_config = hf_config.text_config
25
+ self.boi_token_index = hf_config.boi_token_index
26
+ self.eoi_token_index = hf_config.eoi_token_index
27
+ self.image_token_index = hf_config.image_token_index
28
28
  self.multimodal_tokens = MultimodalSpecialTokens(
29
29
  image_token=_processor.image_token
30
30
  )
@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
54
54
  )
55
55
 
56
56
  # Process the images using the processor
57
- processor = Llama4Processor.from_pretrained(
58
- self.server_args.model_path, **kwargs
59
- )
57
+ processor = self._processor
60
58
 
61
59
  # Process the prompt and images
62
- image_inputs = processor(
63
- text=processed_data.input_text,
60
+ processor_output = self.process_mm_data(
61
+ input_text=processed_data.input_text,
64
62
  images=processed_data.images,
65
- return_tensors="pt",
66
63
  )
67
64
 
68
65
  # Handle image resolutions and aspect ratios
69
- if "pixel_values" in image_inputs:
66
+ if "pixel_values" in processor_output:
70
67
  image_processor = processor.image_processor
71
68
  tokenizer = self._processor.tokenizer
72
69
 
@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
100
97
  ]
101
98
 
102
99
  # Add to image_inputs
103
- image_inputs["aspect_ratios"] = aspect_ratios
104
- image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
100
+ processor_output["aspect_ratios"] = aspect_ratios
101
+ processor_output["patches_per_image"] = torch.tensor(patches_per_image)
105
102
 
106
103
  # Process embed_is_patch
107
104
  vocab = tokenizer.get_vocab()
@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
109
106
  image_end_id = vocab.get(processor.end_of_img_token, -1)
110
107
 
111
108
  if patch_id != -1 and image_end_id != -1:
112
- input_ids = image_inputs["input_ids"].view(-1)
109
+ input_ids = processor_output["input_ids"].view(-1)
113
110
 
114
111
  # Remove BOS token if present
115
112
  if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
129
126
  for per_image_input_ids in split_input_ids:
130
127
  embed_is_patch.append(per_image_input_ids == patch_id)
131
128
 
132
- image_inputs["embed_is_patch"] = embed_is_patch
129
+ processor_output["embed_is_patch"] = embed_is_patch
133
130
 
134
131
  # Convert to the format expected by SGLang
135
- image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
132
+ processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
133
+
134
+ processor_output["im_start_id"] = self.boi_token_index
135
+ processor_output["im_end_id"] = self.eoi_token_index
136
+ processor_output["im_token_id"] = self.image_token_index
136
137
 
137
138
  # Add metadata for image processing
138
- image_inputs["mm_items"] = [
139
+ processor_output["mm_items"] = [
139
140
  MultimodalDataItem(
140
- pixel_values=image_inputs["pixel_values"],
141
+ pixel_values=processor_output["pixel_values"],
141
142
  modality=Modality.IMAGE,
142
- # Add additional metadata needed for Llama4 vision processing
143
- embed_is_patch=image_inputs.get("embed_is_patch", None),
144
- aspect_ratios=image_inputs.get("aspect_ratios", None),
145
- patches_per_image=image_inputs.get("patches_per_image", None),
146
143
  )
147
144
  ]
148
145
 
149
- return image_inputs
150
-
151
- def get_patch_per_chunk(self):
152
- """Calculate patches per chunk based on vision config"""
153
- image_size = self.vision_config.image_size
154
- patch_size = self.vision_config.patch_size
155
-
156
- assert (
157
- image_size % patch_size == 0
158
- ), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
159
-
160
- ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
161
- return (image_size // patch_size) ** 2 // ds_ratio
146
+ return processor_output
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import hashlib
3
4
  from enum import Enum, auto
4
5
 
5
6
  # Copyright 2023-2024 SGLang Team
@@ -44,7 +45,7 @@ import triton.language as tl
44
45
  from sglang.global_config import global_config
45
46
  from sglang.srt.configs.model_config import ModelConfig
46
47
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
47
- from sglang.srt.disaggregation.conn import KVSender
48
+ from sglang.srt.disaggregation.base import BaseKVSender
48
49
  from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
49
50
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
51
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
@@ -66,7 +67,6 @@ global_server_args_dict = {
66
67
  "attention_backend": ServerArgs.attention_backend,
67
68
  "sampling_backend": ServerArgs.sampling_backend,
68
69
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
69
- "disable_mla": ServerArgs.disable_mla,
70
70
  "torchao_config": ServerArgs.torchao_config,
71
71
  "enable_nan_detection": ServerArgs.enable_nan_detection,
72
72
  "enable_dp_attention": ServerArgs.enable_dp_attention,
@@ -76,12 +76,12 @@ global_server_args_dict = {
76
76
  "device": ServerArgs.device,
77
77
  "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
78
78
  "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
79
- "enable_flashmla": ServerArgs.enable_flashmla,
80
79
  "disable_radix_cache": ServerArgs.disable_radix_cache,
81
80
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
81
+ "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
82
82
  "chunked_prefill_size": ServerArgs.chunked_prefill_size,
83
83
  "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
84
- "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
84
+ "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
85
85
  }
86
86
 
87
87
  logger = logging.getLogger(__name__)
@@ -157,7 +157,7 @@ class Modality(Enum):
157
157
  @dataclasses.dataclass
158
158
  class MultimodalDataItem:
159
159
  """
160
- A single multimodal data, from a single image/video/audio or other
160
+ A single multimodal data, from a single image/video/audio or others
161
161
  """
162
162
 
163
163
  modality: Modality
@@ -195,17 +195,54 @@ class MultimodalDataItem:
195
195
 
196
196
  def set_pad_value(self):
197
197
  """
198
- Set the pad value after first hashign the data
198
+ Set the pad value after first hashing the data
199
199
  """
200
200
 
201
+ def data_hash(data) -> int:
202
+ hash_bytes = hashlib.sha256(data).digest()[:8]
203
+ return int.from_bytes(hash_bytes, byteorder="big", signed=False)
204
+
205
+ def tensor_hash(tensor_list) -> int:
206
+ """
207
+ hash a tensor or a tensor list
208
+ """
209
+ tensor = tensor_list
210
+ if isinstance(tensor_list, list):
211
+ tensor_list = flatten_nested_list(tensor_list)
212
+ tensor_list = [
213
+ x.flatten() if isinstance(x, torch.Tensor) else x
214
+ for x in tensor_list
215
+ ]
216
+ tensor = torch.concat(tensor_list)
217
+
218
+ tensor = tensor.detach().contiguous()
219
+
220
+ if tensor.dtype == torch.bfloat16:
221
+ # memoryview() doesn't support PyTorch's BFloat16 dtype
222
+ tensor = tensor.float()
223
+
224
+ assert isinstance(tensor, torch.Tensor)
225
+ if tensor.is_cuda:
226
+ # TODO: improve this
227
+ tensor_cpu = tensor.cpu()
228
+ else:
229
+ tensor_cpu = tensor
230
+
231
+ mv = memoryview(tensor_cpu.numpy())
232
+ return data_hash(mv.tobytes())
233
+
201
234
  def hash_feature(f):
202
235
  if isinstance(f, list):
203
- return hash(tuple(flatten_nested_list(f)))
236
+ if isinstance(f[0], torch.Tensor):
237
+ return tensor_hash(f)
238
+ return data_hash(tuple(flatten_nested_list(f)))
204
239
  elif isinstance(f, np.ndarray):
205
240
  arr = np.ascontiguousarray(f)
206
241
  arr_bytes = arr.tobytes()
207
- return hash(arr_bytes)
208
- return hash(f)
242
+ return data_hash(arr_bytes)
243
+ elif isinstance(f, torch.Tensor):
244
+ return tensor_hash([f])
245
+ return data_hash(f)
209
246
 
210
247
  if self.is_audio():
211
248
  self.hash = hash_feature(self.audio_features)
@@ -230,6 +267,9 @@ class MultimodalDataItem:
230
267
  self.modality == Modality.VIDEO
231
268
  ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
232
269
 
270
+ def is_valid(self) -> bool:
271
+ return self.is_image() or self.is_video() or self.is_audio()
272
+
233
273
  def validate(self):
234
274
  ...
235
275
  # TODO
@@ -248,7 +288,7 @@ class MultimodalInputs:
248
288
  mrope_position_delta: Optional[torch.Tensor] = None
249
289
 
250
290
  # image
251
- im_token_id: Optional[torch.Tensor] = None
291
+ im_token_id: Optional[int] = None
252
292
  im_start_id: Optional[int] = None
253
293
  im_end_id: Optional[int] = None
254
294
  slice_start_id: Optional[int] = None
@@ -268,11 +308,7 @@ class MultimodalInputs:
268
308
  )
269
309
 
270
310
  assert isinstance(ret.mm_items, list)
271
- ret.mm_items = [
272
- item
273
- for item in ret.mm_items
274
- if item.is_audio() or item.is_image() or item.is_video()
275
- ]
311
+ ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
276
312
 
277
313
  assert len(ret.mm_items) != 0
278
314
 
@@ -284,7 +320,6 @@ class MultimodalInputs:
284
320
  item.set_pad_value()
285
321
 
286
322
  optional_args = [
287
- "modalities",
288
323
  "im_token_id",
289
324
  "im_start_id",
290
325
  "im_end_id",
@@ -307,8 +342,8 @@ class MultimodalInputs:
307
342
  """ """
308
343
  return any(item.is_audio() for item in self.mm_items)
309
344
 
310
- def collect_image_inputs(self) -> List[torch.Tensor]:
311
- return [item.pixel_values for item in self.mm_items if item.is_image()]
345
+ def contains_mm_input(self) -> bool:
346
+ return any(True for item in self.mm_items if item.is_valid())
312
347
 
313
348
  def merge(self, other: MultimodalInputs):
314
349
  """
@@ -322,10 +357,8 @@ class MultimodalInputs:
322
357
 
323
358
  # args needed to be merged
324
359
  optional_args = [
325
- "items",
326
- "image_offsets",
360
+ "mm_items",
327
361
  "image_pad_len",
328
- # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
329
362
  ]
330
363
  for arg in optional_args:
331
364
  self_arg = getattr(self, arg, None)
@@ -354,6 +387,8 @@ class Req:
354
387
  custom_logit_processor: Optional[str] = None,
355
388
  return_hidden_states: bool = False,
356
389
  eos_token_ids: Optional[Set[int]] = None,
390
+ bootstrap_host: Optional[str] = None,
391
+ bootstrap_room: Optional[int] = None,
357
392
  ):
358
393
  # Input and output info
359
394
  self.rid = rid
@@ -438,6 +473,10 @@ class Req:
438
473
  self.temp_scaled_logprobs = False
439
474
  self.top_p_normalized_logprobs = False
440
475
 
476
+ # Latency Breakdown
477
+ self.queue_time_start = None
478
+ self.queue_time_end = None
479
+
441
480
  # Logprobs (return values)
442
481
  self.input_token_logprobs_val: Optional[List[float]] = None
443
482
  self.input_token_logprobs_idx: Optional[List[int]] = None
@@ -483,9 +522,9 @@ class Req:
483
522
  self.lora_path = lora_path
484
523
 
485
524
  # For disaggregation
486
- self.bootstrap_host: str = "0.0.0.0"
487
- self.bootstrap_room: Optional[int] = None
488
- self.disagg_kv_sender: Optional[KVSender] = None
525
+ self.bootstrap_host: str = bootstrap_host
526
+ self.bootstrap_room: Optional[int] = bootstrap_room
527
+ self.disagg_kv_sender: Optional[BaseKVSender] = None
489
528
 
490
529
  # used for warmup because we don't have a pair yet when init
491
530
  self.skip_kv_transfer: bool = False
@@ -1440,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1440
1479
  global_server_args_dict["use_mla_backend"]
1441
1480
  and global_server_args_dict["attention_backend"] == "flashinfer"
1442
1481
  )
1443
- or global_server_args_dict["enable_flashmla"]
1482
+ or global_server_args_dict["attention_backend"] == "flashmla"
1444
1483
  or global_server_args_dict["attention_backend"] == "fa3"
1445
1484
  ):
1446
1485
  seq_lens_cpu = self.seq_lens.cpu()