sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/phimoe.py
CHANGED
@@ -13,6 +13,7 @@ from sglang.srt.layers.linear import (
|
|
13
13
|
)
|
14
14
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
15
15
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
16
|
+
from sglang.srt.layers.moe.topk import TopK
|
16
17
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
17
18
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
19
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -200,15 +201,19 @@ class PhiMoE(nn.Module):
|
|
200
201
|
quant_config=None,
|
201
202
|
)
|
202
203
|
|
204
|
+
self.topk = TopK(
|
205
|
+
top_k=top_k,
|
206
|
+
renormalize=False,
|
207
|
+
custom_routing_function=phimoe_routing_function,
|
208
|
+
)
|
209
|
+
|
203
210
|
self.experts = FusedMoE(
|
204
211
|
num_experts=num_experts,
|
205
212
|
top_k=top_k,
|
206
213
|
hidden_size=hidden_size,
|
207
214
|
intermediate_size=intermediate_size,
|
208
215
|
reduce_results=True,
|
209
|
-
renormalize=False,
|
210
216
|
quant_config=quant_config,
|
211
|
-
custom_routing_function=phimoe_routing_function,
|
212
217
|
prefix=add_prefix("experts", prefix),
|
213
218
|
)
|
214
219
|
|
@@ -219,7 +224,8 @@ class PhiMoE(nn.Module):
|
|
219
224
|
orig_shape = hidden_states.shape
|
220
225
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
221
226
|
router_logits, _ = self.gate(hidden_states)
|
222
|
-
|
227
|
+
topk_output = self.topk(hidden_states, router_logits)
|
228
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
223
229
|
return final_hidden_states.view(orig_shape)
|
224
230
|
|
225
231
|
|
sglang/srt/models/qwen.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
# Adapted from
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
17
17
|
|
18
|
+
import time
|
18
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
19
20
|
|
20
21
|
import torch
|
@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
|
|
286
287
|
input_ids, hidden_states, self.lm_head, forward_batch
|
287
288
|
)
|
288
289
|
|
290
|
+
@torch.no_grad()
|
291
|
+
def forward_split_prefill(
|
292
|
+
self,
|
293
|
+
input_ids: torch.Tensor,
|
294
|
+
positions: torch.Tensor,
|
295
|
+
forward_batch: ForwardBatch,
|
296
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
297
|
+
):
|
298
|
+
start, end = split_interval
|
299
|
+
# embed
|
300
|
+
if start == 0:
|
301
|
+
forward_batch.hidden_states = self.transformer.wte(input_ids)
|
302
|
+
|
303
|
+
# decoder layer
|
304
|
+
for i in range(start, end):
|
305
|
+
layer = self.transformer.h[i]
|
306
|
+
forward_batch.hidden_states = layer(
|
307
|
+
positions,
|
308
|
+
forward_batch.hidden_states,
|
309
|
+
forward_batch,
|
310
|
+
)
|
311
|
+
|
312
|
+
if end == self.transformer.config.num_hidden_layers:
|
313
|
+
# norm
|
314
|
+
forward_batch.hidden_states = self.transformer.ln_f(
|
315
|
+
forward_batch.hidden_states
|
316
|
+
)
|
317
|
+
# logits process
|
318
|
+
result = self.logits_processor(
|
319
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
320
|
+
)
|
321
|
+
else:
|
322
|
+
result = None
|
323
|
+
|
324
|
+
return result
|
325
|
+
|
289
326
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
290
327
|
stacked_params_mapping = [
|
291
328
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/qwen2.py
CHANGED
@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
|
|
481
481
|
else:
|
482
482
|
return hidden_states
|
483
483
|
|
484
|
+
@torch.no_grad()
|
485
|
+
def forward_split_prefill(
|
486
|
+
self,
|
487
|
+
input_ids: torch.Tensor,
|
488
|
+
positions: torch.Tensor,
|
489
|
+
forward_batch: ForwardBatch,
|
490
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
491
|
+
input_embeds: torch.Tensor = None,
|
492
|
+
):
|
493
|
+
start, end = split_interval
|
494
|
+
# embed
|
495
|
+
if start == 0:
|
496
|
+
if input_embeds is None:
|
497
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
498
|
+
else:
|
499
|
+
forward_batch.hidden_states = input_embeds
|
500
|
+
# decoder layer
|
501
|
+
for i in range(start, end):
|
502
|
+
layer = self.model.layers[i]
|
503
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
504
|
+
positions,
|
505
|
+
forward_batch.hidden_states,
|
506
|
+
forward_batch,
|
507
|
+
forward_batch.residual,
|
508
|
+
)
|
509
|
+
|
510
|
+
if end == self.model.config.num_hidden_layers:
|
511
|
+
# norm
|
512
|
+
hidden_states, _ = self.model.norm(
|
513
|
+
forward_batch.hidden_states, forward_batch.residual
|
514
|
+
)
|
515
|
+
forward_batch.hidden_states = hidden_states
|
516
|
+
# logits process
|
517
|
+
result = self.logits_processor(
|
518
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
519
|
+
)
|
520
|
+
else:
|
521
|
+
result = None
|
522
|
+
|
523
|
+
return result
|
524
|
+
|
484
525
|
@property
|
485
526
|
def start_layer(self):
|
486
527
|
return self.model.start_layer
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
497
497
|
|
498
498
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
499
499
|
# in qwen-vl, last dim is the same
|
500
|
-
pixel_values = torch.cat([item.
|
500
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
501
501
|
self.visual.dtype
|
502
502
|
)
|
503
503
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
508
508
|
|
509
509
|
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
510
510
|
# in qwen-vl, last dim is the same
|
511
|
-
pixel_values = torch.cat(
|
512
|
-
|
513
|
-
)
|
511
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
512
|
+
self.visual.dtype
|
513
|
+
)
|
514
514
|
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
515
515
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
516
516
|
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
sglang/srt/models/qwen2_audio.py
CHANGED
@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|
118
118
|
|
119
119
|
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
120
120
|
# Extract audio features from input items
|
121
|
-
input_features = torch.cat([item.
|
121
|
+
input_features = torch.cat([item.feature for item in items], dim=0).type(
|
122
122
|
self.audio_tower.dtype
|
123
123
|
)
|
124
124
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -61,6 +61,7 @@ from sglang.srt.layers.linear import (
|
|
61
61
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
62
62
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
|
63
63
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
64
|
+
from sglang.srt.layers.moe.topk import TopK
|
64
65
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
65
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
66
67
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -134,13 +135,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
134
135
|
f"the number of experts {config.num_experts}."
|
135
136
|
)
|
136
137
|
|
138
|
+
self.topk = TopK(
|
139
|
+
top_k=config.num_experts_per_tok,
|
140
|
+
renormalize=config.norm_topk_prob,
|
141
|
+
)
|
142
|
+
|
137
143
|
self.experts = get_moe_impl_class()(
|
138
144
|
layer_id=self.layer_id,
|
139
|
-
num_experts=config.num_experts,
|
140
145
|
top_k=config.num_experts_per_tok,
|
146
|
+
num_experts=config.num_experts,
|
141
147
|
hidden_size=config.hidden_size,
|
142
148
|
intermediate_size=config.moe_intermediate_size,
|
143
|
-
renormalize=config.norm_topk_prob,
|
144
149
|
quant_config=quant_config,
|
145
150
|
prefix=add_prefix("experts", prefix),
|
146
151
|
# Additional args for FusedMoE
|
@@ -189,9 +194,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
189
194
|
|
190
195
|
# router_logits: (num_tokens, n_experts)
|
191
196
|
router_logits, _ = self.gate(hidden_states)
|
192
|
-
|
193
|
-
|
194
|
-
)
|
197
|
+
topk_output = self.topk(hidden_states, router_logits)
|
198
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
195
199
|
if shared_output is not None:
|
196
200
|
final_hidden_states = final_hidden_states + shared_output
|
197
201
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
@@ -406,6 +410,7 @@ class Qwen2MoeModel(nn.Module):
|
|
406
410
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
407
411
|
) -> None:
|
408
412
|
super().__init__()
|
413
|
+
self.config = config
|
409
414
|
self.padding_idx = config.pad_token_id
|
410
415
|
self.vocab_size = config.vocab_size
|
411
416
|
self.pp_group = get_pp_group()
|
@@ -554,6 +559,49 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
554
559
|
else:
|
555
560
|
return hidden_states
|
556
561
|
|
562
|
+
@torch.no_grad()
|
563
|
+
def forward_split_prefill(
|
564
|
+
self,
|
565
|
+
input_ids: torch.Tensor,
|
566
|
+
positions: torch.Tensor,
|
567
|
+
forward_batch: ForwardBatch,
|
568
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
569
|
+
input_embeds: torch.Tensor = None,
|
570
|
+
):
|
571
|
+
start, end = split_interval
|
572
|
+
# embed
|
573
|
+
if start == 0:
|
574
|
+
if input_embeds is None:
|
575
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
576
|
+
else:
|
577
|
+
forward_batch.hidden_states = input_embeds
|
578
|
+
|
579
|
+
# decoder layer
|
580
|
+
for i in range(start, end):
|
581
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
582
|
+
layer = self.model.layers[i]
|
583
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
584
|
+
positions,
|
585
|
+
forward_batch.hidden_states,
|
586
|
+
forward_batch,
|
587
|
+
forward_batch.residual,
|
588
|
+
)
|
589
|
+
|
590
|
+
if end == self.model.config.num_hidden_layers:
|
591
|
+
# norm
|
592
|
+
hidden_states, _ = self.model.norm(
|
593
|
+
forward_batch.hidden_states, forward_batch.residual
|
594
|
+
)
|
595
|
+
forward_batch.hidden_states = hidden_states
|
596
|
+
# logits process
|
597
|
+
result = self.logits_processor(
|
598
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
599
|
+
)
|
600
|
+
else:
|
601
|
+
result = None
|
602
|
+
|
603
|
+
return result
|
604
|
+
|
557
605
|
@property
|
558
606
|
def start_layer(self):
|
559
607
|
return self.model.start_layer
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
484
484
|
|
485
485
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
486
486
|
# in qwen-vl, last dim is the same
|
487
|
-
pixel_values = torch.cat([item.
|
487
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
488
488
|
self.visual.dtype
|
489
489
|
)
|
490
490
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
495
495
|
|
496
496
|
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
497
497
|
# in qwen-vl, last dim is the same
|
498
|
-
pixel_values = torch.cat(
|
499
|
-
|
500
|
-
)
|
498
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
499
|
+
self.visual.dtype
|
500
|
+
)
|
501
501
|
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
502
502
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
503
503
|
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
sglang/srt/models/qwen3.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
# Adapted from qwen2.py
|
2
|
-
|
3
2
|
import logging
|
4
3
|
from functools import partial
|
5
4
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
@@ -331,6 +330,30 @@ class Qwen3ForCausalLM(nn.Module):
|
|
331
330
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
332
331
|
return self.model.get_input_embeddings(input_ids)
|
333
332
|
|
333
|
+
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
|
334
|
+
# return input_dim, output_dim
|
335
|
+
if module_name in ["q_proj", "qkv_proj"]:
|
336
|
+
return (
|
337
|
+
self.config.hidden_size,
|
338
|
+
self.config.head_dim * self.config.num_attention_heads,
|
339
|
+
)
|
340
|
+
elif module_name in ["o_proj"]:
|
341
|
+
return (
|
342
|
+
self.config.head_dim * self.config.num_attention_heads,
|
343
|
+
self.config.hidden_size,
|
344
|
+
)
|
345
|
+
elif module_name in ["kv_proj"]:
|
346
|
+
return (
|
347
|
+
self.config.hidden_size,
|
348
|
+
self.config.head_dim * self.config.num_key_value_heads,
|
349
|
+
)
|
350
|
+
elif module_name == "gate_up_proj":
|
351
|
+
return self.config.hidden_size, self.config.intermediate_size
|
352
|
+
elif module_name == "down_proj":
|
353
|
+
return self.config.intermediate_size, self.config.hidden_size
|
354
|
+
else:
|
355
|
+
raise NotImplementedError()
|
356
|
+
|
334
357
|
@torch.no_grad()
|
335
358
|
def forward(
|
336
359
|
self,
|
@@ -367,6 +390,47 @@ class Qwen3ForCausalLM(nn.Module):
|
|
367
390
|
else:
|
368
391
|
return hidden_states
|
369
392
|
|
393
|
+
@torch.no_grad()
|
394
|
+
def forward_split_prefill(
|
395
|
+
self,
|
396
|
+
input_ids: torch.Tensor,
|
397
|
+
positions: torch.Tensor,
|
398
|
+
forward_batch: ForwardBatch,
|
399
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
400
|
+
input_embeds: torch.Tensor = None,
|
401
|
+
):
|
402
|
+
start, end = split_interval
|
403
|
+
# embed
|
404
|
+
if start == 0:
|
405
|
+
if input_embeds is None:
|
406
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
407
|
+
else:
|
408
|
+
forward_batch.hidden_states = input_embeds
|
409
|
+
# decoder layer
|
410
|
+
for i in range(start, end):
|
411
|
+
layer = self.model.layers[i]
|
412
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
413
|
+
positions,
|
414
|
+
forward_batch.hidden_states,
|
415
|
+
forward_batch,
|
416
|
+
forward_batch.residual,
|
417
|
+
)
|
418
|
+
|
419
|
+
if end == self.model.config.num_hidden_layers:
|
420
|
+
# norm
|
421
|
+
hidden_states, _ = self.model.norm(
|
422
|
+
forward_batch.hidden_states, forward_batch.residual
|
423
|
+
)
|
424
|
+
forward_batch.hidden_states = hidden_states
|
425
|
+
# logits process
|
426
|
+
result = self.logits_processor(
|
427
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
428
|
+
)
|
429
|
+
else:
|
430
|
+
result = None
|
431
|
+
|
432
|
+
return result
|
433
|
+
|
370
434
|
@property
|
371
435
|
def start_layer(self):
|
372
436
|
return self.model.start_layer
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -56,8 +56,7 @@ from sglang.srt.layers.linear import (
|
|
56
56
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
57
57
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
58
58
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
59
|
-
from sglang.srt.layers.moe.
|
60
|
-
from sglang.srt.layers.moe.topk import select_experts
|
59
|
+
from sglang.srt.layers.moe.topk import TopK
|
61
60
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
62
61
|
from sglang.srt.layers.radix_attention import RadixAttention
|
63
62
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -102,6 +101,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
102
101
|
f"the number of experts {config.num_experts}."
|
103
102
|
)
|
104
103
|
|
104
|
+
self.topk = TopK(
|
105
|
+
top_k=config.num_experts_per_tok,
|
106
|
+
renormalize=config.norm_topk_prob,
|
107
|
+
use_grouped_topk=False,
|
108
|
+
)
|
109
|
+
|
105
110
|
self.experts = get_moe_impl_class()(
|
106
111
|
num_experts=config.num_experts
|
107
112
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
@@ -109,7 +114,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
109
114
|
layer_id=layer_id,
|
110
115
|
hidden_size=config.hidden_size,
|
111
116
|
intermediate_size=config.moe_intermediate_size,
|
112
|
-
renormalize=config.norm_topk_prob,
|
113
117
|
quant_config=quant_config,
|
114
118
|
prefix=add_prefix("experts", prefix),
|
115
119
|
**(
|
@@ -143,7 +147,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
143
147
|
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
144
148
|
)
|
145
149
|
self.top_k = config.num_experts_per_tok
|
146
|
-
self.renormalize = config.norm_topk_prob
|
147
150
|
|
148
151
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
149
152
|
group=parallel_state.get_tp_group().device_group,
|
@@ -180,9 +183,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
180
183
|
|
181
184
|
# router_logits: (num_tokens, n_experts)
|
182
185
|
router_logits, _ = self.gate(hidden_states)
|
183
|
-
|
184
|
-
|
185
|
-
)
|
186
|
+
topk_output = self.topk(hidden_states, router_logits)
|
187
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
186
188
|
if self.tp_size > 1:
|
187
189
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
188
190
|
|
@@ -195,13 +197,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
195
197
|
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
196
198
|
# router_logits: (num_tokens, n_experts)
|
197
199
|
router_logits, _ = self.gate(hidden_states)
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
router_logits=router_logits,
|
202
|
-
top_k=self.top_k,
|
203
|
-
use_grouped_topk=False,
|
204
|
-
renormalize=self.renormalize,
|
200
|
+
topk_weights, topk_idx, _ = self.topk(
|
201
|
+
hidden_states,
|
202
|
+
router_logits,
|
205
203
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
206
204
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
207
205
|
layer_id=self.layer_id,
|
@@ -267,12 +265,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
267
265
|
with get_global_expert_distribution_recorder().with_current_layer(
|
268
266
|
self.layer_id
|
269
267
|
):
|
270
|
-
state.topk_weights_local, state.topk_idx_local =
|
268
|
+
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
271
269
|
hidden_states=hidden_states,
|
272
270
|
router_logits=router_logits,
|
273
|
-
top_k=self.top_k,
|
274
|
-
use_grouped_topk=False,
|
275
|
-
renormalize=self.renormalize,
|
276
271
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
277
272
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
278
273
|
layer_id=self.layer_id,
|
@@ -745,6 +740,49 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
745
740
|
else:
|
746
741
|
return hidden_states
|
747
742
|
|
743
|
+
@torch.no_grad()
|
744
|
+
def forward_split_prefill(
|
745
|
+
self,
|
746
|
+
input_ids: torch.Tensor,
|
747
|
+
positions: torch.Tensor,
|
748
|
+
forward_batch: ForwardBatch,
|
749
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
750
|
+
input_embeds: torch.Tensor = None,
|
751
|
+
):
|
752
|
+
start, end = split_interval
|
753
|
+
# embed
|
754
|
+
if start == 0:
|
755
|
+
if input_embeds is None:
|
756
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
757
|
+
else:
|
758
|
+
forward_batch.hidden_states = input_embeds
|
759
|
+
|
760
|
+
# decoder layer
|
761
|
+
for i in range(start, end):
|
762
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
763
|
+
layer = self.model.layers[i]
|
764
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
765
|
+
positions,
|
766
|
+
forward_batch.hidden_states,
|
767
|
+
forward_batch,
|
768
|
+
forward_batch.residual,
|
769
|
+
)
|
770
|
+
|
771
|
+
if end == self.model.config.num_hidden_layers:
|
772
|
+
# norm
|
773
|
+
hidden_states, _ = self.model.norm(
|
774
|
+
forward_batch.hidden_states, forward_batch.residual
|
775
|
+
)
|
776
|
+
forward_batch.hidden_states = hidden_states
|
777
|
+
# logits process
|
778
|
+
result = self.logits_processor(
|
779
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
780
|
+
)
|
781
|
+
else:
|
782
|
+
result = None
|
783
|
+
|
784
|
+
return result
|
785
|
+
|
748
786
|
@property
|
749
787
|
def start_layer(self):
|
750
788
|
return self.model.start_layer
|
sglang/srt/models/vila.py
CHANGED
@@ -237,7 +237,7 @@ class VILAForConditionalGeneration(nn.Module):
|
|
237
237
|
return cast(LogitsProcessorOutput, output)
|
238
238
|
|
239
239
|
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
|
240
|
-
pixel_values = cast(Tensor, mm_input[0].
|
240
|
+
pixel_values = cast(Tensor, mm_input[0].feature)
|
241
241
|
|
242
242
|
##### BEGIN COPY modeling_vila.py #####
|
243
243
|
|