sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -23,13 +23,16 @@ import psutil
23
23
  import setproctitle
24
24
  import zmq
25
25
 
26
+ from sglang.srt.disaggregation.utils import DisaggregationMode
26
27
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
27
28
  from sglang.srt.managers.io_struct import (
28
29
  TokenizedEmbeddingReqInput,
29
30
  TokenizedGenerateReqInput,
30
31
  )
32
+ from sglang.srt.managers.schedule_batch import Req
31
33
  from sglang.srt.managers.scheduler import run_scheduler_process
32
34
  from sglang.srt.server_args import PortArgs, ServerArgs
35
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
33
36
  from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
34
37
  from sglang.utils import get_exception_traceback
35
38
 
@@ -174,6 +177,10 @@ class DataParallelController:
174
177
  if not server_args.enable_dp_attention:
175
178
  logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
176
179
 
180
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
181
+ enable=server_args.enable_memory_saver
182
+ )
183
+
177
184
  # Launch tensor parallel scheduler processes
178
185
  scheduler_pipe_readers = []
179
186
  tp_size_per_node = server_args.tp_size // server_args.nnodes
@@ -208,7 +215,8 @@ class DataParallelController:
208
215
  target=run_scheduler_process,
209
216
  args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
210
217
  )
211
- proc.start()
218
+ with memory_saver_adapter.configure_subprocess():
219
+ proc.start()
212
220
  self.scheduler_procs.append(proc)
213
221
  scheduler_pipe_readers.append(reader)
214
222
 
@@ -220,9 +228,14 @@ class DataParallelController:
220
228
  self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
221
229
  self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
222
230
 
223
- def round_robin_scheduler(self, req):
224
- self.workers[self.round_robin_counter].send_pyobj(req)
225
- self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
231
+ def round_robin_scheduler(self, req: Req):
232
+ if self.server_args.disaggregation_mode == "null":
233
+ self.workers[self.round_robin_counter].send_pyobj(req)
234
+ self.round_robin_counter = (self.round_robin_counter + 1) % len(
235
+ self.workers
236
+ )
237
+ else:
238
+ self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
226
239
 
227
240
  def shortest_queue_scheduler(self, input_requests):
228
241
  raise NotImplementedError()
@@ -96,8 +96,9 @@ class GenerateReqInput:
96
96
  return_hidden_states: bool = False
97
97
 
98
98
  # For disaggregated inference
99
- bootstrap_host: Optional[str] = None
100
- bootstrap_room: Optional[int] = None
99
+ bootstrap_host: Optional[Union[List[str], str]] = None
100
+ bootstrap_port: Optional[Union[List[int], int]] = None
101
+ bootstrap_room: Optional[Union[List[int], int]] = None
101
102
 
102
103
  def normalize_batch_and_arguments(self):
103
104
  """
@@ -397,6 +398,15 @@ class GenerateReqInput:
397
398
  else None
398
399
  ),
399
400
  return_hidden_states=self.return_hidden_states,
401
+ bootstrap_host=(
402
+ self.bootstrap_host[i] if self.bootstrap_host is not None else None
403
+ ),
404
+ bootstrap_port=(
405
+ self.bootstrap_port[i] if self.bootstrap_port is not None else None
406
+ ),
407
+ bootstrap_room=(
408
+ self.bootstrap_room[i] if self.bootstrap_room is not None else None
409
+ ),
400
410
  )
401
411
 
402
412
 
@@ -441,6 +451,7 @@ class TokenizedGenerateReqInput:
441
451
 
442
452
  # For disaggregated inference
443
453
  bootstrap_host: Optional[str] = None
454
+ bootstrap_port: Optional[int] = None
444
455
  bootstrap_room: Optional[int] = None
445
456
 
446
457
 
@@ -457,6 +468,8 @@ class EmbeddingReqInput:
457
468
  image_data: Optional[
458
469
  Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
459
470
  ] = None
471
+ # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
472
+ audio_data: Optional[Union[List[str], str]] = None
460
473
  # The token ids for text; one can either specify text or input_ids.
461
474
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
462
475
  # The request id.
@@ -665,10 +678,15 @@ class BatchEmbeddingOut:
665
678
 
666
679
 
667
680
  @dataclass
668
- class FlushCacheReq:
681
+ class FlushCacheReqInput:
669
682
  pass
670
683
 
671
684
 
685
+ @dataclass
686
+ class FlushCacheReqOutput:
687
+ success: bool
688
+
689
+
672
690
  @dataclass
673
691
  class UpdateWeightFromDiskReqInput:
674
692
  # The model path with the new weights
@@ -10,12 +10,13 @@ import torch
10
10
  from torch import nn
11
11
 
12
12
  from sglang.srt.managers.schedule_batch import (
13
+ Modality,
13
14
  MultimodalDataItem,
14
15
  MultimodalInputs,
15
16
  global_server_args_dict,
16
17
  )
17
18
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
- from sglang.srt.utils import print_warning_once
19
+ from sglang.srt.utils import flatten_nested_list, print_warning_once
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
22
 
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
97
98
  return padded_ids
98
99
 
99
100
 
100
- class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
101
+ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
101
102
  """In this pattern, data tokens should be represented as repetitions of a single token
102
103
  e.g. <image><image>....<image>, or <audio><audio>...<audio>
103
104
  """
104
105
 
105
- def __init__(self, image_token_id: torch.Tensor) -> None:
106
- self.image_token_id = image_token_id
106
+ def __init__(self, token_ids: List[int]) -> None:
107
+ self.token_ids = token_ids
107
108
 
108
- def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
109
+ def pad_input_tokens(
110
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
111
+ ) -> List[int]:
109
112
  """
110
- This function will replace the data-tokens in between with pad_values accordingly
113
+ Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
114
+ and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
111
115
  """
112
116
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
113
- assert len(pad_values) != 0
117
+ if not pad_values:
118
+ # No multimodal items, return original input_ids
119
+ return input_ids
120
+ if not input_ids:
121
+ return []
114
122
 
115
123
  input_ids_tensor = torch.tensor(input_ids)
116
- mask = torch.isin(input_ids_tensor, self.image_token_id)
124
+ device = input_ids_tensor.device
125
+ token_ids_tensor = torch.tensor(self.token_ids, device=device)
126
+ mask = torch.isin(input_ids_tensor, token_ids_tensor)
117
127
 
118
- num_image_tokens = mask.sum().item()
119
- repeated_pad_values = torch.tensor(pad_values).repeat(
120
- num_image_tokens // len(pad_values) + 1
121
- )[:num_image_tokens]
128
+ if not mask.any():
129
+ # No tokens match token_ids, return original input_ids
130
+ return input_ids
131
+
132
+ # Find contiguous regions
133
+ padded_mask = torch.cat(
134
+ (
135
+ torch.tensor([False], device=device),
136
+ mask,
137
+ torch.tensor([False], device=device),
138
+ )
139
+ )
140
+ # Find indices where the mask value changes
141
+ diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
142
+
143
+ # Start indices are where False changes to True
144
+ starts = diff_indices[::2]
145
+ # End indices are where True changes to False (exclusive index)
146
+ ends = diff_indices[1::2]
147
+
148
+ # Check if the number of regions matches the number of pad values
149
+ if len(starts) != len(pad_values):
150
+ # Maybe log a warning here?
151
+ num_regions = len(starts)
152
+ num_pad_values = len(pad_values)
153
+ if num_regions > 0 and num_pad_values > 0:
154
+ pad_values = (pad_values * (num_regions // num_pad_values + 1))[
155
+ :num_regions
156
+ ]
157
+ else: # If no regions or no pad_values, this loop won't run anyway.
158
+ pad_values = [] # Ensure pad_values is empty if starts is empty
159
+
160
+ # Create a copy to modify
161
+ output_ids_tensor = input_ids_tensor.clone()
162
+
163
+ # Replace tokens in each region with the corresponding pad value
164
+ # Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
165
+ for i in range(min(len(starts), len(pad_values))):
166
+ start_idx = starts[i]
167
+ end_idx = ends[i]
168
+ pad_value = pad_values[i]
169
+ if pad_value is not None: # Ensure pad_value is not None before assignment
170
+ output_ids_tensor[start_idx:end_idx] = pad_value
171
+ else:
172
+ logger.warning(f"Skipping region {i} due to None pad_value.")
122
173
 
123
- input_ids_tensor[mask] = repeated_pad_values
124
- return input_ids_tensor.tolist()
174
+ return output_ids_tensor.tolist()
125
175
 
126
176
 
127
177
  def get_embedding_and_mask(
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
150
200
  ).unsqueeze(-1)
151
201
 
152
202
  num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
153
-
154
203
  if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
155
204
  logger.warning(
156
205
  f"Number of tokens in multimodal embedding does not match those in the input text."
@@ -190,13 +239,13 @@ def embed_mm_inputs(
190
239
  audio_data_embedding_func: Callable[
191
240
  [List[MultimodalDataItem]], torch.Tensor
192
241
  ] = None,
193
- placeholder_token_ids: List[int] = None,
242
+ placeholder_tokens: dict[Modality, List[int]] = None,
194
243
  ) -> Optional[torch.Tensor]:
195
244
  """
196
245
  Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
197
246
 
198
247
  Args:
199
- placeholder_token_ids: denoting the token of multimodal data in input_ids.
248
+ placeholder_tokens: denoting the token of multimodal data in input_ids.
200
249
  If none, the pad_values of multimodal items are used
201
250
 
202
251
  Returns:
@@ -208,9 +257,17 @@ def embed_mm_inputs(
208
257
 
209
258
  # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
210
259
  # we assume that multimodal data are represented with its pad_values in input_ids
211
- placeholder_token_ids = placeholder_token_ids or [
212
- item.pad_value for item in mm_inputs.mm_items
213
- ]
260
+ # See `pad_input_ids` for more detail
261
+
262
+ # if placeholder_tokens is specified
263
+ if placeholder_tokens is not None:
264
+ placeholder_token_ids = flatten_nested_list(
265
+ [placeholder_token for placeholder_token in placeholder_tokens.values()]
266
+ )
267
+ else:
268
+ placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
269
+
270
+ assert isinstance(placeholder_token_ids[0], int)
214
271
 
215
272
  placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
216
273
 
@@ -233,7 +290,7 @@ def embed_mm_inputs(
233
290
  using_all_items = False
234
291
  if len(appearing_items) == 0:
235
292
  # This happens mostly when arg placeholder_token_ids is passed
236
- logger.warning_once(
293
+ logger.warning(
237
294
  "No multimodal data item's pad value exist in placeholder ids. Using all items"
238
295
  )
239
296
  using_all_items = True
@@ -253,7 +310,8 @@ def embed_mm_inputs(
253
310
  data_embedding_func=image_data_embedding_func,
254
311
  embedding_items=items,
255
312
  placeholder_tensor=(
256
- placeholder_tensor
313
+ # use the specified modality token to identify the location to embed
314
+ placeholder_tokens[Modality.IMAGE]
257
315
  if using_all_items
258
316
  else torch.tensor(
259
317
  [item.pad_value for item in items],
@@ -275,7 +333,7 @@ def embed_mm_inputs(
275
333
  data_embedding_func=audio_data_embedding_func,
276
334
  embedding_items=items,
277
335
  placeholder_tensor=(
278
- placeholder_tensor
336
+ placeholder_tokens[Modality.AUDIO]
279
337
  if using_all_items
280
338
  else torch.tensor(
281
339
  [item.pad_value for item in items],
@@ -296,7 +354,7 @@ def embed_mm_inputs(
296
354
  input_ids.clamp_(min=0, max=vocab_size - 1)
297
355
  inputs_embeds = input_embedding(input_ids)
298
356
 
299
- # 4. scatter embeddings into input embedding
357
+ # 4. Scatter embeddings into input embedding
300
358
  for embedding, mask in zip(embeddings, masks):
301
359
  mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
302
360
  inputs_embeds = inputs_embeds.masked_scatter(
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
316
374
  audio_data_embedding_func: Callable[
317
375
  [List[MultimodalDataItem]], torch.Tensor
318
376
  ] = None,
319
- placeholder_token_ids: List[int] = None,
377
+ placeholder_tokens: dict[Modality, List[int]] = None,
320
378
  **kwargs,
321
379
  ) -> torch.Tensor:
322
380
  """
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
328
386
  audio_data_embedding_func : the function returning the image embedding
329
387
 
330
388
  Returns:
331
- inputs_embedding
332
389
  forwarded hidden states
333
390
 
334
391
  """
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
346
403
  input_embedding=embed_tokens,
347
404
  image_data_embedding_func=image_data_embedding_func,
348
405
  audio_data_embedding_func=audio_data_embedding_func,
349
- placeholder_token_ids=placeholder_token_ids,
406
+ placeholder_tokens=placeholder_tokens,
350
407
  )
351
- # once used, mm_inputs is useless
408
+ # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
352
409
  # just being defensive here
353
410
  forward_batch.mm_inputs = None
354
411
  else:
@@ -8,6 +8,7 @@ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
+ from PIL import Image
11
12
  from transformers import BaseImageProcessorFast
12
13
 
13
14
  from sglang.srt.managers.schedule_batch import Modality
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
92
93
 
93
94
  @abstractmethod
94
95
  async def process_mm_data_async(
95
- self, image_data, input_text, max_req_input_len, **kwargs
96
+ self,
97
+ image_data,
98
+ input_text,
99
+ request_obj,
100
+ max_req_input_len,
101
+ **kwargs,
96
102
  ):
97
103
  pass
98
104
 
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
104
110
  from decord import VideoReader, cpu
105
111
 
106
112
  # Before processing inputs
113
+ if not image_data or len(image_data) == 0:
114
+ return []
107
115
  estimated_frames_list = []
108
116
  for image in image_data:
109
117
  if isinstance(image, str) and image.startswith("video:"):
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
215
223
  discard_alpha_channel: if True, discards the alpha channel in the returned images
216
224
 
217
225
  """
226
+
227
+ if image_data is None:
228
+ image_data = []
218
229
  if isinstance(multimodal_tokens.image_token, int):
219
230
  multimodal_tokens.image_token = (
220
231
  self._processor.tokenizer.convert_ids_to_tokens(
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
229
240
  prompt = self._processor.tokenizer.decode(prompt)
230
241
  else:
231
242
  prompt = prompt
243
+
244
+ assert isinstance(prompt, str)
232
245
  if return_text:
233
246
  import re
234
247
 
@@ -16,6 +16,7 @@
16
16
  # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
17
  # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
18
  # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+ from typing import List, Union
19
20
 
20
21
  import torch
21
22
 
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
35
36
  self.IMAGE_TOKEN = "<image>"
36
37
 
37
38
  async def process_mm_data_async(
38
- self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
39
+ self,
40
+ image_data: List[Union[str, bytes]],
41
+ input_text,
42
+ request_obj,
43
+ max_req_input_len,
44
+ *args,
45
+ **kwargs
39
46
  ):
40
47
  if not image_data:
41
48
  return None
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
45
52
 
46
53
  image_token = self.IMAGE_TOKEN
47
54
  base_output = self.load_mm_data(
48
- input_ids,
55
+ input_text,
49
56
  image_data=image_data,
50
57
  multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
51
58
  max_req_input_len=max_req_input_len,
@@ -1,7 +1,5 @@
1
1
  from typing import List, Union
2
2
 
3
- from transformers.utils import logging
4
-
5
3
  from sglang.srt.managers.multimodal_processor import (
6
4
  BaseMultimodalProcessor as SGLangBaseProcessor,
7
5
  )
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
13
11
 
14
12
  # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
15
13
  # will be removed in the future
16
- logger = logging.get_logger(__name__)
17
14
 
18
15
 
19
16
  class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
28
25
  async def process_mm_data_async(
29
26
  self,
30
27
  image_data: List[Union[str, bytes]],
31
- input_ids,
28
+ input_text,
32
29
  request_obj,
33
30
  max_req_input_len,
34
31
  *args,
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
41
38
 
42
39
  image_token = self.IMAGE_TOKEN
43
40
  base_output = self.load_mm_data(
44
- prompt=input_ids,
41
+ prompt=input_text,
45
42
  image_data=image_data,
46
43
  multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
47
44
  max_req_input_len=max_req_input_len,
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
17
17
  async def process_mm_data_async(
18
18
  self,
19
19
  image_data: List[Union[str, bytes]],
20
- input_ids,
20
+ input_text,
21
21
  request_obj,
22
22
  max_req_input_len,
23
23
  **kwargs,
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
31
31
  processor = self._processor
32
32
 
33
33
  base_out = self.load_mm_data(
34
- prompt=input_ids,
34
+ prompt=input_text,
35
35
  image_data=image_data,
36
36
  multimodal_tokens=MultimodalSpecialTokens(
37
37
  image_token=processor.image_token
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
51
51
  async def process_mm_data_async(
52
52
  self,
53
53
  image_data: List[Union[str, bytes]],
54
- input_ids,
54
+ input_text,
55
55
  request_obj,
56
56
  max_req_input_len,
57
+ **kwargs,
57
58
  ):
58
59
  audio_data = request_obj.audio_data
59
60
  if not image_data and not audio_data:
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
64
65
  audio_data = [audio_data]
65
66
 
66
67
  base_output = self.load_mm_data(
67
- prompt=input_ids,
68
+ prompt=input_text,
68
69
  max_req_input_len=max_req_input_len,
69
70
  audio_data=audio_data,
70
71
  image_data=image_data,
@@ -96,7 +97,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
96
97
  audio_start_id = tokenizer.audio_start_id
97
98
  audio_end_id = tokenizer.audio_end_id
98
99
 
99
- im_token_id = tokenizer.unk_token_id
100
+ im_token_id = tokenizer.unk_id
100
101
  pixel_values = res["pixel_values"]
101
102
  tgt_sizes = res["tgt_sizes"]
102
103
 
@@ -5,6 +5,7 @@ from typing import List, Union
5
5
  import torch
6
6
  from PIL import Image
7
7
 
8
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
8
9
  from sglang.srt.managers.multimodal_processors.base_processor import (
9
10
  BaseMultimodalProcessor as SGLangBaseProcessor,
10
11
  )
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
27
28
  self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
28
29
  self.image_token_id = hf_config.image_token_id
29
30
  self.video_token_id = hf_config.video_token_id
31
+ self.vision_start_token_id = hf_config.vision_start_token_id
32
+ self.vision_end_token_id = hf_config.vision_end_token_id
30
33
  self.NUM_TOKEN_PER_FRAME = 770
31
34
  self.IMAGE_FACTOR = 28
32
35
  self.MIN_PIXELS = 4 * 28 * 28
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
36
39
  async def process_mm_data_async(
37
40
  self,
38
41
  image_data: List[Union[str, bytes]],
39
- prompt,
42
+ input_text,
40
43
  request_obj,
41
44
  max_req_input_len,
42
45
  *args,
43
46
  **kwargs,
44
47
  ):
45
- if not image_data:
46
- return None
47
48
  if isinstance(image_data, str):
48
49
  image_data = [image_data]
49
50
 
50
51
  image_token = self.IMAGE_TOKEN
51
52
  base_output = self.load_mm_data(
52
- prompt=prompt,
53
+ prompt=input_text,
53
54
  image_data=image_data,
54
55
  multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
55
56
  max_req_input_len=max_req_input_len,
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
116
117
  async def resize_image_async(image):
117
118
  return resize_image(image)
118
119
 
119
- resize_tasks = [resize_image_async(image) for image in base_output.images]
120
- resized_images = await asyncio.gather(*resize_tasks)
120
+ if base_output.images:
121
+ resize_tasks = [resize_image_async(image) for image in base_output.images]
122
+ base_output.images = await asyncio.gather(*resize_tasks)
121
123
 
122
124
  ret = self.process_mm_data(
123
125
  input_text=base_output.input_text,
124
- images=resized_images,
126
+ images=base_output.images,
125
127
  )
126
128
 
127
- image_grid_thws = torch.concat([ret["image_grid_thw"]])
128
- return {
129
- "input_ids": ret["input_ids"].flatten().tolist(),
130
- "mm_items": [
129
+ items = []
130
+
131
+ input_ids = ret["input_ids"].flatten().tolist()
132
+ if "pixel_values" in ret:
133
+ items += [
131
134
  MultimodalDataItem(
132
135
  pixel_values=ret["pixel_values"],
133
- image_grid_thws=image_grid_thws,
136
+ image_grid_thws=torch.concat([ret["image_grid_thw"]]),
134
137
  # TODO
135
138
  video_grid_thws=None,
136
139
  second_per_grid_ts=ret.get("second_per_grid_ts", None),
137
140
  modality=Modality.IMAGE,
138
141
  )
139
- ],
142
+ ]
143
+
144
+ mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
145
+ spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
146
+ image_token_id=self.image_token_id,
147
+ video_token_id=self.video_token_id,
148
+ vision_start_token_id=self.vision_start_token_id,
149
+ model_type=self.hf_config.model_type,
150
+ tokens_per_second=getattr(
151
+ self.hf_config.vision_config, "tokens_per_second", None
152
+ ),
153
+ input_ids=torch.tensor(input_ids).unsqueeze(0),
154
+ image_grid_thw=ret.get("image_grid_thw", None),
155
+ video_grid_thw=ret.get("video_grid_thw", None),
156
+ second_per_grid_ts=ret.get("second_per_grid_ts", None),
157
+ )
158
+ mrope_positions = mrope_positions.squeeze(1)
159
+
160
+ return {
161
+ "input_ids": input_ids,
162
+ "mm_items": items,
140
163
  "im_start_id": self.IM_START_TOKEN_ID,
141
164
  "im_end_id": self.IM_END_TOKEN_ID,
142
165
  "im_token_id": self.image_token_id,
143
166
  "video_token_id": self.video_token_id,
167
+ "mrope_positions": mrope_positions,
168
+ "mrope_position_delta": mrope_position_delta,
144
169
  }