sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/phimoe.py
CHANGED
@@ -13,6 +13,7 @@ from sglang.srt.layers.linear import (
|
|
13
13
|
)
|
14
14
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
15
15
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
16
|
+
from sglang.srt.layers.moe.topk import TopK
|
16
17
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
17
18
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
19
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -200,15 +201,19 @@ class PhiMoE(nn.Module):
|
|
200
201
|
quant_config=None,
|
201
202
|
)
|
202
203
|
|
204
|
+
self.topk = TopK(
|
205
|
+
top_k=top_k,
|
206
|
+
renormalize=False,
|
207
|
+
custom_routing_function=phimoe_routing_function,
|
208
|
+
)
|
209
|
+
|
203
210
|
self.experts = FusedMoE(
|
204
211
|
num_experts=num_experts,
|
205
212
|
top_k=top_k,
|
206
213
|
hidden_size=hidden_size,
|
207
214
|
intermediate_size=intermediate_size,
|
208
215
|
reduce_results=True,
|
209
|
-
renormalize=False,
|
210
216
|
quant_config=quant_config,
|
211
|
-
custom_routing_function=phimoe_routing_function,
|
212
217
|
prefix=add_prefix("experts", prefix),
|
213
218
|
)
|
214
219
|
|
@@ -219,7 +224,8 @@ class PhiMoE(nn.Module):
|
|
219
224
|
orig_shape = hidden_states.shape
|
220
225
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
221
226
|
router_logits, _ = self.gate(hidden_states)
|
222
|
-
|
227
|
+
topk_output = self.topk(hidden_states, router_logits)
|
228
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
223
229
|
return final_hidden_states.view(orig_shape)
|
224
230
|
|
225
231
|
|
sglang/srt/models/qwen.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
# Adapted from
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
17
17
|
|
18
|
+
import time
|
18
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
19
20
|
|
20
21
|
import torch
|
@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
|
|
286
287
|
input_ids, hidden_states, self.lm_head, forward_batch
|
287
288
|
)
|
288
289
|
|
290
|
+
@torch.no_grad()
|
291
|
+
def forward_split_prefill(
|
292
|
+
self,
|
293
|
+
input_ids: torch.Tensor,
|
294
|
+
positions: torch.Tensor,
|
295
|
+
forward_batch: ForwardBatch,
|
296
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
297
|
+
):
|
298
|
+
start, end = split_interval
|
299
|
+
# embed
|
300
|
+
if start == 0:
|
301
|
+
forward_batch.hidden_states = self.transformer.wte(input_ids)
|
302
|
+
|
303
|
+
# decoder layer
|
304
|
+
for i in range(start, end):
|
305
|
+
layer = self.transformer.h[i]
|
306
|
+
forward_batch.hidden_states = layer(
|
307
|
+
positions,
|
308
|
+
forward_batch.hidden_states,
|
309
|
+
forward_batch,
|
310
|
+
)
|
311
|
+
|
312
|
+
if end == self.transformer.config.num_hidden_layers:
|
313
|
+
# norm
|
314
|
+
forward_batch.hidden_states = self.transformer.ln_f(
|
315
|
+
forward_batch.hidden_states
|
316
|
+
)
|
317
|
+
# logits process
|
318
|
+
result = self.logits_processor(
|
319
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
320
|
+
)
|
321
|
+
else:
|
322
|
+
result = None
|
323
|
+
|
324
|
+
return result
|
325
|
+
|
289
326
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
290
327
|
stacked_params_mapping = [
|
291
328
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/qwen2.py
CHANGED
@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
|
|
481
481
|
else:
|
482
482
|
return hidden_states
|
483
483
|
|
484
|
+
@torch.no_grad()
|
485
|
+
def forward_split_prefill(
|
486
|
+
self,
|
487
|
+
input_ids: torch.Tensor,
|
488
|
+
positions: torch.Tensor,
|
489
|
+
forward_batch: ForwardBatch,
|
490
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
491
|
+
input_embeds: torch.Tensor = None,
|
492
|
+
):
|
493
|
+
start, end = split_interval
|
494
|
+
# embed
|
495
|
+
if start == 0:
|
496
|
+
if input_embeds is None:
|
497
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
498
|
+
else:
|
499
|
+
forward_batch.hidden_states = input_embeds
|
500
|
+
# decoder layer
|
501
|
+
for i in range(start, end):
|
502
|
+
layer = self.model.layers[i]
|
503
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
504
|
+
positions,
|
505
|
+
forward_batch.hidden_states,
|
506
|
+
forward_batch,
|
507
|
+
forward_batch.residual,
|
508
|
+
)
|
509
|
+
|
510
|
+
if end == self.model.config.num_hidden_layers:
|
511
|
+
# norm
|
512
|
+
hidden_states, _ = self.model.norm(
|
513
|
+
forward_batch.hidden_states, forward_batch.residual
|
514
|
+
)
|
515
|
+
forward_batch.hidden_states = hidden_states
|
516
|
+
# logits process
|
517
|
+
result = self.logits_processor(
|
518
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
519
|
+
)
|
520
|
+
else:
|
521
|
+
result = None
|
522
|
+
|
523
|
+
return result
|
524
|
+
|
484
525
|
@property
|
485
526
|
def start_layer(self):
|
486
527
|
return self.model.start_layer
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
497
497
|
|
498
498
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
499
499
|
# in qwen-vl, last dim is the same
|
500
|
-
pixel_values = torch.cat([item.
|
500
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
501
501
|
self.visual.dtype
|
502
502
|
)
|
503
503
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
508
508
|
|
509
509
|
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
510
510
|
# in qwen-vl, last dim is the same
|
511
|
-
pixel_values = torch.cat(
|
512
|
-
|
513
|
-
)
|
511
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
512
|
+
self.visual.dtype
|
513
|
+
)
|
514
514
|
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
515
515
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
516
516
|
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
sglang/srt/models/qwen2_audio.py
CHANGED
@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|
118
118
|
|
119
119
|
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
120
120
|
# Extract audio features from input items
|
121
|
-
input_features = torch.cat([item.
|
121
|
+
input_features = torch.cat([item.feature for item in items], dim=0).type(
|
122
122
|
self.audio_tower.dtype
|
123
123
|
)
|
124
124
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import (
|
|
43
43
|
ScatterMode,
|
44
44
|
)
|
45
45
|
from sglang.srt.layers.dp_attention import (
|
46
|
-
attn_tp_all_gather,
|
47
|
-
attn_tp_reduce_scatter,
|
48
|
-
dp_gather_partial,
|
49
|
-
dp_scatter,
|
50
46
|
get_attention_tp_rank,
|
51
47
|
get_attention_tp_size,
|
52
48
|
get_local_attention_dp_size,
|
@@ -61,6 +57,7 @@ from sglang.srt.layers.linear import (
|
|
61
57
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
62
58
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
|
63
59
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
60
|
+
from sglang.srt.layers.moe.topk import TopK
|
64
61
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
65
62
|
from sglang.srt.layers.radix_attention import RadixAttention
|
66
63
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -134,13 +131,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
134
131
|
f"the number of experts {config.num_experts}."
|
135
132
|
)
|
136
133
|
|
134
|
+
self.topk = TopK(
|
135
|
+
top_k=config.num_experts_per_tok,
|
136
|
+
renormalize=config.norm_topk_prob,
|
137
|
+
)
|
138
|
+
|
137
139
|
self.experts = get_moe_impl_class()(
|
138
140
|
layer_id=self.layer_id,
|
139
|
-
num_experts=config.num_experts,
|
140
141
|
top_k=config.num_experts_per_tok,
|
142
|
+
num_experts=config.num_experts,
|
141
143
|
hidden_size=config.hidden_size,
|
142
144
|
intermediate_size=config.moe_intermediate_size,
|
143
|
-
renormalize=config.norm_topk_prob,
|
144
145
|
quant_config=quant_config,
|
145
146
|
prefix=add_prefix("experts", prefix),
|
146
147
|
# Additional args for FusedMoE
|
@@ -189,9 +190,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
189
190
|
|
190
191
|
# router_logits: (num_tokens, n_experts)
|
191
192
|
router_logits, _ = self.gate(hidden_states)
|
192
|
-
|
193
|
-
|
194
|
-
)
|
193
|
+
topk_output = self.topk(hidden_states, router_logits)
|
194
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
195
195
|
if shared_output is not None:
|
196
196
|
final_hidden_states = final_hidden_states + shared_output
|
197
197
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
@@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module):
|
|
406
406
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
407
407
|
) -> None:
|
408
408
|
super().__init__()
|
409
|
+
self.config = config
|
409
410
|
self.padding_idx = config.pad_token_id
|
410
411
|
self.vocab_size = config.vocab_size
|
411
412
|
self.pp_group = get_pp_group()
|
@@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
554
555
|
else:
|
555
556
|
return hidden_states
|
556
557
|
|
558
|
+
@torch.no_grad()
|
559
|
+
def forward_split_prefill(
|
560
|
+
self,
|
561
|
+
input_ids: torch.Tensor,
|
562
|
+
positions: torch.Tensor,
|
563
|
+
forward_batch: ForwardBatch,
|
564
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
565
|
+
input_embeds: torch.Tensor = None,
|
566
|
+
):
|
567
|
+
start, end = split_interval
|
568
|
+
# embed
|
569
|
+
if start == 0:
|
570
|
+
if input_embeds is None:
|
571
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
572
|
+
else:
|
573
|
+
forward_batch.hidden_states = input_embeds
|
574
|
+
|
575
|
+
# decoder layer
|
576
|
+
for i in range(start, end):
|
577
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
578
|
+
layer = self.model.layers[i]
|
579
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
580
|
+
positions,
|
581
|
+
forward_batch.hidden_states,
|
582
|
+
forward_batch,
|
583
|
+
forward_batch.residual,
|
584
|
+
)
|
585
|
+
|
586
|
+
if end == self.model.config.num_hidden_layers:
|
587
|
+
# norm
|
588
|
+
hidden_states, _ = self.model.norm(
|
589
|
+
forward_batch.hidden_states, forward_batch.residual
|
590
|
+
)
|
591
|
+
forward_batch.hidden_states = hidden_states
|
592
|
+
# logits process
|
593
|
+
result = self.logits_processor(
|
594
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
595
|
+
)
|
596
|
+
else:
|
597
|
+
result = None
|
598
|
+
|
599
|
+
return result
|
600
|
+
|
557
601
|
@property
|
558
602
|
def start_layer(self):
|
559
603
|
return self.model.start_layer
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
484
484
|
|
485
485
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
486
486
|
# in qwen-vl, last dim is the same
|
487
|
-
pixel_values = torch.cat([item.
|
487
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
488
488
|
self.visual.dtype
|
489
489
|
)
|
490
490
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
495
495
|
|
496
496
|
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
497
497
|
# in qwen-vl, last dim is the same
|
498
|
-
pixel_values = torch.cat(
|
499
|
-
|
500
|
-
)
|
498
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
499
|
+
self.visual.dtype
|
500
|
+
)
|
501
501
|
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
502
502
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
503
503
|
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
sglang/srt/models/qwen3.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
# Adapted from qwen2.py
|
2
|
-
|
3
2
|
import logging
|
4
3
|
from functools import partial
|
5
4
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
@@ -331,6 +330,30 @@ class Qwen3ForCausalLM(nn.Module):
|
|
331
330
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
332
331
|
return self.model.get_input_embeddings(input_ids)
|
333
332
|
|
333
|
+
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
|
334
|
+
# return input_dim, output_dim
|
335
|
+
if module_name in ["q_proj", "qkv_proj"]:
|
336
|
+
return (
|
337
|
+
self.config.hidden_size,
|
338
|
+
self.config.head_dim * self.config.num_attention_heads,
|
339
|
+
)
|
340
|
+
elif module_name in ["o_proj"]:
|
341
|
+
return (
|
342
|
+
self.config.head_dim * self.config.num_attention_heads,
|
343
|
+
self.config.hidden_size,
|
344
|
+
)
|
345
|
+
elif module_name in ["kv_proj"]:
|
346
|
+
return (
|
347
|
+
self.config.hidden_size,
|
348
|
+
self.config.head_dim * self.config.num_key_value_heads,
|
349
|
+
)
|
350
|
+
elif module_name == "gate_up_proj":
|
351
|
+
return self.config.hidden_size, self.config.intermediate_size
|
352
|
+
elif module_name == "down_proj":
|
353
|
+
return self.config.intermediate_size, self.config.hidden_size
|
354
|
+
else:
|
355
|
+
raise NotImplementedError()
|
356
|
+
|
334
357
|
@torch.no_grad()
|
335
358
|
def forward(
|
336
359
|
self,
|
@@ -367,6 +390,47 @@ class Qwen3ForCausalLM(nn.Module):
|
|
367
390
|
else:
|
368
391
|
return hidden_states
|
369
392
|
|
393
|
+
@torch.no_grad()
|
394
|
+
def forward_split_prefill(
|
395
|
+
self,
|
396
|
+
input_ids: torch.Tensor,
|
397
|
+
positions: torch.Tensor,
|
398
|
+
forward_batch: ForwardBatch,
|
399
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
400
|
+
input_embeds: torch.Tensor = None,
|
401
|
+
):
|
402
|
+
start, end = split_interval
|
403
|
+
# embed
|
404
|
+
if start == 0:
|
405
|
+
if input_embeds is None:
|
406
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
407
|
+
else:
|
408
|
+
forward_batch.hidden_states = input_embeds
|
409
|
+
# decoder layer
|
410
|
+
for i in range(start, end):
|
411
|
+
layer = self.model.layers[i]
|
412
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
413
|
+
positions,
|
414
|
+
forward_batch.hidden_states,
|
415
|
+
forward_batch,
|
416
|
+
forward_batch.residual,
|
417
|
+
)
|
418
|
+
|
419
|
+
if end == self.model.config.num_hidden_layers:
|
420
|
+
# norm
|
421
|
+
hidden_states, _ = self.model.norm(
|
422
|
+
forward_batch.hidden_states, forward_batch.residual
|
423
|
+
)
|
424
|
+
forward_batch.hidden_states = hidden_states
|
425
|
+
# logits process
|
426
|
+
result = self.logits_processor(
|
427
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
428
|
+
)
|
429
|
+
else:
|
430
|
+
result = None
|
431
|
+
|
432
|
+
return result
|
433
|
+
|
370
434
|
@property
|
371
435
|
def start_layer(self):
|
372
436
|
return self.model.start_layer
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
|
38
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
39
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
40
40
|
from sglang.srt.layers.dp_attention import (
|
41
|
-
attn_tp_all_gather,
|
42
|
-
attn_tp_reduce_scatter,
|
43
|
-
dp_gather_partial,
|
44
|
-
dp_scatter,
|
45
41
|
get_attention_tp_rank,
|
46
42
|
get_attention_tp_size,
|
47
43
|
get_local_attention_dp_size,
|
@@ -56,8 +52,7 @@ from sglang.srt.layers.linear import (
|
|
56
52
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
57
53
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
58
54
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
59
|
-
from sglang.srt.layers.moe.
|
60
|
-
from sglang.srt.layers.moe.topk import select_experts
|
55
|
+
from sglang.srt.layers.moe.topk import TopK
|
61
56
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
62
57
|
from sglang.srt.layers.radix_attention import RadixAttention
|
63
58
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -102,6 +97,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
102
97
|
f"the number of experts {config.num_experts}."
|
103
98
|
)
|
104
99
|
|
100
|
+
self.topk = TopK(
|
101
|
+
top_k=config.num_experts_per_tok,
|
102
|
+
renormalize=config.norm_topk_prob,
|
103
|
+
use_grouped_topk=False,
|
104
|
+
)
|
105
|
+
|
105
106
|
self.experts = get_moe_impl_class()(
|
106
107
|
num_experts=config.num_experts
|
107
108
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
@@ -109,7 +110,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
109
110
|
layer_id=layer_id,
|
110
111
|
hidden_size=config.hidden_size,
|
111
112
|
intermediate_size=config.moe_intermediate_size,
|
112
|
-
renormalize=config.norm_topk_prob,
|
113
113
|
quant_config=quant_config,
|
114
114
|
prefix=add_prefix("experts", prefix),
|
115
115
|
**(
|
@@ -143,7 +143,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
143
143
|
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
144
144
|
)
|
145
145
|
self.top_k = config.num_experts_per_tok
|
146
|
-
self.renormalize = config.norm_topk_prob
|
147
146
|
|
148
147
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
149
148
|
group=parallel_state.get_tp_group().device_group,
|
@@ -180,9 +179,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
180
179
|
|
181
180
|
# router_logits: (num_tokens, n_experts)
|
182
181
|
router_logits, _ = self.gate(hidden_states)
|
183
|
-
|
184
|
-
|
185
|
-
)
|
182
|
+
topk_output = self.topk(hidden_states, router_logits)
|
183
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
186
184
|
if self.tp_size > 1:
|
187
185
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
188
186
|
|
@@ -191,17 +189,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
191
189
|
def forward_deepep(
|
192
190
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
193
191
|
) -> torch.Tensor:
|
194
|
-
|
195
|
-
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
192
|
+
if hidden_states.shape[0] > 0:
|
196
193
|
# router_logits: (num_tokens, n_experts)
|
197
194
|
router_logits, _ = self.gate(hidden_states)
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
router_logits=router_logits,
|
202
|
-
top_k=self.top_k,
|
203
|
-
use_grouped_topk=False,
|
204
|
-
renormalize=self.renormalize,
|
195
|
+
topk_weights, topk_idx, _ = self.topk(
|
196
|
+
hidden_states,
|
197
|
+
router_logits,
|
205
198
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
206
199
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
207
200
|
layer_id=self.layer_id,
|
@@ -267,12 +260,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
267
260
|
with get_global_expert_distribution_recorder().with_current_layer(
|
268
261
|
self.layer_id
|
269
262
|
):
|
270
|
-
state.topk_weights_local, state.topk_idx_local =
|
263
|
+
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
271
264
|
hidden_states=hidden_states,
|
272
265
|
router_logits=router_logits,
|
273
|
-
top_k=self.top_k,
|
274
|
-
use_grouped_topk=False,
|
275
|
-
renormalize=self.renormalize,
|
276
266
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
277
267
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
278
268
|
layer_id=self.layer_id,
|
@@ -745,6 +735,49 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
745
735
|
else:
|
746
736
|
return hidden_states
|
747
737
|
|
738
|
+
@torch.no_grad()
|
739
|
+
def forward_split_prefill(
|
740
|
+
self,
|
741
|
+
input_ids: torch.Tensor,
|
742
|
+
positions: torch.Tensor,
|
743
|
+
forward_batch: ForwardBatch,
|
744
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
745
|
+
input_embeds: torch.Tensor = None,
|
746
|
+
):
|
747
|
+
start, end = split_interval
|
748
|
+
# embed
|
749
|
+
if start == 0:
|
750
|
+
if input_embeds is None:
|
751
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
752
|
+
else:
|
753
|
+
forward_batch.hidden_states = input_embeds
|
754
|
+
|
755
|
+
# decoder layer
|
756
|
+
for i in range(start, end):
|
757
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
758
|
+
layer = self.model.layers[i]
|
759
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
760
|
+
positions,
|
761
|
+
forward_batch.hidden_states,
|
762
|
+
forward_batch,
|
763
|
+
forward_batch.residual,
|
764
|
+
)
|
765
|
+
|
766
|
+
if end == self.model.config.num_hidden_layers:
|
767
|
+
# norm
|
768
|
+
hidden_states, _ = self.model.norm(
|
769
|
+
forward_batch.hidden_states, forward_batch.residual
|
770
|
+
)
|
771
|
+
forward_batch.hidden_states = hidden_states
|
772
|
+
# logits process
|
773
|
+
result = self.logits_processor(
|
774
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
775
|
+
)
|
776
|
+
else:
|
777
|
+
result = None
|
778
|
+
|
779
|
+
return result
|
780
|
+
|
748
781
|
@property
|
749
782
|
def start_layer(self):
|
750
783
|
return self.model.start_layer
|
sglang/srt/models/vila.py
CHANGED
@@ -237,7 +237,7 @@ class VILAForConditionalGeneration(nn.Module):
|
|
237
237
|
return cast(LogitsProcessorOutput, output)
|
238
238
|
|
239
239
|
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
|
240
|
-
pixel_values = cast(Tensor, mm_input[0].
|
240
|
+
pixel_values = cast(Tensor, mm_input[0].feature)
|
241
241
|
|
242
242
|
##### BEGIN COPY modeling_vila.py #####
|
243
243
|
|