sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,14 @@ from sglang.srt.layers.quantization.base_config import (
|
|
18
18
|
QuantizationConfig,
|
19
19
|
QuantizeMethodBase,
|
20
20
|
)
|
21
|
-
from sglang.srt.utils import
|
21
|
+
from sglang.srt.utils import (
|
22
|
+
_process_weight_after_loading,
|
23
|
+
cpu_has_amx_support,
|
24
|
+
get_bool_env_var,
|
25
|
+
is_cpu,
|
26
|
+
is_hip,
|
27
|
+
set_weight_attrs,
|
28
|
+
)
|
22
29
|
|
23
30
|
if torch.cuda.is_available():
|
24
31
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -28,6 +35,8 @@ else:
|
|
28
35
|
import logging
|
29
36
|
|
30
37
|
_is_hip = is_hip()
|
38
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
39
|
+
_is_cpu = is_cpu()
|
31
40
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
32
41
|
|
33
42
|
if _use_aiter:
|
@@ -117,6 +126,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
117
126
|
requires_grad=False,
|
118
127
|
)
|
119
128
|
torch.cuda.empty_cache()
|
129
|
+
|
130
|
+
# Pack weight for get better performance on CPU
|
131
|
+
if _is_cpu and _is_cpu_amx_available:
|
132
|
+
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
133
|
+
|
120
134
|
return
|
121
135
|
|
122
136
|
def apply(
|
@@ -248,19 +262,64 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
248
262
|
no_combine: bool = False,
|
249
263
|
routed_scaling_factor: Optional[float] = None,
|
250
264
|
) -> torch.Tensor:
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
265
|
+
assert activation == "silu", f"activation = {activation} is not supported."
|
266
|
+
|
267
|
+
if (
|
268
|
+
getattr(layer, "use_intel_amx_backend", False)
|
269
|
+
and not apply_router_weight_on_input
|
270
|
+
):
|
271
|
+
topk_weights, topk_ids = select_experts(
|
272
|
+
hidden_states=x,
|
273
|
+
router_logits=router_logits,
|
274
|
+
use_grouped_topk=use_grouped_topk,
|
275
|
+
top_k=top_k,
|
276
|
+
renormalize=renormalize,
|
277
|
+
topk_group=topk_group,
|
278
|
+
num_expert_group=num_expert_group,
|
279
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
280
|
+
custom_routing_function=custom_routing_function,
|
281
|
+
correction_bias=correction_bias,
|
282
|
+
routed_scaling_factor=routed_scaling_factor,
|
283
|
+
)
|
284
|
+
|
285
|
+
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
286
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
287
|
+
x,
|
288
|
+
layer.w13_weight,
|
289
|
+
layer.w2_weight,
|
290
|
+
topk_weights.to(
|
291
|
+
torch.float
|
292
|
+
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
293
|
+
topk_ids,
|
294
|
+
True, # inplace
|
295
|
+
False, # use_int8_w8a8
|
296
|
+
False, # use_fp8_w8a16
|
297
|
+
None, # w1_scale
|
298
|
+
None, # w2_scale
|
299
|
+
None, # block_size
|
300
|
+
None, # a1_scale
|
301
|
+
None, # a2_scale
|
302
|
+
True, # is_vnni
|
303
|
+
)
|
304
|
+
else:
|
305
|
+
return moe_forward_native(
|
306
|
+
layer,
|
307
|
+
x,
|
308
|
+
use_grouped_topk,
|
309
|
+
top_k,
|
310
|
+
router_logits,
|
311
|
+
renormalize,
|
312
|
+
topk_group,
|
313
|
+
num_expert_group,
|
314
|
+
num_fused_shared_experts,
|
315
|
+
custom_routing_function,
|
316
|
+
correction_bias,
|
317
|
+
activation,
|
318
|
+
apply_router_weight_on_input,
|
319
|
+
inplace,
|
320
|
+
no_combine,
|
321
|
+
routed_scaling_factor,
|
322
|
+
)
|
264
323
|
|
265
324
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
266
325
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -30,6 +30,7 @@ from sglang.srt.managers.expert_location_dispatch import (
|
|
30
30
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
31
|
from sglang.srt.utils import (
|
32
32
|
cpu_has_amx_support,
|
33
|
+
get_bool_env_var,
|
33
34
|
get_compiler_backend,
|
34
35
|
is_cpu,
|
35
36
|
is_cuda,
|
@@ -38,6 +39,7 @@ from sglang.srt.utils import (
|
|
38
39
|
|
39
40
|
_is_cuda = is_cuda()
|
40
41
|
_is_hip = is_hip()
|
42
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
41
43
|
_is_cpu_amx_available = cpu_has_amx_support()
|
42
44
|
_is_cpu = is_cpu()
|
43
45
|
|
@@ -46,6 +48,11 @@ if _is_cuda:
|
|
46
48
|
|
47
49
|
if _is_cuda or _is_hip:
|
48
50
|
from sgl_kernel import topk_softmax
|
51
|
+
if _use_aiter:
|
52
|
+
try:
|
53
|
+
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
54
|
+
except ImportError:
|
55
|
+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
49
56
|
|
50
57
|
|
51
58
|
def fused_topk_torch_native(
|
@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu(
|
|
347
354
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
348
355
|
)
|
349
356
|
return topk_weights, topk_ids
|
357
|
+
elif _use_aiter:
|
358
|
+
token = gating_output.shape[0]
|
359
|
+
device = gating_output.device
|
360
|
+
assert (
|
361
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
362
|
+
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
|
363
|
+
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
364
|
+
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
365
|
+
aiter_biased_grouped_topk(
|
366
|
+
gating_output,
|
367
|
+
correction_bias,
|
368
|
+
topk_weights,
|
369
|
+
topk_ids,
|
370
|
+
num_expert_group,
|
371
|
+
topk_group,
|
372
|
+
renormalize,
|
373
|
+
routed_scaling_factor,
|
374
|
+
)
|
375
|
+
return topk_weights, topk_ids
|
350
376
|
else:
|
351
377
|
biased_grouped_topk_fn = (
|
352
378
|
torch.compile(
|
@@ -42,7 +42,10 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
42
42
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
43
|
|
44
44
|
if _use_aiter:
|
45
|
-
|
45
|
+
import aiter
|
46
|
+
from aiter import gemm_a8w8_blockscale_CK, get_hip_quant
|
47
|
+
|
48
|
+
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
46
49
|
|
47
50
|
if _is_cuda:
|
48
51
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
@@ -271,9 +274,7 @@ def aiter_w8a8_block_fp8_linear(
|
|
271
274
|
input_2d = input.view(-1, input.shape[-1])
|
272
275
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
273
276
|
|
274
|
-
q_input, x_scale =
|
275
|
-
input_2d, block_size[1], column_major_scales=False
|
276
|
-
)
|
277
|
+
q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
|
277
278
|
output = gemm_a8w8_blockscale_CK(
|
278
279
|
q_input, weight, x_scale, weight_scale, dtype=input.dtype
|
279
280
|
)
|
@@ -8,16 +8,29 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import
|
11
|
+
from sglang.srt.utils import (
|
12
|
+
cpu_has_amx_support,
|
13
|
+
get_bool_env_var,
|
14
|
+
is_cpu,
|
15
|
+
is_cuda,
|
16
|
+
is_hip,
|
17
|
+
is_npu,
|
18
|
+
)
|
12
19
|
|
13
20
|
_is_cuda = is_cuda()
|
14
21
|
_is_hip = is_hip()
|
22
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
15
23
|
_is_npu = is_npu()
|
16
24
|
_is_cpu_amx_available = cpu_has_amx_support()
|
17
25
|
_is_cpu = is_cpu()
|
18
26
|
|
19
27
|
if _is_cuda:
|
20
28
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
29
|
+
if _use_aiter:
|
30
|
+
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
31
|
+
|
32
|
+
if is_npu():
|
33
|
+
import torch_npu
|
21
34
|
|
22
35
|
|
23
36
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -152,6 +165,36 @@ class RotaryEmbedding(CustomOp):
|
|
152
165
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
153
166
|
return query, key
|
154
167
|
|
168
|
+
def forward_npu(
|
169
|
+
self,
|
170
|
+
positions: torch.Tensor,
|
171
|
+
query: torch.Tensor,
|
172
|
+
key: torch.Tensor,
|
173
|
+
offsets: Optional[torch.Tensor] = None,
|
174
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
175
|
+
"""A PyTorch-npu implementation of forward()."""
|
176
|
+
import os
|
177
|
+
|
178
|
+
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
179
|
+
return self.forward_native(positions, query, key, offsets)
|
180
|
+
else:
|
181
|
+
rotary_mode = "half"
|
182
|
+
if self.is_neox_style:
|
183
|
+
rotary_mode = "half"
|
184
|
+
else:
|
185
|
+
rotary_mode = "interleave"
|
186
|
+
mrope_section = [0, 0, 0]
|
187
|
+
query_out, key_out = torch_npu.npu_mrope(
|
188
|
+
positions,
|
189
|
+
query,
|
190
|
+
key,
|
191
|
+
self.cos_sin_cache,
|
192
|
+
self.head_size,
|
193
|
+
mrope_section=mrope_section,
|
194
|
+
rotary_mode=rotary_mode,
|
195
|
+
)
|
196
|
+
return query_out, key_out
|
197
|
+
|
155
198
|
def forward_cpu(
|
156
199
|
self,
|
157
200
|
positions: torch.Tensor,
|
@@ -847,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
|
847
890
|
return query_out.type_as(query), key_out.type_as(key)
|
848
891
|
|
849
892
|
|
893
|
+
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
|
894
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
895
|
+
|
896
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
897
|
+
"""
|
898
|
+
|
899
|
+
def __init__(
|
900
|
+
self,
|
901
|
+
head_size: int,
|
902
|
+
rotary_dim: int,
|
903
|
+
max_position_embeddings: int,
|
904
|
+
base: int,
|
905
|
+
is_neox_style: bool,
|
906
|
+
scaling_alpha: float,
|
907
|
+
dtype: torch.dtype,
|
908
|
+
) -> None:
|
909
|
+
self.scaling_alpha = scaling_alpha
|
910
|
+
super().__init__(
|
911
|
+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
912
|
+
)
|
913
|
+
|
914
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
915
|
+
max_len = self.max_position_embeddings
|
916
|
+
base = self.base * self.scaling_alpha ** (
|
917
|
+
self.rotary_dim / (self.rotary_dim - 2)
|
918
|
+
)
|
919
|
+
|
920
|
+
inv_freq = self._compute_inv_freq(base)
|
921
|
+
t = torch.arange(max_len, dtype=torch.float)
|
922
|
+
|
923
|
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
924
|
+
cos = freqs.cos()
|
925
|
+
sin = freqs.sin()
|
926
|
+
cache = torch.cat((cos, sin), dim=-1)
|
927
|
+
return cache
|
928
|
+
|
929
|
+
|
850
930
|
class MRotaryEmbedding(RotaryEmbedding):
|
851
931
|
"""Rotary Embedding with Multimodal Sections."""
|
852
932
|
|
@@ -1191,15 +1271,26 @@ def get_rope(
|
|
1191
1271
|
)
|
1192
1272
|
elif scaling_type == "dynamic":
|
1193
1273
|
scaling_factor = rope_scaling["factor"]
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1274
|
+
if "alpha" in rope_scaling:
|
1275
|
+
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
|
1276
|
+
head_size,
|
1277
|
+
rotary_dim,
|
1278
|
+
max_position,
|
1279
|
+
base,
|
1280
|
+
is_neox_style,
|
1281
|
+
rope_scaling["alpha"],
|
1282
|
+
dtype,
|
1283
|
+
)
|
1284
|
+
else:
|
1285
|
+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
1286
|
+
head_size,
|
1287
|
+
rotary_dim,
|
1288
|
+
max_position,
|
1289
|
+
base,
|
1290
|
+
is_neox_style,
|
1291
|
+
scaling_factor,
|
1292
|
+
dtype,
|
1293
|
+
)
|
1203
1294
|
elif scaling_type == "yarn":
|
1204
1295
|
scaling_factor = rope_scaling["factor"]
|
1205
1296
|
original_max_position = rope_scaling["original_max_position_embeddings"]
|
@@ -1388,7 +1479,8 @@ def get_rope_wrapper(
|
|
1388
1479
|
device: Optional[str] = None,
|
1389
1480
|
):
|
1390
1481
|
if device != "cpu":
|
1391
|
-
|
1482
|
+
wrapper = aiter_get_rope if _use_aiter else get_rope
|
1483
|
+
return wrapper(
|
1392
1484
|
head_size,
|
1393
1485
|
rotary_dim,
|
1394
1486
|
max_position,
|
@@ -20,10 +20,18 @@ from sglang.srt.layers.quantization.base_config import (
|
|
20
20
|
QuantizeMethodBase,
|
21
21
|
method_has_implemented_embedding,
|
22
22
|
)
|
23
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
PackWeightMethod,
|
25
|
+
cpu_has_amx_support,
|
26
|
+
is_cpu,
|
27
|
+
set_weight_attrs,
|
28
|
+
)
|
24
29
|
|
25
30
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
26
31
|
|
32
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
33
|
+
_is_cpu = is_cpu()
|
34
|
+
|
27
35
|
|
28
36
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
29
37
|
"""Unquantized method for embeddings."""
|
@@ -549,6 +557,11 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
549
557
|
use_presharded_weights=use_presharded_weights,
|
550
558
|
)
|
551
559
|
self.quant_config = quant_config
|
560
|
+
|
561
|
+
# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
|
562
|
+
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
|
563
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
564
|
+
|
552
565
|
if bias:
|
553
566
|
self.bias = Parameter(
|
554
567
|
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
|
@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
|
|
61
61
|
def with_debug_name(self, debug_name):
|
62
62
|
yield
|
63
63
|
|
64
|
+
@contextmanager
|
65
|
+
def disable_this_region(self):
|
66
|
+
yield
|
67
|
+
|
64
68
|
@contextmanager
|
65
69
|
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
66
70
|
yield
|
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|
116
120
|
self._expert_location_metadata = expert_location_metadata
|
117
121
|
|
118
122
|
self._recording = False
|
123
|
+
self._disable_all = False
|
119
124
|
self._current_forward_pass_id = Withable()
|
120
125
|
self._current_layer_idx = Withable()
|
121
126
|
self._current_debug_name = Withable()
|
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|
148
153
|
finally:
|
149
154
|
self._on_forward_pass_end(forward_pass_id)
|
150
155
|
|
156
|
+
@contextmanager
|
157
|
+
def disable_this_region(self):
|
158
|
+
"""Context manager to temporarily disable recording."""
|
159
|
+
previous_disable_all = self._disable_all
|
160
|
+
self._disable_all = True
|
161
|
+
try:
|
162
|
+
yield
|
163
|
+
finally:
|
164
|
+
self._disable_all = previous_disable_all
|
165
|
+
|
151
166
|
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
|
152
167
|
if not self._recording:
|
153
168
|
return
|
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|
189
204
|
)
|
190
205
|
|
191
206
|
def _on_hook(self, hook_name: str, **kwargs):
|
207
|
+
if self._disable_all:
|
208
|
+
return
|
192
209
|
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
193
210
|
return
|
194
211
|
gatherer = self._single_pass_gatherers[
|
@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
|
462
479
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
463
480
|
topk_ids = topk_ids.flatten()
|
464
481
|
mask = topk_ids != -1
|
482
|
+
assert self._data[layer_idx, :].shape == topk_ids.shape, (
|
483
|
+
"Shape mismatch between data and topk_ids."
|
484
|
+
"Selecting expert is not supported for multiple token prediction at the moment."
|
485
|
+
)
|
465
486
|
self._data[layer_idx, :].scatter_add_(
|
466
487
|
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
467
488
|
)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -319,8 +319,16 @@ class GenerateReqInput:
|
|
319
319
|
"""Normalize request IDs for batch processing."""
|
320
320
|
if self.rid is None:
|
321
321
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
322
|
-
elif
|
323
|
-
|
322
|
+
elif isinstance(self.rid, str):
|
323
|
+
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
324
|
+
self.rid = new_rids
|
325
|
+
elif isinstance(self.rid, list):
|
326
|
+
if len(self.rid) != num:
|
327
|
+
raise ValueError(
|
328
|
+
"The specified rids length mismatch with the batch_size for batch processing."
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
raise ValueError("The rid should be a string or a list of strings.")
|
324
332
|
|
325
333
|
def _normalize_logprob_params(self, num):
|
326
334
|
"""Normalize logprob-related parameters for batch processing."""
|
@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
|
|
23
23
|
RAW_IMAGES = "raw_images"
|
24
24
|
PRECOMPUTED_FEATURES = "precomputed_features"
|
25
25
|
PIXEL_VALUES = "pixel_values"
|
26
|
+
AUDIO = "audio"
|
26
27
|
|
27
28
|
|
28
29
|
@dataclasses.dataclass
|
@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
|
|
441
442
|
has_image = False
|
442
443
|
has_pixel_values = False
|
443
444
|
has_precomputed_features = False
|
445
|
+
has_audio = False
|
444
446
|
|
445
447
|
for mm_input in mm_inputs:
|
446
448
|
if isinstance(mm_input, Image.Image):
|
447
449
|
has_image = True
|
450
|
+
elif isinstance(mm_input, np.ndarray):
|
451
|
+
has_audio = True
|
448
452
|
elif isinstance(mm_input, dict):
|
449
453
|
if mm_input.get("precomputed_features", None) is not None:
|
450
454
|
has_precomputed_features = True
|
@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
|
|
461
465
|
|
462
466
|
# Validate format consistency
|
463
467
|
format_count = sum(
|
464
|
-
[has_image, has_pixel_values, has_precomputed_features]
|
468
|
+
[has_image, has_pixel_values, has_precomputed_features, has_audio]
|
465
469
|
)
|
466
470
|
if format_count > 1:
|
467
471
|
raise ValueError(
|
468
472
|
"Unsupported: mixture of multimodal input formats. "
|
469
473
|
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
470
|
-
f"precomputed_features={has_precomputed_features}"
|
474
|
+
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
|
471
475
|
)
|
472
476
|
|
473
477
|
if has_image:
|
@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
|
|
476
480
|
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
477
481
|
elif has_pixel_values:
|
478
482
|
return MultimodalInputFormat.PIXEL_VALUES
|
483
|
+
elif has_audio:
|
484
|
+
return MultimodalInputFormat.AUDIO
|
479
485
|
else:
|
480
486
|
raise ValueError("No valid multimodal input format found")
|
481
487
|
except Exception as e:
|
@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
|
|
521
527
|
input_ids = tokenize_text(base_output.input_text)
|
522
528
|
return combined_mm_item, input_ids
|
523
529
|
|
530
|
+
def process_audio(
|
531
|
+
base_output: BaseMultiModalProcessorOutput,
|
532
|
+
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
533
|
+
"""Process inputs with audio."""
|
534
|
+
ret = self.process_mm_data(
|
535
|
+
input_text=base_output.input_text,
|
536
|
+
audio=base_output.audios, # Note: "audio" is for gemma3n only
|
537
|
+
)
|
538
|
+
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
|
539
|
+
for key, value in ret.items():
|
540
|
+
if key != "input_ids" and hasattr(combined_mm_item, key):
|
541
|
+
setattr(combined_mm_item, key, value)
|
542
|
+
input_ids = ret["input_ids"].flatten()
|
543
|
+
return combined_mm_item, input_ids
|
544
|
+
|
524
545
|
def finalize_mm_item(
|
525
546
|
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
526
547
|
) -> MultimodalDataItem:
|
527
548
|
"""Apply common post-processing to the multimodal item."""
|
528
|
-
combined_mm_item.
|
529
|
-
|
530
|
-
|
531
|
-
|
549
|
+
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
550
|
+
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
551
|
+
input_ids=input_ids,
|
552
|
+
mm_token_id=self.IM_TOKEN_ID,
|
553
|
+
)
|
554
|
+
elif combined_mm_item.modality == Modality.AUDIO:
|
555
|
+
combined_mm_item.audio_offsets = self.get_mm_items_offset(
|
556
|
+
input_ids=input_ids,
|
557
|
+
mm_token_id=self.AUDIO_TOKEN_ID,
|
558
|
+
)
|
559
|
+
elif combined_mm_item.modality == Modality.VIDEO:
|
560
|
+
combined_mm_item.video_offsets = self.get_mm_items_offset(
|
561
|
+
input_ids=input_ids,
|
562
|
+
mm_token_id=self.VIDEO_TOKEN_ID,
|
563
|
+
)
|
564
|
+
else:
|
565
|
+
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
|
532
566
|
return combined_mm_item
|
533
567
|
|
534
|
-
# Main logic
|
535
|
-
mm_inputs = base_output.images
|
568
|
+
# Main logic - determine input type and handle text-only case
|
569
|
+
mm_inputs = base_output.images or base_output.audios
|
536
570
|
if not mm_inputs:
|
537
|
-
# Return text-only case
|
538
571
|
input_ids = tokenize_text(base_output.input_text)
|
539
572
|
return None, input_ids
|
540
573
|
|
@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
|
|
548
581
|
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
549
582
|
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
550
583
|
combined_mm_item, input_ids = process_pixel_values(base_output)
|
584
|
+
elif input_format == MultimodalInputFormat.AUDIO:
|
585
|
+
combined_mm_item, input_ids = process_audio(base_output)
|
551
586
|
else:
|
552
587
|
raise ValueError(f"Unknown input format: {input_format}")
|
553
588
|
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
import re
|
16
|
+
from typing import Dict, List, Optional, Union
|
17
|
+
|
18
|
+
from sglang.srt.managers.multimodal_processor import (
|
19
|
+
BaseMultimodalProcessor as SGLangBaseProcessor,
|
20
|
+
)
|
21
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
22
|
+
MultimodalSpecialTokens,
|
23
|
+
)
|
24
|
+
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
|
25
|
+
|
26
|
+
|
27
|
+
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
28
|
+
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
|
29
|
+
|
30
|
+
models = [Gemma3nForConditionalGeneration]
|
31
|
+
|
32
|
+
def __init__(self, hf_config, server_args, _processor):
|
33
|
+
super().__init__(hf_config, server_args, _processor)
|
34
|
+
|
35
|
+
self.IMAGE_TOKEN = "<image_soft_token>"
|
36
|
+
self.IMAGE_TOKEN_REGEX = re.compile(
|
37
|
+
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
38
|
+
)
|
39
|
+
|
40
|
+
self.AUDIO_TOKEN = "<audio_soft_token>"
|
41
|
+
self.AUDIO_TOKEN_REGEX = re.compile(
|
42
|
+
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
|
43
|
+
)
|
44
|
+
|
45
|
+
self.IM_TOKEN_ID = hf_config.image_token_id
|
46
|
+
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
47
|
+
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
48
|
+
|
49
|
+
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
|
50
|
+
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
|
51
|
+
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
|
52
|
+
|
53
|
+
async def process_mm_data_async(
|
54
|
+
self,
|
55
|
+
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
56
|
+
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
57
|
+
input_text: str = "",
|
58
|
+
request_obj=None,
|
59
|
+
max_req_input_len: int = 0,
|
60
|
+
*args,
|
61
|
+
**kwargs,
|
62
|
+
):
|
63
|
+
"""Process multimodal data including images and audio."""
|
64
|
+
|
65
|
+
audio_data = request_obj.audio_data
|
66
|
+
if not image_data and not audio_data:
|
67
|
+
return None
|
68
|
+
|
69
|
+
if isinstance(image_data, str):
|
70
|
+
image_data = [image_data]
|
71
|
+
|
72
|
+
if isinstance(audio_data, str):
|
73
|
+
audio_data = [audio_data]
|
74
|
+
|
75
|
+
base_output = self.load_mm_data(
|
76
|
+
prompt=input_text,
|
77
|
+
image_data=image_data,
|
78
|
+
audio_data=audio_data,
|
79
|
+
max_req_input_len=max_req_input_len,
|
80
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
81
|
+
image_token=self.IMAGE_TOKEN,
|
82
|
+
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
83
|
+
audio_token=self.AUDIO_TOKEN,
|
84
|
+
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
85
|
+
),
|
86
|
+
)
|
87
|
+
|
88
|
+
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
89
|
+
|
90
|
+
return {
|
91
|
+
"input_ids": input_ids.tolist(),
|
92
|
+
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
93
|
+
"im_start_id": self.IM_START_TOKEN_ID,
|
94
|
+
"im_end_id": self.IM_END_TOKEN_ID,
|
95
|
+
"audio_start_id": self.AUDIO_START_TOKEN_ID,
|
96
|
+
"audio_end_id": self.AUDIO_END_TOKEN_ID,
|
97
|
+
}
|