sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,14 @@ from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
20
20
  )
21
- from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
21
+ from sglang.srt.utils import (
22
+ _process_weight_after_loading,
23
+ cpu_has_amx_support,
24
+ get_bool_env_var,
25
+ is_cpu,
26
+ is_hip,
27
+ set_weight_attrs,
28
+ )
22
29
 
23
30
  if torch.cuda.is_available():
24
31
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -28,6 +35,8 @@ else:
28
35
  import logging
29
36
 
30
37
  _is_hip = is_hip()
38
+ _is_cpu_amx_available = cpu_has_amx_support()
39
+ _is_cpu = is_cpu()
31
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
32
41
 
33
42
  if _use_aiter:
@@ -117,6 +126,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
117
126
  requires_grad=False,
118
127
  )
119
128
  torch.cuda.empty_cache()
129
+
130
+ # Pack weight for get better performance on CPU
131
+ if _is_cpu and _is_cpu_amx_available:
132
+ _process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
133
+
120
134
  return
121
135
 
122
136
  def apply(
@@ -248,19 +262,64 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
248
262
  no_combine: bool = False,
249
263
  routed_scaling_factor: Optional[float] = None,
250
264
  ) -> torch.Tensor:
251
- return moe_forward_native(
252
- layer,
253
- x,
254
- use_grouped_topk,
255
- top_k,
256
- router_logits,
257
- renormalize,
258
- topk_group,
259
- num_expert_group,
260
- num_fused_shared_experts,
261
- custom_routing_function,
262
- correction_bias,
263
- )
265
+ assert activation == "silu", f"activation = {activation} is not supported."
266
+
267
+ if (
268
+ getattr(layer, "use_intel_amx_backend", False)
269
+ and not apply_router_weight_on_input
270
+ ):
271
+ topk_weights, topk_ids = select_experts(
272
+ hidden_states=x,
273
+ router_logits=router_logits,
274
+ use_grouped_topk=use_grouped_topk,
275
+ top_k=top_k,
276
+ renormalize=renormalize,
277
+ topk_group=topk_group,
278
+ num_expert_group=num_expert_group,
279
+ num_fused_shared_experts=num_fused_shared_experts,
280
+ custom_routing_function=custom_routing_function,
281
+ correction_bias=correction_bias,
282
+ routed_scaling_factor=routed_scaling_factor,
283
+ )
284
+
285
+ # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
286
+ return torch.ops.sgl_kernel.fused_experts_cpu(
287
+ x,
288
+ layer.w13_weight,
289
+ layer.w2_weight,
290
+ topk_weights.to(
291
+ torch.float
292
+ ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
293
+ topk_ids,
294
+ True, # inplace
295
+ False, # use_int8_w8a8
296
+ False, # use_fp8_w8a16
297
+ None, # w1_scale
298
+ None, # w2_scale
299
+ None, # block_size
300
+ None, # a1_scale
301
+ None, # a2_scale
302
+ True, # is_vnni
303
+ )
304
+ else:
305
+ return moe_forward_native(
306
+ layer,
307
+ x,
308
+ use_grouped_topk,
309
+ top_k,
310
+ router_logits,
311
+ renormalize,
312
+ topk_group,
313
+ num_expert_group,
314
+ num_fused_shared_experts,
315
+ custom_routing_function,
316
+ correction_bias,
317
+ activation,
318
+ apply_router_weight_on_input,
319
+ inplace,
320
+ no_combine,
321
+ routed_scaling_factor,
322
+ )
264
323
 
265
324
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
266
325
  raise NotImplementedError("The TPU backend currently does not support MoE.")
@@ -30,6 +30,7 @@ from sglang.srt.managers.expert_location_dispatch import (
30
30
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
31
  from sglang.srt.utils import (
32
32
  cpu_has_amx_support,
33
+ get_bool_env_var,
33
34
  get_compiler_backend,
34
35
  is_cpu,
35
36
  is_cuda,
@@ -38,6 +39,7 @@ from sglang.srt.utils import (
38
39
 
39
40
  _is_cuda = is_cuda()
40
41
  _is_hip = is_hip()
42
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
41
43
  _is_cpu_amx_available = cpu_has_amx_support()
42
44
  _is_cpu = is_cpu()
43
45
 
@@ -46,6 +48,11 @@ if _is_cuda:
46
48
 
47
49
  if _is_cuda or _is_hip:
48
50
  from sgl_kernel import topk_softmax
51
+ if _use_aiter:
52
+ try:
53
+ from aiter import biased_grouped_topk as aiter_biased_grouped_topk
54
+ except ImportError:
55
+ raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
49
56
 
50
57
 
51
58
  def fused_topk_torch_native(
@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu(
347
354
  topk_ids, expert_location_dispatch_info, num_token_non_padded
348
355
  )
349
356
  return topk_weights, topk_ids
357
+ elif _use_aiter:
358
+ token = gating_output.shape[0]
359
+ device = gating_output.device
360
+ assert (
361
+ hidden_states.shape[0] == gating_output.shape[0]
362
+ ), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
363
+ topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
364
+ topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
365
+ aiter_biased_grouped_topk(
366
+ gating_output,
367
+ correction_bias,
368
+ topk_weights,
369
+ topk_ids,
370
+ num_expert_group,
371
+ topk_group,
372
+ renormalize,
373
+ routed_scaling_factor,
374
+ )
375
+ return topk_weights, topk_ids
350
376
  else:
351
377
  biased_grouped_topk_fn = (
352
378
  torch.compile(
@@ -42,7 +42,10 @@ _is_fp8_fnuz = is_fp8_fnuz()
42
42
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
43
43
 
44
44
  if _use_aiter:
45
- from aiter import gemm_a8w8_blockscale_CK
45
+ import aiter
46
+ from aiter import gemm_a8w8_blockscale_CK, get_hip_quant
47
+
48
+ aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
46
49
 
47
50
  if _is_cuda:
48
51
  from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
@@ -271,9 +274,7 @@ def aiter_w8a8_block_fp8_linear(
271
274
  input_2d = input.view(-1, input.shape[-1])
272
275
  output_shape = [*input.shape[:-1], weight.shape[0]]
273
276
 
274
- q_input, x_scale = per_token_group_quant_fp8(
275
- input_2d, block_size[1], column_major_scales=False
276
- )
277
+ q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
277
278
  output = gemm_a8w8_blockscale_CK(
278
279
  q_input, weight, x_scale, weight_scale, dtype=input.dtype
279
280
  )
@@ -8,16 +8,29 @@ import torch
8
8
  import torch.nn as nn
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
11
+ from sglang.srt.utils import (
12
+ cpu_has_amx_support,
13
+ get_bool_env_var,
14
+ is_cpu,
15
+ is_cuda,
16
+ is_hip,
17
+ is_npu,
18
+ )
12
19
 
13
20
  _is_cuda = is_cuda()
14
21
  _is_hip = is_hip()
22
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
15
23
  _is_npu = is_npu()
16
24
  _is_cpu_amx_available = cpu_has_amx_support()
17
25
  _is_cpu = is_cpu()
18
26
 
19
27
  if _is_cuda:
20
28
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
29
+ if _use_aiter:
30
+ from aiter.rotary_embedding import get_rope as aiter_get_rope
31
+
32
+ if is_npu():
33
+ import torch_npu
21
34
 
22
35
 
23
36
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -152,6 +165,36 @@ class RotaryEmbedding(CustomOp):
152
165
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153
166
  return query, key
154
167
 
168
+ def forward_npu(
169
+ self,
170
+ positions: torch.Tensor,
171
+ query: torch.Tensor,
172
+ key: torch.Tensor,
173
+ offsets: Optional[torch.Tensor] = None,
174
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ """A PyTorch-npu implementation of forward()."""
176
+ import os
177
+
178
+ if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
179
+ return self.forward_native(positions, query, key, offsets)
180
+ else:
181
+ rotary_mode = "half"
182
+ if self.is_neox_style:
183
+ rotary_mode = "half"
184
+ else:
185
+ rotary_mode = "interleave"
186
+ mrope_section = [0, 0, 0]
187
+ query_out, key_out = torch_npu.npu_mrope(
188
+ positions,
189
+ query,
190
+ key,
191
+ self.cos_sin_cache,
192
+ self.head_size,
193
+ mrope_section=mrope_section,
194
+ rotary_mode=rotary_mode,
195
+ )
196
+ return query_out, key_out
197
+
155
198
  def forward_cpu(
156
199
  self,
157
200
  positions: torch.Tensor,
@@ -847,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
847
890
  return query_out.type_as(query), key_out.type_as(key)
848
891
 
849
892
 
893
+ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
894
+ """RotaryEmbedding extended with Dynamic NTK scaling.
895
+
896
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
897
+ """
898
+
899
+ def __init__(
900
+ self,
901
+ head_size: int,
902
+ rotary_dim: int,
903
+ max_position_embeddings: int,
904
+ base: int,
905
+ is_neox_style: bool,
906
+ scaling_alpha: float,
907
+ dtype: torch.dtype,
908
+ ) -> None:
909
+ self.scaling_alpha = scaling_alpha
910
+ super().__init__(
911
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
912
+ )
913
+
914
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
915
+ max_len = self.max_position_embeddings
916
+ base = self.base * self.scaling_alpha ** (
917
+ self.rotary_dim / (self.rotary_dim - 2)
918
+ )
919
+
920
+ inv_freq = self._compute_inv_freq(base)
921
+ t = torch.arange(max_len, dtype=torch.float)
922
+
923
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
924
+ cos = freqs.cos()
925
+ sin = freqs.sin()
926
+ cache = torch.cat((cos, sin), dim=-1)
927
+ return cache
928
+
929
+
850
930
  class MRotaryEmbedding(RotaryEmbedding):
851
931
  """Rotary Embedding with Multimodal Sections."""
852
932
 
@@ -1191,15 +1271,26 @@ def get_rope(
1191
1271
  )
1192
1272
  elif scaling_type == "dynamic":
1193
1273
  scaling_factor = rope_scaling["factor"]
1194
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
1195
- head_size,
1196
- rotary_dim,
1197
- max_position,
1198
- base,
1199
- is_neox_style,
1200
- scaling_factor,
1201
- dtype,
1202
- )
1274
+ if "alpha" in rope_scaling:
1275
+ rotary_emb = DynamicNTKAlphaRotaryEmbedding(
1276
+ head_size,
1277
+ rotary_dim,
1278
+ max_position,
1279
+ base,
1280
+ is_neox_style,
1281
+ rope_scaling["alpha"],
1282
+ dtype,
1283
+ )
1284
+ else:
1285
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
1286
+ head_size,
1287
+ rotary_dim,
1288
+ max_position,
1289
+ base,
1290
+ is_neox_style,
1291
+ scaling_factor,
1292
+ dtype,
1293
+ )
1203
1294
  elif scaling_type == "yarn":
1204
1295
  scaling_factor = rope_scaling["factor"]
1205
1296
  original_max_position = rope_scaling["original_max_position_embeddings"]
@@ -1388,7 +1479,8 @@ def get_rope_wrapper(
1388
1479
  device: Optional[str] = None,
1389
1480
  ):
1390
1481
  if device != "cpu":
1391
- return get_rope(
1482
+ wrapper = aiter_get_rope if _use_aiter else get_rope
1483
+ return wrapper(
1392
1484
  head_size,
1393
1485
  rotary_dim,
1394
1486
  max_position,
@@ -20,10 +20,18 @@ from sglang.srt.layers.quantization.base_config import (
20
20
  QuantizeMethodBase,
21
21
  method_has_implemented_embedding,
22
22
  )
23
- from sglang.srt.utils import set_weight_attrs
23
+ from sglang.srt.utils import (
24
+ PackWeightMethod,
25
+ cpu_has_amx_support,
26
+ is_cpu,
27
+ set_weight_attrs,
28
+ )
24
29
 
25
30
  DEFAULT_VOCAB_PADDING_SIZE = 64
26
31
 
32
+ _is_cpu_amx_available = cpu_has_amx_support()
33
+ _is_cpu = is_cpu()
34
+
27
35
 
28
36
  class UnquantizedEmbeddingMethod(QuantizeMethodBase):
29
37
  """Unquantized method for embeddings."""
@@ -549,6 +557,11 @@ class ParallelLMHead(VocabParallelEmbedding):
549
557
  use_presharded_weights=use_presharded_weights,
550
558
  )
551
559
  self.quant_config = quant_config
560
+
561
+ # We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
562
+ if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
563
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
564
+
552
565
  if bias:
553
566
  self.bias = Parameter(
554
567
  torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
61
61
  def with_debug_name(self, debug_name):
62
62
  yield
63
63
 
64
+ @contextmanager
65
+ def disable_this_region(self):
66
+ yield
67
+
64
68
  @contextmanager
65
69
  def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
66
70
  yield
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
116
120
  self._expert_location_metadata = expert_location_metadata
117
121
 
118
122
  self._recording = False
123
+ self._disable_all = False
119
124
  self._current_forward_pass_id = Withable()
120
125
  self._current_layer_idx = Withable()
121
126
  self._current_debug_name = Withable()
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
148
153
  finally:
149
154
  self._on_forward_pass_end(forward_pass_id)
150
155
 
156
+ @contextmanager
157
+ def disable_this_region(self):
158
+ """Context manager to temporarily disable recording."""
159
+ previous_disable_all = self._disable_all
160
+ self._disable_all = True
161
+ try:
162
+ yield
163
+ finally:
164
+ self._disable_all = previous_disable_all
165
+
151
166
  def _on_forward_pass_start(self, forward_batch: ForwardBatch):
152
167
  if not self._recording:
153
168
  return
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
189
204
  )
190
205
 
191
206
  def _on_hook(self, hook_name: str, **kwargs):
207
+ if self._disable_all:
208
+ return
192
209
  if not (self._recording or torch.cuda.is_current_stream_capturing()):
193
210
  return
194
211
  gatherer = self._single_pass_gatherers[
@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
462
479
  def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
463
480
  topk_ids = topk_ids.flatten()
464
481
  mask = topk_ids != -1
482
+ assert self._data[layer_idx, :].shape == topk_ids.shape, (
483
+ "Shape mismatch between data and topk_ids."
484
+ "Selecting expert is not supported for multiple token prediction at the moment."
485
+ )
465
486
  self._data[layer_idx, :].scatter_add_(
466
487
  dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
467
488
  )
@@ -319,8 +319,16 @@ class GenerateReqInput:
319
319
  """Normalize request IDs for batch processing."""
320
320
  if self.rid is None:
321
321
  self.rid = [uuid.uuid4().hex for _ in range(num)]
322
- elif not isinstance(self.rid, list):
323
- raise ValueError("The rid should be a list for batch processing.")
322
+ elif isinstance(self.rid, str):
323
+ new_rids = [f"{self.rid}_{i}" for i in range(num)]
324
+ self.rid = new_rids
325
+ elif isinstance(self.rid, list):
326
+ if len(self.rid) != num:
327
+ raise ValueError(
328
+ "The specified rids length mismatch with the batch_size for batch processing."
329
+ )
330
+ else:
331
+ raise ValueError("The rid should be a string or a list of strings.")
324
332
 
325
333
  def _normalize_logprob_params(self, num):
326
334
  """Normalize logprob-related parameters for batch processing."""
@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
23
23
  RAW_IMAGES = "raw_images"
24
24
  PRECOMPUTED_FEATURES = "precomputed_features"
25
25
  PIXEL_VALUES = "pixel_values"
26
+ AUDIO = "audio"
26
27
 
27
28
 
28
29
  @dataclasses.dataclass
@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
441
442
  has_image = False
442
443
  has_pixel_values = False
443
444
  has_precomputed_features = False
445
+ has_audio = False
444
446
 
445
447
  for mm_input in mm_inputs:
446
448
  if isinstance(mm_input, Image.Image):
447
449
  has_image = True
450
+ elif isinstance(mm_input, np.ndarray):
451
+ has_audio = True
448
452
  elif isinstance(mm_input, dict):
449
453
  if mm_input.get("precomputed_features", None) is not None:
450
454
  has_precomputed_features = True
@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
461
465
 
462
466
  # Validate format consistency
463
467
  format_count = sum(
464
- [has_image, has_pixel_values, has_precomputed_features]
468
+ [has_image, has_pixel_values, has_precomputed_features, has_audio]
465
469
  )
466
470
  if format_count > 1:
467
471
  raise ValueError(
468
472
  "Unsupported: mixture of multimodal input formats. "
469
473
  f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
470
- f"precomputed_features={has_precomputed_features}"
474
+ f"precomputed_features={has_precomputed_features}, audio={has_audio}"
471
475
  )
472
476
 
473
477
  if has_image:
@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
476
480
  return MultimodalInputFormat.PRECOMPUTED_FEATURES
477
481
  elif has_pixel_values:
478
482
  return MultimodalInputFormat.PIXEL_VALUES
483
+ elif has_audio:
484
+ return MultimodalInputFormat.AUDIO
479
485
  else:
480
486
  raise ValueError("No valid multimodal input format found")
481
487
  except Exception as e:
@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
521
527
  input_ids = tokenize_text(base_output.input_text)
522
528
  return combined_mm_item, input_ids
523
529
 
530
+ def process_audio(
531
+ base_output: BaseMultiModalProcessorOutput,
532
+ ) -> Tuple[MultimodalDataItem, torch.Tensor]:
533
+ """Process inputs with audio."""
534
+ ret = self.process_mm_data(
535
+ input_text=base_output.input_text,
536
+ audio=base_output.audios, # Note: "audio" is for gemma3n only
537
+ )
538
+ combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
539
+ for key, value in ret.items():
540
+ if key != "input_ids" and hasattr(combined_mm_item, key):
541
+ setattr(combined_mm_item, key, value)
542
+ input_ids = ret["input_ids"].flatten()
543
+ return combined_mm_item, input_ids
544
+
524
545
  def finalize_mm_item(
525
546
  combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
526
547
  ) -> MultimodalDataItem:
527
548
  """Apply common post-processing to the multimodal item."""
528
- combined_mm_item.image_offsets = self.get_mm_items_offset(
529
- input_ids=input_ids,
530
- mm_token_id=self.IM_TOKEN_ID,
531
- )
549
+ if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
550
+ combined_mm_item.image_offsets = self.get_mm_items_offset(
551
+ input_ids=input_ids,
552
+ mm_token_id=self.IM_TOKEN_ID,
553
+ )
554
+ elif combined_mm_item.modality == Modality.AUDIO:
555
+ combined_mm_item.audio_offsets = self.get_mm_items_offset(
556
+ input_ids=input_ids,
557
+ mm_token_id=self.AUDIO_TOKEN_ID,
558
+ )
559
+ elif combined_mm_item.modality == Modality.VIDEO:
560
+ combined_mm_item.video_offsets = self.get_mm_items_offset(
561
+ input_ids=input_ids,
562
+ mm_token_id=self.VIDEO_TOKEN_ID,
563
+ )
564
+ else:
565
+ raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
532
566
  return combined_mm_item
533
567
 
534
- # Main logic
535
- mm_inputs = base_output.images
568
+ # Main logic - determine input type and handle text-only case
569
+ mm_inputs = base_output.images or base_output.audios
536
570
  if not mm_inputs:
537
- # Return text-only case
538
571
  input_ids = tokenize_text(base_output.input_text)
539
572
  return None, input_ids
540
573
 
@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
548
581
  combined_mm_item, input_ids = process_precomputed_features(base_output)
549
582
  elif input_format == MultimodalInputFormat.PIXEL_VALUES:
550
583
  combined_mm_item, input_ids = process_pixel_values(base_output)
584
+ elif input_format == MultimodalInputFormat.AUDIO:
585
+ combined_mm_item, input_ids = process_audio(base_output)
551
586
  else:
552
587
  raise ValueError(f"Unknown input format: {input_format}")
553
588
 
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ import re
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ from sglang.srt.managers.multimodal_processor import (
19
+ BaseMultimodalProcessor as SGLangBaseProcessor,
20
+ )
21
+ from sglang.srt.managers.multimodal_processors.base_processor import (
22
+ MultimodalSpecialTokens,
23
+ )
24
+ from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
25
+
26
+
27
+ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
28
+ """Multimodal processor for Gemma3n supporting image and audio inputs."""
29
+
30
+ models = [Gemma3nForConditionalGeneration]
31
+
32
+ def __init__(self, hf_config, server_args, _processor):
33
+ super().__init__(hf_config, server_args, _processor)
34
+
35
+ self.IMAGE_TOKEN = "<image_soft_token>"
36
+ self.IMAGE_TOKEN_REGEX = re.compile(
37
+ r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
38
+ )
39
+
40
+ self.AUDIO_TOKEN = "<audio_soft_token>"
41
+ self.AUDIO_TOKEN_REGEX = re.compile(
42
+ r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
43
+ )
44
+
45
+ self.IM_TOKEN_ID = hf_config.image_token_id
46
+ self.IM_START_TOKEN_ID = hf_config.boi_token_id
47
+ self.IM_END_TOKEN_ID = hf_config.eoi_token_id
48
+
49
+ self.AUDIO_TOKEN_ID = hf_config.audio_token_id
50
+ self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
51
+ self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
52
+
53
+ async def process_mm_data_async(
54
+ self,
55
+ image_data: Optional[List[Union[str, bytes, Dict]]] = None,
56
+ audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
57
+ input_text: str = "",
58
+ request_obj=None,
59
+ max_req_input_len: int = 0,
60
+ *args,
61
+ **kwargs,
62
+ ):
63
+ """Process multimodal data including images and audio."""
64
+
65
+ audio_data = request_obj.audio_data
66
+ if not image_data and not audio_data:
67
+ return None
68
+
69
+ if isinstance(image_data, str):
70
+ image_data = [image_data]
71
+
72
+ if isinstance(audio_data, str):
73
+ audio_data = [audio_data]
74
+
75
+ base_output = self.load_mm_data(
76
+ prompt=input_text,
77
+ image_data=image_data,
78
+ audio_data=audio_data,
79
+ max_req_input_len=max_req_input_len,
80
+ multimodal_tokens=MultimodalSpecialTokens(
81
+ image_token=self.IMAGE_TOKEN,
82
+ image_token_regex=self.IMAGE_TOKEN_REGEX,
83
+ audio_token=self.AUDIO_TOKEN,
84
+ audio_token_regex=self.AUDIO_TOKEN_REGEX,
85
+ ),
86
+ )
87
+
88
+ combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
89
+
90
+ return {
91
+ "input_ids": input_ids.tolist(),
92
+ "mm_items": [combined_mm_item] if combined_mm_item is not None else [],
93
+ "im_start_id": self.IM_START_TOKEN_ID,
94
+ "im_end_id": self.IM_END_TOKEN_ID,
95
+ "audio_start_id": self.AUDIO_START_TOKEN_ID,
96
+ "audio_end_id": self.AUDIO_END_TOKEN_ID,
97
+ }