sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/hunyuan.py
CHANGED
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
|
|
40
40
|
)
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
43
|
+
from sglang.srt.layers.moe.topk import TopK
|
43
44
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
44
45
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
46
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -152,13 +153,16 @@ class HunYuanSparseMoeBlock(nn.Module):
|
|
152
153
|
else config.moe_intermediate_size[layer_id]
|
153
154
|
)
|
154
155
|
|
156
|
+
self.topk = TopK(
|
157
|
+
top_k=top_k,
|
158
|
+
renormalize=True if top_k > 1 else False,
|
159
|
+
)
|
160
|
+
|
155
161
|
self.experts = FusedMoE(
|
156
162
|
num_experts=config.num_experts,
|
157
|
-
top_k=top_k,
|
158
163
|
hidden_size=config.hidden_size,
|
159
164
|
intermediate_size=intermediate_size,
|
160
165
|
reduce_results=False,
|
161
|
-
renormalize=True if top_k > 1 else False,
|
162
166
|
quant_config=quant_config,
|
163
167
|
)
|
164
168
|
|
@@ -195,9 +199,8 @@ class HunYuanSparseMoeBlock(nn.Module):
|
|
195
199
|
|
196
200
|
# router_logits: (num_tokens, n_experts)
|
197
201
|
router_logits, _ = self.gate(hidden_states)
|
198
|
-
|
199
|
-
|
200
|
-
)
|
202
|
+
topk_output = self.topk(hidden_states, router_logits)
|
203
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
201
204
|
if shared_output is not None:
|
202
205
|
final_hidden_states = final_hidden_states + shared_output
|
203
206
|
if self.tp_size > 1:
|
@@ -206,6 +209,42 @@ class HunYuanSparseMoeBlock(nn.Module):
|
|
206
209
|
return final_hidden_states.view(orig_shape)
|
207
210
|
|
208
211
|
|
212
|
+
def get_head_dim(config):
|
213
|
+
if hasattr(config, "head_dim"):
|
214
|
+
return int(config.head_dim)
|
215
|
+
if hasattr(config, "attention_head_dim"):
|
216
|
+
return int(config.attention_head_dim)
|
217
|
+
|
218
|
+
# since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule
|
219
|
+
# wrong setting may cause runtime error, just throw error if this field is missing.
|
220
|
+
raise ValueError("Missing head dim config, try set head_dim in config.json")
|
221
|
+
|
222
|
+
|
223
|
+
def check_head_dim(config):
|
224
|
+
# Some models may lack `head_dim` and use `attention_head_dim` instead.
|
225
|
+
# This attribute is also used by flashinfer_backend.py, so we check for
|
226
|
+
# consistency and raise an error if it's not met to avoid silent failures.
|
227
|
+
# Although we could adapt the HunYuan model to use `attention_head_dim`,
|
228
|
+
# flashinfer expects `head_dim`, so we enforce its presence for correctness.
|
229
|
+
calc_head_dim = config.hidden_size // config.num_attention_heads
|
230
|
+
|
231
|
+
if hasattr(config, "attention_head_dim"):
|
232
|
+
if calc_head_dim != config.attention_head_dim and not hasattr(
|
233
|
+
config, "head_dim"
|
234
|
+
):
|
235
|
+
# in this case, flash infer(and other components may calculate wrong value.)
|
236
|
+
raise ValueError(
|
237
|
+
f"HunYuan model config error: calculated head_dim {calc_head_dim} != attention_head_dim {config.attention_head_dim}"
|
238
|
+
+ f"\nPlease Add head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
|
239
|
+
)
|
240
|
+
|
241
|
+
if hasattr(config, "head_dim") and config.attention_head_dim != config.head_dim:
|
242
|
+
raise ValueError(
|
243
|
+
f"HunYuan model config error: head_dim({config.head_dim}) != attention_head_dim({config.attention_head_dim})"
|
244
|
+
+ f"\nPlease change head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
|
245
|
+
)
|
246
|
+
|
247
|
+
|
209
248
|
class HunYuanAttention(nn.Module):
|
210
249
|
|
211
250
|
def __init__(
|
@@ -240,9 +279,11 @@ class HunYuanAttention(nn.Module):
|
|
240
279
|
assert tp_size % self.total_num_kv_heads == 0
|
241
280
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
242
281
|
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
243
|
-
|
244
|
-
|
245
|
-
|
282
|
+
# Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models.
|
283
|
+
self.head_dim = get_head_dim(config)
|
284
|
+
|
285
|
+
check_head_dim(config)
|
286
|
+
|
246
287
|
self.q_size = self.num_heads * self.head_dim
|
247
288
|
self.kv_size = self.num_kv_heads * self.head_dim
|
248
289
|
self.scaling = self.head_dim**-0.5
|
@@ -493,7 +534,6 @@ class HunYuanModel(nn.Module):
|
|
493
534
|
hidden_states = self.get_input_embeddings(input_ids)
|
494
535
|
residual = None
|
495
536
|
|
496
|
-
cla_factor = _get_cla_factor(self.config)
|
497
537
|
prev_kv_states = None
|
498
538
|
for i in range(len(self.layers)):
|
499
539
|
layer = self.layers[i]
|
@@ -560,6 +600,11 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
|
|
560
600
|
if config.tie_word_embeddings:
|
561
601
|
self.lm_head.weight = self.model.embed_tokens.weight
|
562
602
|
|
603
|
+
self.hidden_size = config.hidden_size
|
604
|
+
self.head_dim = get_head_dim(config)
|
605
|
+
|
606
|
+
check_head_dim(config)
|
607
|
+
|
563
608
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
564
609
|
self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
|
565
610
|
self.sampler = Sampler()
|
@@ -582,16 +627,14 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
|
|
582
627
|
self.config, "num_key_value_heads", self.config.num_attention_heads
|
583
628
|
)
|
584
629
|
num_key_value_groups = num_attention_heads // num_kv_heads
|
585
|
-
hidden_size = self.config.hidden_size
|
586
|
-
attention_head_dim = self.config.hidden_size // num_attention_heads
|
587
630
|
|
588
631
|
qkv = qkv.reshape(
|
589
|
-
num_kv_heads, num_key_value_groups + 2,
|
632
|
+
num_kv_heads, num_key_value_groups + 2, self.head_dim, self.hidden_size
|
590
633
|
)
|
591
634
|
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
|
592
|
-
q = q.reshape(-1, hidden_size)
|
593
|
-
k = k.reshape(-1, hidden_size)
|
594
|
-
v = v.reshape(-1, hidden_size)
|
635
|
+
q = q.reshape(-1, self.hidden_size)
|
636
|
+
k = k.reshape(-1, self.hidden_size)
|
637
|
+
v = v.reshape(-1, self.hidden_size)
|
595
638
|
return torch.concat((q, k, v))
|
596
639
|
# return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
|
597
640
|
|
@@ -768,4 +811,8 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
|
|
768
811
|
)
|
769
812
|
|
770
813
|
|
771
|
-
|
814
|
+
class HunYuanDenseV1ForCausalLM(HunYuanMoEV1ForCausalLM):
|
815
|
+
pass
|
816
|
+
|
817
|
+
|
818
|
+
EntryClass = [HunYuanMoEV1ForCausalLM, HunYuanDenseV1ForCausalLM]
|
sglang/srt/models/internvl.py
CHANGED
@@ -510,7 +510,7 @@ class InternVLChatModel(nn.Module):
|
|
510
510
|
Returns:
|
511
511
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
512
512
|
"""
|
513
|
-
pixel_values = torch.cat([item.
|
513
|
+
pixel_values = torch.cat([item.feature for item in items])
|
514
514
|
image_features = self.extract_feature(pixel_values)
|
515
515
|
return image_features
|
516
516
|
|
sglang/srt/models/kimi_vl.py
CHANGED
@@ -144,7 +144,7 @@ class KimiVLForConditionalGeneration(nn.Module):
|
|
144
144
|
|
145
145
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
146
146
|
pixel_values = (
|
147
|
-
torch.cat([item.
|
147
|
+
torch.cat([item.feature for item in items], dim=0)
|
148
148
|
.type(self.vision_tower.dtype)
|
149
149
|
.to(self.vision_tower.device)
|
150
150
|
)
|
sglang/srt/models/llama.py
CHANGED
@@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module):
|
|
480
480
|
else:
|
481
481
|
return hidden_states
|
482
482
|
|
483
|
+
@torch.no_grad()
|
484
|
+
def forward_split_prefill(
|
485
|
+
self,
|
486
|
+
input_ids: torch.Tensor,
|
487
|
+
positions: torch.Tensor,
|
488
|
+
forward_batch: ForwardBatch,
|
489
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
490
|
+
input_embeds: torch.Tensor = None,
|
491
|
+
) -> Optional[LogitsProcessorOutput]:
|
492
|
+
start, end = split_interval
|
493
|
+
# embed
|
494
|
+
if start == 0:
|
495
|
+
if input_embeds is None:
|
496
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
497
|
+
else:
|
498
|
+
forward_batch.hidden_states = input_embeds
|
499
|
+
# decoder layer
|
500
|
+
for i in range(start, end):
|
501
|
+
layer = self.model.layers[i]
|
502
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
503
|
+
positions,
|
504
|
+
forward_batch.hidden_states,
|
505
|
+
forward_batch,
|
506
|
+
forward_batch.residual,
|
507
|
+
)
|
508
|
+
|
509
|
+
if end == self.model.config.num_hidden_layers:
|
510
|
+
# norm
|
511
|
+
hidden_states, _ = self.model.norm(
|
512
|
+
forward_batch.hidden_states, forward_batch.residual
|
513
|
+
)
|
514
|
+
forward_batch.hidden_states = hidden_states
|
515
|
+
# logits process
|
516
|
+
result = self.logits_processor(
|
517
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
518
|
+
)
|
519
|
+
else:
|
520
|
+
result = None
|
521
|
+
|
522
|
+
return result
|
523
|
+
|
483
524
|
@property
|
484
525
|
def start_layer(self):
|
485
526
|
return self.model.start_layer
|
sglang/srt/models/llama4.py
CHANGED
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
|
|
40
40
|
RowParallelLinear,
|
41
41
|
)
|
42
42
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
43
|
+
from sglang.srt.layers.moe.topk import TopK
|
43
44
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
44
45
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
46
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -103,14 +104,17 @@ class Llama4MoE(nn.Module):
|
|
103
104
|
prefix=add_prefix("router", prefix),
|
104
105
|
)
|
105
106
|
|
107
|
+
self.topk = TopK(
|
108
|
+
top_k=self.top_k,
|
109
|
+
renormalize=False,
|
110
|
+
custom_routing_function=Llama4MoE.custom_routing_function,
|
111
|
+
)
|
112
|
+
|
106
113
|
self.experts = FusedMoE(
|
107
114
|
num_experts=config.num_local_experts,
|
108
|
-
top_k=config.num_experts_per_tok,
|
109
115
|
hidden_size=config.hidden_size,
|
110
|
-
custom_routing_function=Llama4MoE.custom_routing_function,
|
111
116
|
intermediate_size=intermediate_size_moe,
|
112
117
|
reduce_results=False,
|
113
|
-
renormalize=False,
|
114
118
|
quant_config=quant_config,
|
115
119
|
apply_router_weight_on_input=True,
|
116
120
|
prefix=add_prefix("experts", prefix),
|
@@ -147,10 +151,8 @@ class Llama4MoE(nn.Module):
|
|
147
151
|
# router_scores: [num_tokens, num_experts]
|
148
152
|
router_logits, _ = self.router(hidden_states)
|
149
153
|
shared_out = self.shared_expert(hidden_states)
|
150
|
-
|
151
|
-
|
152
|
-
router_logits=router_logits,
|
153
|
-
)
|
154
|
+
topk_output = self.topk(hidden_states, router_logits)
|
155
|
+
routed_out = self.experts(hidden_states, topk_output)
|
154
156
|
return shared_out, routed_out
|
155
157
|
|
156
158
|
def _forward_core_shared_routed_overlap(self, hidden_states):
|
@@ -163,10 +165,8 @@ class Llama4MoE(nn.Module):
|
|
163
165
|
with self.device_module.stream(alt_stream):
|
164
166
|
# router_scores: [num_tokens, num_experts]
|
165
167
|
router_logits, _ = self.router(hidden_states)
|
166
|
-
|
167
|
-
|
168
|
-
router_logits=router_logits,
|
169
|
-
)
|
168
|
+
topk_output = self.topk(hidden_states, router_logits)
|
169
|
+
routed_out = self.experts(hidden_states, topk_output)
|
170
170
|
self.device_module.current_stream().wait_stream(alt_stream)
|
171
171
|
|
172
172
|
return shared_out, routed_out
|
sglang/srt/models/llava.py
CHANGED
@@ -186,7 +186,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
186
186
|
bs = forward_batch.batch_size
|
187
187
|
pixel_values = flatten_nested_list(
|
188
188
|
[
|
189
|
-
[item.
|
189
|
+
[item.feature for item in image_inputs[i].mm_items]
|
190
190
|
for i in range(bs)
|
191
191
|
if need_vision[i]
|
192
192
|
]
|
@@ -753,7 +753,7 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
|
753
753
|
features = []
|
754
754
|
for item in items:
|
755
755
|
# in each item, we assume pixel_values is always batched
|
756
|
-
pixel_values, image_sizes = item.
|
756
|
+
pixel_values, image_sizes = item.feature, item.image_sizes
|
757
757
|
image_outputs = self.vision_tower(
|
758
758
|
pixel_values, image_sizes, output_hidden_states=True
|
759
759
|
)
|
sglang/srt/models/llavavid.py
CHANGED
@@ -135,7 +135,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
135
135
|
if need_vision.any():
|
136
136
|
pixel_values = flatten_nested_list(
|
137
137
|
[
|
138
|
-
[item.
|
138
|
+
[item.feature for item in image_inputs[i].mm_items]
|
139
139
|
for i in range(bs)
|
140
140
|
if need_vision[i]
|
141
141
|
]
|
sglang/srt/models/minicpm.py
CHANGED
@@ -138,8 +138,6 @@ class MiniCPMAttention(nn.Module):
|
|
138
138
|
base=rope_theta,
|
139
139
|
rope_scaling=rope_scaling,
|
140
140
|
)
|
141
|
-
# set rope as fp32 instead of bf16
|
142
|
-
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
|
143
141
|
self.attn = RadixAttention(
|
144
142
|
self.num_heads,
|
145
143
|
self.head_dim,
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -1552,9 +1552,7 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1552
1552
|
Returns:
|
1553
1553
|
List[List[torch.Tensor]]: audio embeddings
|
1554
1554
|
"""
|
1555
|
-
wavforms = flatten_nested_list(
|
1556
|
-
[item.audio_features for item in items if item.audio_features]
|
1557
|
-
)
|
1555
|
+
wavforms = flatten_nested_list([item.feature for item in items if item.feature])
|
1558
1556
|
# list, [[x1, x2], [y1], [z1]]
|
1559
1557
|
audio_feature_lens_raw = flatten_nested_list(
|
1560
1558
|
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
@@ -1659,9 +1657,7 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1659
1657
|
List[List[torch.Tensor]]: audio embeddings
|
1660
1658
|
"""
|
1661
1659
|
# (bs, 80, frames) or [], multi audios need filled in advance
|
1662
|
-
wavforms = flatten_nested_list(
|
1663
|
-
[item.audio_features for item in items if item.audio_features]
|
1664
|
-
)
|
1660
|
+
wavforms = flatten_nested_list([item.feature for item in items if item.feature])
|
1665
1661
|
# list, [[x1, x2], [y1], [z1]]
|
1666
1662
|
audio_feature_lens_raw = flatten_nested_list(
|
1667
1663
|
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
@@ -1778,7 +1774,7 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1778
1774
|
|
1779
1775
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
1780
1776
|
# list of tensors
|
1781
|
-
pixel_values = flatten_nested_list([item.
|
1777
|
+
pixel_values = flatten_nested_list([item.feature for item in items])
|
1782
1778
|
tgt_sizes = torch.stack(
|
1783
1779
|
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
1784
1780
|
)
|
sglang/srt/models/minicpmv.py
CHANGED
@@ -724,7 +724,7 @@ class MiniCPMV2_6(MiniCPMBaseModel):
|
|
724
724
|
|
725
725
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
726
726
|
# list of tensors
|
727
|
-
pixel_values = flatten_nested_list([item.
|
727
|
+
pixel_values = flatten_nested_list([item.feature for item in items])
|
728
728
|
tgt_sizes = torch.stack(
|
729
729
|
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
730
730
|
)
|
sglang/srt/models/mistral.py
CHANGED
@@ -56,7 +56,7 @@ class Mistral3ForConditionalGeneration:
|
|
56
56
|
features = []
|
57
57
|
for item in items:
|
58
58
|
# in each item, we assume pixel_values is always batched
|
59
|
-
pixel_values, image_sizes = item.
|
59
|
+
pixel_values, image_sizes = item.feature, item.image_sizes
|
60
60
|
image_outputs = self.vision_tower(
|
61
61
|
pixel_values, image_sizes, output_hidden_states=True
|
62
62
|
)
|
sglang/srt/models/mixtral.py
CHANGED
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
37
37
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
38
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
39
39
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
40
|
+
from sglang.srt.layers.moe.topk import TopK
|
40
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
43
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -86,6 +87,12 @@ class MixtralMoE(nn.Module):
|
|
86
87
|
quant_config=None,
|
87
88
|
prefix=add_prefix("gate", prefix),
|
88
89
|
)
|
90
|
+
|
91
|
+
self.topk = TopK(
|
92
|
+
top_k=top_k,
|
93
|
+
renormalize=True,
|
94
|
+
)
|
95
|
+
|
89
96
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
90
97
|
self.experts = MoEImpl(
|
91
98
|
num_experts=num_experts,
|
@@ -93,7 +100,6 @@ class MixtralMoE(nn.Module):
|
|
93
100
|
hidden_size=hidden_size,
|
94
101
|
intermediate_size=intermediate_size,
|
95
102
|
params_dtype=params_dtype,
|
96
|
-
renormalize=True,
|
97
103
|
quant_config=quant_config,
|
98
104
|
tp_size=tp_size,
|
99
105
|
prefix=add_prefix("experts", prefix),
|
@@ -105,7 +111,8 @@ class MixtralMoE(nn.Module):
|
|
105
111
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
106
112
|
# router_logits: (num_tokens, n_experts)
|
107
113
|
router_logits, _ = self.gate(hidden_states)
|
108
|
-
|
114
|
+
topk_output = self.topk(hidden_states, router_logits)
|
115
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
109
116
|
if self.tp_size > 1:
|
110
117
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
111
118
|
return final_hidden_states.view(orig_shape)
|
sglang/srt/models/mllama.py
CHANGED
@@ -838,9 +838,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
838
838
|
self.logits_processor = LogitsProcessor(config.text_config)
|
839
839
|
|
840
840
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
841
|
-
pixel_values = torch.cat(
|
842
|
-
[item.pixel_values for item in mm_inputs.mm_items], dim=0
|
843
|
-
)
|
841
|
+
pixel_values = torch.cat([item.feature for item in mm_inputs.mm_items], dim=0)
|
844
842
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
845
843
|
|
846
844
|
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
@@ -862,7 +860,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
862
860
|
|
863
861
|
if not forward_batch.encoder_cached[i] and mm_input is not None:
|
864
862
|
pixel_values = torch.cat(
|
865
|
-
[item.
|
863
|
+
[item.feature for item in mm_input.mm_items], dim=0
|
866
864
|
)
|
867
865
|
max_num_images = max(max_num_images, pixel_values.shape[1])
|
868
866
|
|
@@ -897,7 +895,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
897
895
|
|
898
896
|
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
899
897
|
pixel_values = torch.cat(
|
900
|
-
[item.
|
898
|
+
[item.feature for item in mm_input.mm_items], dim=0
|
901
899
|
)
|
902
900
|
for j in range(pixel_values.shape[1]):
|
903
901
|
img = pixel_values[0, j]
|
sglang/srt/models/mllama4.py
CHANGED
@@ -81,6 +81,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
81
81
|
self.logits_processor = LogitsProcessor(
|
82
82
|
config.text_config if hasattr(config, "text_config") else config
|
83
83
|
)
|
84
|
+
self.padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
84
85
|
|
85
86
|
def _has_vision_weights(self, config) -> bool:
|
86
87
|
"""Check if the model has vision components by examining the checkpoint."""
|
@@ -135,8 +136,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
135
136
|
return False
|
136
137
|
|
137
138
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
138
|
-
|
139
|
-
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
139
|
+
return self.padding_pattern.pad_input_tokens(input_ids, mm_inputs)
|
140
140
|
|
141
141
|
def get_image_feature(
|
142
142
|
self,
|
@@ -147,7 +147,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
147
147
|
raise ValueError("Vision model not available for text-only checkpoint")
|
148
148
|
|
149
149
|
pixel_values = (
|
150
|
-
torch.concat([item.
|
150
|
+
torch.concat([item.feature for item in items])
|
151
151
|
.to(next(self.vision_model.parameters()).device)
|
152
152
|
.type(next(self.vision_model.parameters()).dtype)
|
153
153
|
)
|
sglang/srt/models/olmoe.py
CHANGED
@@ -32,6 +32,7 @@ from sglang.srt.layers.linear import (
|
|
32
32
|
)
|
33
33
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
34
34
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
35
|
+
from sglang.srt.layers.moe.topk import TopK
|
35
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
38
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -76,13 +77,16 @@ class OlmoeMoE(nn.Module):
|
|
76
77
|
prefix=add_prefix("gate", prefix),
|
77
78
|
)
|
78
79
|
|
80
|
+
self.topk = TopK(
|
81
|
+
top_k=top_k,
|
82
|
+
renormalize=False,
|
83
|
+
)
|
84
|
+
|
79
85
|
self.experts = FusedMoE(
|
80
86
|
num_experts=num_experts,
|
81
|
-
top_k=top_k,
|
82
87
|
hidden_size=hidden_size,
|
83
88
|
intermediate_size=intermediate_size,
|
84
89
|
reduce_results=True,
|
85
|
-
renormalize=False,
|
86
90
|
quant_config=quant_config,
|
87
91
|
tp_size=tp_size,
|
88
92
|
prefix=add_prefix("experts", prefix),
|
@@ -94,9 +98,8 @@ class OlmoeMoE(nn.Module):
|
|
94
98
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
95
99
|
# router_logits: (num_tokens, n_experts)
|
96
100
|
router_logits, _ = self.gate(hidden_states)
|
97
|
-
|
98
|
-
|
99
|
-
)
|
101
|
+
topk_output = self.topk(hidden_states, router_logits)
|
102
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
100
103
|
return final_hidden_states.view(orig_shape)
|
101
104
|
|
102
105
|
|