sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -84,7 +84,15 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
84
84
|
self.hidden_states = torch.zeros(
|
85
85
|
(
|
86
86
|
self.max_num_token,
|
87
|
-
|
87
|
+
(
|
88
|
+
self.model_runner.model_config.hf_config.target_hidden_size
|
89
|
+
* 3
|
90
|
+
if hasattr(
|
91
|
+
self.model_runner.model_config.hf_config,
|
92
|
+
"target_hidden_size",
|
93
|
+
)
|
94
|
+
else self.model_runner.model_config.hidden_size * 3
|
95
|
+
),
|
88
96
|
),
|
89
97
|
dtype=self.model_runner.dtype,
|
90
98
|
)
|
sglang/srt/two_batch_overlap.py
CHANGED
@@ -500,6 +500,7 @@ class TboForwardBatchPreparer:
|
|
500
500
|
"capture_hidden_mode",
|
501
501
|
"padded_static_len",
|
502
502
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
503
|
+
"split_index", # for split prefill
|
503
504
|
]:
|
504
505
|
output_dict[key] = getattr(batch, key)
|
505
506
|
if not batch.forward_mode.is_target_verify():
|
sglang/srt/utils.py
CHANGED
@@ -691,12 +691,17 @@ def decode_video_base64(video_base64):
|
|
691
691
|
) # Return an empty array and size tuple if no frames were found
|
692
692
|
|
693
693
|
|
694
|
-
def load_audio(
|
694
|
+
def load_audio(
|
695
|
+
audio_file: str, sr: Optional[int] = None, mono: bool = True
|
696
|
+
) -> np.ndarray:
|
695
697
|
# Use soundfile here, since librosa use it under the hood,
|
696
698
|
# and librosa will not support audio loading in the future
|
697
699
|
import soundfile as sf
|
698
700
|
from scipy.signal import resample
|
699
701
|
|
702
|
+
if sr is None:
|
703
|
+
sr = 16000
|
704
|
+
|
700
705
|
# Load audio data
|
701
706
|
if isinstance(audio_file, bytes):
|
702
707
|
audio, original_sr = sf.read(BytesIO(audio_file))
|
@@ -1417,6 +1422,13 @@ def get_nvgpu_memory_capacity():
|
|
1417
1422
|
]
|
1418
1423
|
|
1419
1424
|
if not memory_values:
|
1425
|
+
# Fallback to torch.cuda.mem_get_info() when failed to get memory capacity from nvidia-smi,
|
1426
|
+
# typically in NVIDIA MIG mode.
|
1427
|
+
if torch.cuda.is_available():
|
1428
|
+
logger.warning(
|
1429
|
+
"Failed to get GPU memory capacity from nvidia-smi, falling back to torch.cuda.mem_get_info()."
|
1430
|
+
)
|
1431
|
+
return torch.cuda.mem_get_info()[1] // 1024 // 1024 # unit: MB
|
1420
1432
|
raise ValueError("No GPU memory values found.")
|
1421
1433
|
|
1422
1434
|
# Return the minimum memory value
|
@@ -2880,3 +2892,17 @@ def parse_module_path(module_path, function_name, create_dummy):
|
|
2880
2892
|
return final_module, getattr(final_module, function_name)
|
2881
2893
|
|
2882
2894
|
return final_module, None
|
2895
|
+
|
2896
|
+
|
2897
|
+
# LoRA-related constants and utilities
|
2898
|
+
SUPPORTED_LORA_TARGET_MODULES = [
|
2899
|
+
"q_proj",
|
2900
|
+
"k_proj",
|
2901
|
+
"v_proj",
|
2902
|
+
"o_proj",
|
2903
|
+
"gate_proj",
|
2904
|
+
"up_proj",
|
2905
|
+
"down_proj",
|
2906
|
+
]
|
2907
|
+
|
2908
|
+
LORA_TARGET_ALL_MODULES = "all"
|
sglang/test/runners.py
CHANGED
@@ -134,10 +134,12 @@ class HFRunner:
|
|
134
134
|
model_type: str = "generation",
|
135
135
|
output_str_only: bool = False,
|
136
136
|
trust_remote_code: bool = False,
|
137
|
+
patch_model_do_sample_false: bool = False,
|
137
138
|
):
|
138
139
|
self.model_type = model_type
|
139
140
|
self.output_str_only = output_str_only
|
140
141
|
self.trust_remote_code = trust_remote_code
|
142
|
+
self.patch_model_do_sample_false = patch_model_do_sample_false
|
141
143
|
|
142
144
|
self.in_queue = mp.Queue()
|
143
145
|
self.out_queue = mp.Queue()
|
@@ -292,6 +294,7 @@ class HFRunner:
|
|
292
294
|
torch_dtype=torch_dtype,
|
293
295
|
output_str_only=self.output_str_only,
|
294
296
|
token_ids_logprob=token_ids_logprob,
|
297
|
+
patch_model_do_sample_false=self.patch_model_do_sample_false,
|
295
298
|
)
|
296
299
|
)
|
297
300
|
elif self.model_type == "embedding":
|
@@ -380,6 +383,7 @@ class HFRunner:
|
|
380
383
|
lora_paths: Optional[List[str]] = None,
|
381
384
|
output_str_only: bool = False,
|
382
385
|
token_ids_logprob: Optional[int] = None,
|
386
|
+
patch_model_do_sample_false: Optional[bool] = False,
|
383
387
|
) -> ModelOutput:
|
384
388
|
output_strs = []
|
385
389
|
top_input_logprobs = []
|
@@ -407,7 +411,8 @@ class HFRunner:
|
|
407
411
|
)
|
408
412
|
else:
|
409
413
|
model = base_model
|
410
|
-
|
414
|
+
if patch_model_do_sample_false:
|
415
|
+
model.generation_config.do_sample = False
|
411
416
|
outputs = model.generate(
|
412
417
|
input_ids=input_ids,
|
413
418
|
generation_config=GenerationConfig(
|
@@ -481,7 +486,7 @@ class SRTRunner:
|
|
481
486
|
torch_dtype: torch.dtype,
|
482
487
|
model_type: str,
|
483
488
|
tp_size: int = 1,
|
484
|
-
|
489
|
+
model_impl: str = "auto",
|
485
490
|
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
486
491
|
lora_paths: List[str] = None,
|
487
492
|
max_loras_per_batch: int = 4,
|
@@ -505,6 +510,9 @@ class SRTRunner:
|
|
505
510
|
torchao_config: Optional[str] = None,
|
506
511
|
cuda_graph_max_bs: int = 4,
|
507
512
|
sleep_on_idle=False,
|
513
|
+
max_lora_rank: Optional[int] = None,
|
514
|
+
lora_target_modules: Optional[List[str]] = None,
|
515
|
+
enable_lora: Optional[bool] = None,
|
508
516
|
):
|
509
517
|
self.model_type = model_type
|
510
518
|
self.is_generation = model_type == "generation"
|
@@ -523,7 +531,7 @@ class SRTRunner:
|
|
523
531
|
tp_size=tp_size,
|
524
532
|
dtype=get_dtype_str(torch_dtype),
|
525
533
|
port=port,
|
526
|
-
|
534
|
+
model_impl=model_impl,
|
527
535
|
torchao_config=torchao_config,
|
528
536
|
mem_fraction_static=mem_fraction_static,
|
529
537
|
trust_remote_code=trust_remote_code,
|
@@ -543,6 +551,9 @@ class SRTRunner:
|
|
543
551
|
cuda_graph_max_bs=cuda_graph_max_bs,
|
544
552
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
545
553
|
sleep_on_idle=sleep_on_idle,
|
554
|
+
max_lora_rank=max_lora_rank,
|
555
|
+
lora_target_modules=lora_target_modules,
|
556
|
+
enable_lora=enable_lora,
|
546
557
|
**spec_kwargs,
|
547
558
|
)
|
548
559
|
|
sglang/test/test_block_fp8.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6
6
|
|
7
7
|
from sglang.srt.layers.activation import SiluAndMul
|
8
8
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
9
|
+
from sglang.srt.layers.moe.topk import select_experts
|
9
10
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
10
11
|
per_tensor_quant_mla_fp8,
|
11
12
|
per_token_group_quant_fp8,
|
@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
|
497
498
|
score = torch.randn((M, E), dtype=dtype)
|
498
499
|
|
499
500
|
with torch.inference_mode():
|
501
|
+
topk_output = select_experts(
|
502
|
+
hidden_states=a,
|
503
|
+
router_logits=score,
|
504
|
+
top_k=topk,
|
505
|
+
renormalize=False,
|
506
|
+
)
|
500
507
|
out = fused_moe(
|
501
508
|
a,
|
502
509
|
w1,
|
503
510
|
w2,
|
504
|
-
|
505
|
-
topk,
|
506
|
-
renormalize=False,
|
511
|
+
topk_output,
|
507
512
|
use_fp8_w8a8=True,
|
508
513
|
w1_scale=w1_s,
|
509
514
|
w2_scale=w2_s,
|
sglang/test/test_block_fp8_ep.py
CHANGED
@@ -40,7 +40,7 @@ def ep_moe(
|
|
40
40
|
block_shape: Optional[List[int]] = None,
|
41
41
|
):
|
42
42
|
use_blockwise_fp8 = block_shape is not None
|
43
|
-
topk_weights, topk_ids = select_experts(
|
43
|
+
topk_weights, topk_ids, _ = select_experts(
|
44
44
|
hidden_states=hidden_states,
|
45
45
|
router_logits=router_logits,
|
46
46
|
top_k=top_k,
|
sglang/test/test_custom_ops.py
CHANGED
@@ -3,8 +3,13 @@
|
|
3
3
|
import pytest
|
4
4
|
import torch
|
5
5
|
|
6
|
-
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
7
|
-
from sglang.srt.utils import is_cuda
|
6
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
7
|
+
from sglang.srt.utils import is_cuda, is_hip
|
8
|
+
|
9
|
+
_is_cuda = is_cuda()
|
10
|
+
_is_hip = is_hip()
|
11
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
12
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
|
8
13
|
|
9
14
|
|
10
15
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
|
13
18
|
def quantize_ref_per_tensor(tensor, inv_scale):
|
14
19
|
# The reference implementation that fully aligns to
|
15
20
|
# the kernel being tested.
|
16
|
-
finfo = torch.finfo(
|
21
|
+
finfo = torch.finfo(fp8_dtype)
|
17
22
|
scale = inv_scale.reciprocal()
|
18
23
|
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
19
|
-
qweight = qweight.to(
|
24
|
+
qweight = qweight.to(fp8_dtype)
|
20
25
|
return qweight
|
21
26
|
|
22
27
|
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
|
48
53
|
)
|
49
54
|
|
50
55
|
|
51
|
-
if
|
56
|
+
if _is_cuda or _is_hip:
|
52
57
|
|
53
58
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
54
59
|
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
55
60
|
def quantize_ref_per_token(tensor, inv_scale):
|
56
61
|
# The reference implementation that fully aligns to
|
57
62
|
# the kernel being tested.
|
58
|
-
finfo = torch.finfo(
|
63
|
+
finfo = torch.finfo(fp8_dtype)
|
59
64
|
scale = inv_scale.reciprocal()
|
60
65
|
qweight = (tensor.to(torch.float32) * scale).clamp(
|
61
66
|
min=finfo.min, max=finfo.max
|
62
67
|
)
|
63
|
-
qweight = qweight.to(
|
68
|
+
qweight = qweight.to(fp8_dtype)
|
64
69
|
return qweight
|
65
70
|
|
66
71
|
def dequantize_per_token(tensor, inv_scale, dtype):
|
@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
|
100
100
|
s_strides2 = c_strides2
|
101
101
|
|
102
102
|
score = torch.randn((M, E), dtype=dtype, device=device)
|
103
|
-
topk_weights, topk_ids = select_experts(
|
103
|
+
topk_weights, topk_ids, _ = select_experts(
|
104
104
|
hidden_states=a,
|
105
105
|
router_logits=score,
|
106
106
|
top_k=topk,
|
107
|
-
use_grouped_topk=False,
|
108
|
-
renormalize=False,
|
109
107
|
)
|
110
108
|
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
111
109
|
expert_map[local_e:] = E
|
sglang/test/test_fp4_moe.py
CHANGED
@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
|
|
159
159
|
|
160
160
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
161
161
|
|
162
|
-
topk_weights, topk_ids = select_experts(
|
162
|
+
topk_weights, topk_ids, _ = select_experts(
|
163
163
|
hidden_states=a,
|
164
164
|
router_logits=score,
|
165
165
|
top_k=topk,
|
166
|
-
use_grouped_topk=False,
|
167
|
-
renormalize=False,
|
168
166
|
)
|
169
167
|
|
170
168
|
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
@@ -0,0 +1,286 @@
|
|
1
|
+
import types
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import pytest
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import fused_marlin_moe
|
7
|
+
|
8
|
+
from sglang.srt.layers.activation import SiluAndMul
|
9
|
+
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
10
|
+
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
11
|
+
|
12
|
+
|
13
|
+
def stack_and_dev(tensors: list[torch.Tensor]):
|
14
|
+
dev = tensors[0].device
|
15
|
+
return torch.stack(tensors, dim=0).to(dev)
|
16
|
+
|
17
|
+
|
18
|
+
def torch_experts(
|
19
|
+
a: torch.Tensor,
|
20
|
+
w1: torch.Tensor,
|
21
|
+
w2: torch.Tensor,
|
22
|
+
topk_weight: torch.Tensor,
|
23
|
+
topk_ids: torch.Tensor,
|
24
|
+
global_num_experts: int = -1,
|
25
|
+
expert_map: Optional[torch.Tensor] = None,
|
26
|
+
quant_dtype: Optional[torch.dtype] = None,
|
27
|
+
apply_router_weights_on_input: bool = False,
|
28
|
+
) -> torch.Tensor:
|
29
|
+
assert (
|
30
|
+
global_num_experts == -1
|
31
|
+
or (global_num_experts == w1.shape[0] and expert_map is None)
|
32
|
+
or (expert_map is not None and global_num_experts == expert_map.shape[0])
|
33
|
+
)
|
34
|
+
|
35
|
+
M, K = a.shape
|
36
|
+
topk = topk_ids.shape[1]
|
37
|
+
print("quant_dtype", quant_dtype)
|
38
|
+
# exit(0)
|
39
|
+
if apply_router_weights_on_input:
|
40
|
+
assert topk == 1
|
41
|
+
a = a * topk_weight.to(a.dtype)
|
42
|
+
|
43
|
+
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
44
|
+
|
45
|
+
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
46
|
+
|
47
|
+
num_experts = w1.shape[0]
|
48
|
+
|
49
|
+
topk_ids = topk_ids.view(-1)
|
50
|
+
if expert_map is not None:
|
51
|
+
topk_ids = expert_map[topk_ids]
|
52
|
+
|
53
|
+
f32 = torch.float32
|
54
|
+
|
55
|
+
for i in range(num_experts):
|
56
|
+
mask = topk_ids == i
|
57
|
+
if mask.sum():
|
58
|
+
if quant_dtype is None:
|
59
|
+
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
60
|
+
tmp2 = SiluAndMul()(tmp1)
|
61
|
+
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
62
|
+
|
63
|
+
if apply_router_weights_on_input:
|
64
|
+
return out
|
65
|
+
else:
|
66
|
+
return (
|
67
|
+
(out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
|
68
|
+
.sum(dim=1)
|
69
|
+
.to(out.dtype)
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
def torch_moe(
|
74
|
+
a: torch.Tensor,
|
75
|
+
w1: torch.Tensor,
|
76
|
+
w2: torch.Tensor,
|
77
|
+
score: torch.Tensor,
|
78
|
+
topk: int,
|
79
|
+
global_num_experts: int = -1,
|
80
|
+
expert_map: Optional[torch.Tensor] = None,
|
81
|
+
) -> torch.Tensor:
|
82
|
+
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
83
|
+
topk_weight, topk_ids = torch.topk(score, topk)
|
84
|
+
return torch_experts(
|
85
|
+
a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def marlin_moe_generate_valid_test_cases():
|
90
|
+
import itertools
|
91
|
+
|
92
|
+
m_list = [1, 123, 666]
|
93
|
+
n_list = [128, 1024]
|
94
|
+
k_list = [256, 2048]
|
95
|
+
e_list = [4, 12]
|
96
|
+
topk_list = [2, 3]
|
97
|
+
dtype_list = [torch.half, torch.bfloat16]
|
98
|
+
group_size_list = [128]
|
99
|
+
act_order_list = [True, False]
|
100
|
+
quant_type_list = [
|
101
|
+
scalar_types.uint4,
|
102
|
+
scalar_types.uint4b8,
|
103
|
+
]
|
104
|
+
is_k_full_list = [True, False]
|
105
|
+
|
106
|
+
all_combinations = itertools.product(
|
107
|
+
m_list,
|
108
|
+
n_list,
|
109
|
+
k_list,
|
110
|
+
e_list,
|
111
|
+
topk_list,
|
112
|
+
dtype_list,
|
113
|
+
group_size_list,
|
114
|
+
act_order_list,
|
115
|
+
quant_type_list,
|
116
|
+
is_k_full_list,
|
117
|
+
)
|
118
|
+
|
119
|
+
def is_invalid(
|
120
|
+
m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full
|
121
|
+
):
|
122
|
+
|
123
|
+
# Filter act_order
|
124
|
+
if act_order:
|
125
|
+
if group_size in (-1, k, n):
|
126
|
+
return False
|
127
|
+
if quant_type not in [scalar_types.uint4b8]:
|
128
|
+
return False
|
129
|
+
elif not is_k_full:
|
130
|
+
return False
|
131
|
+
|
132
|
+
return True
|
133
|
+
|
134
|
+
cases = []
|
135
|
+
for case in all_combinations:
|
136
|
+
if is_invalid(*case):
|
137
|
+
cases.append(case)
|
138
|
+
return cases
|
139
|
+
|
140
|
+
|
141
|
+
@pytest.mark.flaky(reruns=2)
|
142
|
+
@pytest.mark.parametrize(
|
143
|
+
("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"),
|
144
|
+
marlin_moe_generate_valid_test_cases(),
|
145
|
+
)
|
146
|
+
def test_fused_marlin_moe(
|
147
|
+
m: int,
|
148
|
+
n: int,
|
149
|
+
k: int,
|
150
|
+
e: int,
|
151
|
+
topk: int,
|
152
|
+
dtype: torch.dtype,
|
153
|
+
group_size: int,
|
154
|
+
act_order: bool,
|
155
|
+
quant_type: ScalarType,
|
156
|
+
is_k_full: bool,
|
157
|
+
):
|
158
|
+
if not torch.cuda.is_available():
|
159
|
+
pytest.skip("CUDA device not available")
|
160
|
+
|
161
|
+
torch.manual_seed(0)
|
162
|
+
|
163
|
+
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
164
|
+
|
165
|
+
# Filter act_order
|
166
|
+
if act_order:
|
167
|
+
if group_size == -1:
|
168
|
+
return
|
169
|
+
if group_size in (k, n):
|
170
|
+
return
|
171
|
+
if has_zp:
|
172
|
+
return
|
173
|
+
else:
|
174
|
+
if not is_k_full:
|
175
|
+
return
|
176
|
+
|
177
|
+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
178
|
+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
179
|
+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
180
|
+
|
181
|
+
e_map = None
|
182
|
+
|
183
|
+
w_ref1_l = []
|
184
|
+
qweight1_l = []
|
185
|
+
scales1_l = []
|
186
|
+
zeros1_l = []
|
187
|
+
g_idx1_l = []
|
188
|
+
sort_indices1_l = []
|
189
|
+
|
190
|
+
for i in range(w1.shape[0]):
|
191
|
+
if has_zp:
|
192
|
+
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
193
|
+
w1[i].transpose(1, 0), quant_type, group_size
|
194
|
+
)
|
195
|
+
|
196
|
+
w_ref1_l.append(w_ref1.T)
|
197
|
+
qweight1_l.append(qweight1)
|
198
|
+
scales1_l.append(scales1)
|
199
|
+
zeros1_l.append(zeros1)
|
200
|
+
else:
|
201
|
+
test_perm = torch.randperm(k)
|
202
|
+
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
203
|
+
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
204
|
+
)
|
205
|
+
|
206
|
+
w_ref1_l.append(w_ref1.T)
|
207
|
+
qweight1_l.append(qweight1)
|
208
|
+
scales1_l.append(scales1)
|
209
|
+
g_idx1_l.append(g_idx1)
|
210
|
+
sort_indices1_l.append(sort_indices1)
|
211
|
+
|
212
|
+
w_ref1 = stack_and_dev(w_ref1_l)
|
213
|
+
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
214
|
+
scales1 = stack_and_dev(scales1_l)
|
215
|
+
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
216
|
+
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
217
|
+
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
218
|
+
|
219
|
+
w_ref2_l = []
|
220
|
+
qweight2_l = []
|
221
|
+
scales2_l = []
|
222
|
+
zeros2_l = []
|
223
|
+
g_idx2_l = []
|
224
|
+
sort_indices2_l = []
|
225
|
+
|
226
|
+
for i in range(w2.shape[0]):
|
227
|
+
if has_zp:
|
228
|
+
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
229
|
+
w2[i].transpose(1, 0), quant_type, group_size
|
230
|
+
)
|
231
|
+
|
232
|
+
w_ref2_l.append(w_ref2.T)
|
233
|
+
qweight2_l.append(qweight2)
|
234
|
+
scales2_l.append(scales2)
|
235
|
+
zeros2_l.append(zeros2)
|
236
|
+
else:
|
237
|
+
test_perm = torch.randperm(n)
|
238
|
+
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
239
|
+
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
240
|
+
)
|
241
|
+
|
242
|
+
w_ref2_l.append(w_ref2.T)
|
243
|
+
qweight2_l.append(qweight2)
|
244
|
+
scales2_l.append(scales2)
|
245
|
+
g_idx2_l.append(g_idx2)
|
246
|
+
sort_indices2_l.append(sort_indices2)
|
247
|
+
|
248
|
+
w_ref2 = stack_and_dev(w_ref2_l)
|
249
|
+
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
250
|
+
scales2 = stack_and_dev(scales2_l)
|
251
|
+
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
252
|
+
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
253
|
+
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
254
|
+
|
255
|
+
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
256
|
+
from sglang.srt.layers.moe.topk import fused_topk_torch_native
|
257
|
+
|
258
|
+
topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False)
|
259
|
+
|
260
|
+
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
|
261
|
+
|
262
|
+
marlin_output = fused_marlin_moe(
|
263
|
+
a,
|
264
|
+
qweight1,
|
265
|
+
qweight2,
|
266
|
+
scales1,
|
267
|
+
scales2,
|
268
|
+
score,
|
269
|
+
topk_weights,
|
270
|
+
topk_ids,
|
271
|
+
g_idx1=g_idx1,
|
272
|
+
g_idx2=g_idx2,
|
273
|
+
sort_indices1=sort_indices1,
|
274
|
+
sort_indices2=sort_indices2,
|
275
|
+
w1_zeros=zeros1,
|
276
|
+
w2_zeros=zeros2,
|
277
|
+
num_bits=4,
|
278
|
+
is_k_full=is_k_full,
|
279
|
+
)
|
280
|
+
|
281
|
+
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
282
|
+
|
283
|
+
|
284
|
+
if __name__ == "__main__":
|
285
|
+
# Run the specific test function directly
|
286
|
+
pytest.main([__file__])
|