sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
|
|
43
43
|
"layer_id": 0,
|
44
44
|
}
|
45
45
|
|
46
|
+
ROPE_BASE = 10000
|
47
|
+
ROPE_SCALING_CONFIG = {
|
48
|
+
"beta_fast": 32,
|
49
|
+
"beta_slow": 1,
|
50
|
+
"factor": 40,
|
51
|
+
"mscale": 1.0,
|
52
|
+
"mscale_all_dim": 1.0,
|
53
|
+
"original_max_position_embeddings": 4096,
|
54
|
+
"type": "yarn",
|
55
|
+
"rope_type": "deepseek_yarn",
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
def build_rotary_emb(config, device=None):
|
60
|
+
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
61
|
+
|
62
|
+
dev = device or config["device"]
|
63
|
+
rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
|
64
|
+
rotary = get_rope_wrapper(
|
65
|
+
head_size=config["qk_rope_head_dim"],
|
66
|
+
rotary_dim=config["qk_rope_head_dim"],
|
67
|
+
max_position=config["context_len"],
|
68
|
+
base=ROPE_BASE,
|
69
|
+
rope_scaling=rope_scaling,
|
70
|
+
is_neox_style=False,
|
71
|
+
device=dev,
|
72
|
+
)
|
73
|
+
rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
|
74
|
+
return rotary
|
75
|
+
|
76
|
+
|
46
77
|
# Centralized test cases for different test scenarios
|
47
78
|
TEST_CASES = {
|
48
79
|
"basic_functionality": [
|
@@ -63,18 +94,36 @@ TEST_CASES = {
|
|
63
94
|
],
|
64
95
|
"decode_output_match": [
|
65
96
|
{
|
66
|
-
"name": "
|
97
|
+
"name": "single_fp16",
|
67
98
|
"batch_size": 1,
|
68
99
|
"max_seq_len": 64,
|
69
100
|
"page_size": 32,
|
70
|
-
"description": "Single vs reference",
|
101
|
+
"description": "Single FP16 vs reference",
|
71
102
|
},
|
72
103
|
{
|
73
|
-
"name": "
|
104
|
+
"name": "single_fp8",
|
105
|
+
"batch_size": 1,
|
106
|
+
"max_seq_len": 64,
|
107
|
+
"page_size": 64,
|
108
|
+
"tolerance": 1e-1,
|
109
|
+
"kv_cache_dtype": torch.float8_e4m3fn,
|
110
|
+
"description": "Single FP8 vs reference",
|
111
|
+
},
|
112
|
+
{
|
113
|
+
"name": "batch_fp16",
|
74
114
|
"batch_size": 32,
|
75
115
|
"max_seq_len": 64,
|
76
116
|
"page_size": 32,
|
77
|
-
"description": "Batch vs reference",
|
117
|
+
"description": "Batch FP16 vs reference",
|
118
|
+
},
|
119
|
+
{
|
120
|
+
"name": "batch_fp8",
|
121
|
+
"batch_size": 32,
|
122
|
+
"max_seq_len": 64,
|
123
|
+
"page_size": 64,
|
124
|
+
"tolerance": 1e-1,
|
125
|
+
"kv_cache_dtype": torch.float8_e4m3fn,
|
126
|
+
"description": "Batch FP8 vs reference",
|
78
127
|
},
|
79
128
|
],
|
80
129
|
"page_size_consistency": [
|
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
293
342
|
layer,
|
294
343
|
)
|
295
344
|
|
296
|
-
def _create_qkv_tensors(self, batch_size, config):
|
297
|
-
"""Create Q, K, V tensors for
|
298
|
-
|
345
|
+
def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
|
346
|
+
"""Create Q, K, V random tensors for given batch size with separate MLA components.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
batch_size: Batch size.
|
350
|
+
config: Configuration dict with model dims and device.
|
351
|
+
dtype_override: Optional torch dtype to override config["dtype"].
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
|
355
|
+
"""
|
299
356
|
device = config["device"]
|
300
|
-
|
357
|
+
target_dtype = dtype_override or config["dtype"]
|
301
358
|
|
302
|
-
|
303
|
-
|
304
|
-
|
359
|
+
# Create separate nope and rope components for Q
|
360
|
+
q_nope = torch.randn(
|
361
|
+
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
362
|
+
dtype=config["dtype"],
|
305
363
|
device=device,
|
306
364
|
)
|
307
|
-
|
308
|
-
(batch_size, config["
|
365
|
+
q_rope = torch.randn(
|
366
|
+
(batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
|
367
|
+
dtype=config["dtype"],
|
368
|
+
device=device,
|
369
|
+
)
|
370
|
+
|
371
|
+
# Create separate nope and rope components for K
|
372
|
+
k_nope = torch.randn(
|
373
|
+
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
374
|
+
dtype=config["dtype"],
|
375
|
+
device=device,
|
376
|
+
)
|
377
|
+
k_rope = torch.randn(
|
378
|
+
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
379
|
+
dtype=config["dtype"],
|
380
|
+
device=device,
|
309
381
|
)
|
382
|
+
|
383
|
+
# V tensor (unchanged)
|
310
384
|
v = torch.randn(
|
311
385
|
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
312
|
-
dtype=dtype,
|
386
|
+
dtype=config["dtype"],
|
313
387
|
device=device,
|
314
388
|
)
|
315
|
-
|
389
|
+
|
390
|
+
return q_nope, q_rope, k_nope, k_rope, v
|
316
391
|
|
317
392
|
def _create_forward_batch(
|
318
393
|
self, batch_size, seq_lens, backend, model_runner, config
|
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
331
406
|
)
|
332
407
|
fb.req_to_token_pool = model_runner.req_to_token_pool
|
333
408
|
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
409
|
+
|
410
|
+
# Add position information for RoPE
|
411
|
+
fb.positions = torch.arange(batch_size, device=config["device"])
|
412
|
+
|
334
413
|
return fb
|
335
414
|
|
336
415
|
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
344
423
|
for token_idx in range(seq_len - 1):
|
345
424
|
# Create random K components for MLA
|
346
425
|
cache_k_nope = torch.randn(
|
347
|
-
(1, config["
|
426
|
+
(1, config["kv_lora_rank"]),
|
348
427
|
dtype=config["dtype"],
|
349
428
|
device=config["device"],
|
350
429
|
)
|
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
411
490
|
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
412
491
|
)
|
413
492
|
|
414
|
-
# Create Q, K, V tensors
|
493
|
+
# Create Q, K, V tensors with separate MLA components
|
415
494
|
torch.manual_seed(config["seed_qkv"])
|
416
|
-
|
495
|
+
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
496
|
+
batch_size, config
|
497
|
+
)
|
417
498
|
|
418
|
-
# Run forward decode
|
419
|
-
output = trtllm_backend.forward_decode(
|
499
|
+
# Run forward decode with separate MLA components
|
500
|
+
output = trtllm_backend.forward_decode(
|
501
|
+
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
502
|
+
)
|
420
503
|
|
421
504
|
# Basic checks
|
422
505
|
expected_shape = (
|
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
439
522
|
config = self._merge_config(test_case)
|
440
523
|
batch_size = config["batch_size"]
|
441
524
|
max_seq_len = config["max_seq_len"]
|
525
|
+
use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
|
442
526
|
|
443
527
|
# Create components
|
444
528
|
(
|
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
487
571
|
|
488
572
|
# Create Q, K, V tensors for current decode step
|
489
573
|
torch.manual_seed(config["seed_qkv"])
|
490
|
-
|
574
|
+
|
575
|
+
q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
|
576
|
+
self._create_qkv_tensors(batch_size, config)
|
577
|
+
)
|
578
|
+
q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
|
579
|
+
q_nope_ref.clone(),
|
580
|
+
q_rope_ref.clone(),
|
581
|
+
k_nope_ref.clone(),
|
582
|
+
k_rope_ref.clone(),
|
583
|
+
v_ref.clone(),
|
584
|
+
)
|
585
|
+
tolerance = config["tolerance"]
|
586
|
+
|
587
|
+
extra_args = {}
|
588
|
+
if use_fp8:
|
589
|
+
# TRT kernel applies RoPE + FP8 quantization internally
|
590
|
+
# pre-apply RoPE on the reference (FlashInfer) path here so
|
591
|
+
# both paths share the same rope params/cache while keeping
|
592
|
+
# the TRT path unrotated.
|
593
|
+
rotary_emb = build_rotary_emb(config)
|
594
|
+
q_rope_ref, k_rope_ref = rotary_emb(
|
595
|
+
fb_reference.positions, q_rope_ref, k_rope_ref
|
596
|
+
)
|
597
|
+
extra_args = {
|
598
|
+
"cos_sin_cache": rotary_emb.cos_sin_cache,
|
599
|
+
"is_neox": rotary_emb.is_neox_style,
|
600
|
+
}
|
601
|
+
|
602
|
+
dtype = q_rope_ref.dtype
|
603
|
+
q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
604
|
+
q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
605
|
+
k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
606
|
+
k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
491
607
|
|
492
608
|
# Run forward decode on both backends
|
493
609
|
out_trtllm = trtllm_backend.forward_decode(
|
494
|
-
|
610
|
+
q_nope_trt,
|
611
|
+
k_nope_trt,
|
612
|
+
None,
|
613
|
+
layer,
|
614
|
+
fb_trtllm,
|
615
|
+
q_rope=q_rope_trt,
|
616
|
+
k_rope=k_rope_trt,
|
617
|
+
**extra_args,
|
495
618
|
)
|
619
|
+
|
620
|
+
# Reference backend should also take separate components, not concatenated
|
496
621
|
out_reference = reference_backend.forward_decode(
|
497
|
-
|
622
|
+
q_nope_ref,
|
623
|
+
k_nope_ref,
|
624
|
+
v_ref,
|
625
|
+
layer,
|
626
|
+
fb_reference,
|
627
|
+
q_rope=q_rope_ref,
|
628
|
+
k_rope=k_rope_ref,
|
498
629
|
)
|
499
630
|
|
500
631
|
# Compare outputs
|
501
632
|
comparison_passed = compare_outputs(
|
502
|
-
out_trtllm, out_reference, tolerance=
|
633
|
+
out_trtllm, out_reference, tolerance=tolerance
|
503
634
|
)
|
504
635
|
|
505
636
|
self.assertTrue(
|
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
544
675
|
batch_size, seq_lens, [model_runner], layer, config
|
545
676
|
)
|
546
677
|
|
547
|
-
# Create Q, K, V tensors
|
678
|
+
# Create Q, K, V tensors with separate MLA components
|
548
679
|
torch.manual_seed(config["seed_qkv"])
|
549
|
-
|
680
|
+
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
681
|
+
batch_size, config
|
682
|
+
)
|
550
683
|
|
551
|
-
# Run forward decode
|
552
|
-
output = backend.forward_decode(
|
684
|
+
# Run forward decode with separate MLA components
|
685
|
+
output = backend.forward_decode(
|
686
|
+
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
687
|
+
)
|
553
688
|
|
554
689
|
expected_shape = (
|
555
690
|
batch_size,
|
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
591
726
|
)
|
592
727
|
backend.init_forward_metadata(fb)
|
593
728
|
|
594
|
-
# Create Q, K, V tensors
|
729
|
+
# Create Q, K, V tensors with separate MLA components
|
595
730
|
torch.manual_seed(config["seed_qkv"])
|
596
|
-
|
597
|
-
|
598
|
-
(batch_size, config["num_attention_heads"], head_dim),
|
731
|
+
q_nope = torch.randn(
|
732
|
+
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
599
733
|
dtype=config["dtype"],
|
600
734
|
device=config["device"],
|
601
735
|
)
|
602
|
-
|
603
|
-
(batch_size, config["num_kv_heads"],
|
736
|
+
k_nope = torch.randn(
|
737
|
+
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
604
738
|
dtype=config["dtype"],
|
605
739
|
device=config["device"],
|
606
740
|
)
|
607
|
-
|
741
|
+
q_rope = torch.randn(
|
742
|
+
(
|
743
|
+
batch_size,
|
744
|
+
config["num_attention_heads"],
|
745
|
+
config["qk_rope_head_dim"],
|
746
|
+
),
|
747
|
+
dtype=config["dtype"],
|
748
|
+
device=config["device"],
|
749
|
+
)
|
750
|
+
k_rope = torch.randn(
|
751
|
+
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
752
|
+
dtype=config["dtype"],
|
753
|
+
device=config["device"],
|
754
|
+
)
|
755
|
+
v = None # Test with None v
|
608
756
|
|
609
757
|
# Run forward decode
|
610
|
-
output = backend.forward_decode(
|
758
|
+
output = backend.forward_decode(
|
759
|
+
q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
|
760
|
+
)
|
611
761
|
|
612
762
|
# Shape and sanity checks
|
613
763
|
expected_shape = (
|
sglang/test/doc_patch.py
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
"""
|
2
|
+
Do some monkey patch to make the documentation compilation faster and more reliable.
|
3
|
+
|
4
|
+
- Avoid port conflicts
|
5
|
+
- Reduce the server launch time
|
6
|
+
"""
|
7
|
+
|
8
|
+
import weakref
|
9
|
+
|
10
|
+
import nest_asyncio
|
11
|
+
|
12
|
+
nest_asyncio.apply()
|
13
|
+
|
14
|
+
import sglang.srt.server_args as server_args_mod
|
15
|
+
from sglang.utils import execute_shell_command, reserve_port
|
16
|
+
|
17
|
+
DEFAULT_MAX_RUNNING_REQUESTS = 128
|
18
|
+
DEFAULT_MAX_TOTAL_TOKENS = 20480 # To allow multiple servers on the same machine
|
19
|
+
|
20
|
+
_original_post_init = server_args_mod.ServerArgs.__post_init__
|
21
|
+
|
22
|
+
|
23
|
+
def patched_post_init(self):
|
24
|
+
_original_post_init(self)
|
25
|
+
if self.max_running_requests is None:
|
26
|
+
self.max_running_requests = DEFAULT_MAX_RUNNING_REQUESTS
|
27
|
+
if self.max_total_tokens is None:
|
28
|
+
self.max_total_tokens = DEFAULT_MAX_TOTAL_TOKENS
|
29
|
+
self.cuda_graph_max_bs = 4
|
30
|
+
|
31
|
+
|
32
|
+
server_args_mod.ServerArgs.__post_init__ = patched_post_init
|
33
|
+
|
34
|
+
process_socket_map = weakref.WeakKeyDictionary()
|
35
|
+
|
36
|
+
|
37
|
+
def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
|
38
|
+
"""
|
39
|
+
Launch the server using the given command.
|
40
|
+
If no port is specified, a free port is reserved.
|
41
|
+
"""
|
42
|
+
if port is None:
|
43
|
+
port, lock_socket = reserve_port(host)
|
44
|
+
else:
|
45
|
+
lock_socket = None
|
46
|
+
|
47
|
+
extra_flags = (
|
48
|
+
f"--max-running-requests {DEFAULT_MAX_RUNNING_REQUESTS} "
|
49
|
+
f"--max-total-tokens {DEFAULT_MAX_TOTAL_TOKENS} "
|
50
|
+
f"--cuda-graph-max-bs 4"
|
51
|
+
)
|
52
|
+
|
53
|
+
full_command = f"{command} --port {port} {extra_flags}"
|
54
|
+
process = execute_shell_command(full_command)
|
55
|
+
|
56
|
+
if lock_socket is not None:
|
57
|
+
process_socket_map[process] = lock_socket
|
58
|
+
|
59
|
+
return process, port
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -12,7 +12,7 @@ import time
|
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
|
15
|
-
from sglang.api import set_default_backend
|
15
|
+
from sglang.lang.api import set_default_backend
|
16
16
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
17
17
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
18
18
|
|
@@ -8,7 +8,7 @@ import time
|
|
8
8
|
import numpy as np
|
9
9
|
|
10
10
|
import sglang as sgl
|
11
|
-
from sglang.api import set_default_backend
|
11
|
+
from sglang.lang.api import set_default_backend
|
12
12
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
13
13
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
14
14
|
|
sglang/test/run_eval.py
CHANGED
@@ -65,9 +65,10 @@ def run_eval(args):
|
|
65
65
|
|
66
66
|
sampler = ChatCompletionSampler(
|
67
67
|
model=args.model,
|
68
|
-
max_tokens=2048,
|
68
|
+
max_tokens=getattr(args, "max_tokens", 2048),
|
69
69
|
base_url=base_url,
|
70
70
|
temperature=getattr(args, "temperature", 0.0),
|
71
|
+
reasoning_effort=getattr(args, "reasoning_effort", None),
|
71
72
|
)
|
72
73
|
|
73
74
|
# Run eval
|
@@ -120,7 +121,9 @@ if __name__ == "__main__":
|
|
120
121
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
121
122
|
parser.add_argument("--num-examples", type=int)
|
122
123
|
parser.add_argument("--num-threads", type=int, default=512)
|
124
|
+
parser.add_argument("--max-tokens", type=int, default=2048)
|
123
125
|
parser.add_argument("--temperature", type=float, default=0.0)
|
126
|
+
parser.add_argument("--reasoning-effort", type=str)
|
124
127
|
args = parser.parse_args()
|
125
128
|
|
126
129
|
run_eval(args)
|
sglang/test/runners.py
CHANGED
@@ -568,8 +568,8 @@ class SRTRunner:
|
|
568
568
|
else:
|
569
569
|
self.tokenizer = None
|
570
570
|
|
571
|
-
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
572
|
-
return self.engine.load_lora_adapter(lora_name, lora_path)
|
571
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
|
572
|
+
return self.engine.load_lora_adapter(lora_name, lora_path, pinned)
|
573
573
|
|
574
574
|
def unload_lora_adapter(self, lora_name: str):
|
575
575
|
return self.engine.unload_lora_adapter(lora_name)
|
@@ -91,6 +91,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
91
91
|
model: Optional[str] = None,
|
92
92
|
system_message: Optional[str] = None,
|
93
93
|
temperature: float = 0.0,
|
94
|
+
reasoning_effort: Optional[str] = None,
|
94
95
|
max_tokens: int = 2048,
|
95
96
|
):
|
96
97
|
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
@@ -102,7 +103,11 @@ class ChatCompletionSampler(SamplerBase):
|
|
102
103
|
self.system_message = system_message
|
103
104
|
self.temperature = temperature
|
104
105
|
self.max_tokens = max_tokens
|
106
|
+
self.reasoning_effort = reasoning_effort
|
105
107
|
self.image_format = "url"
|
108
|
+
print(
|
109
|
+
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
|
110
|
+
)
|
106
111
|
|
107
112
|
def _handle_image(
|
108
113
|
self,
|
@@ -138,6 +143,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
138
143
|
messages=message_list,
|
139
144
|
temperature=self.temperature,
|
140
145
|
max_tokens=self.max_tokens,
|
146
|
+
reasoning_effort=self.reasoning_effort,
|
141
147
|
)
|
142
148
|
return response.choices[0].message.content
|
143
149
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|
sglang/test/simple_eval_gpqa.py
CHANGED
@@ -71,6 +71,8 @@ class GPQAEval(Eval):
|
|
71
71
|
)
|
72
72
|
]
|
73
73
|
response_text = sampler(prompt_messages)
|
74
|
+
if response_text is None:
|
75
|
+
response_text = ""
|
74
76
|
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
75
77
|
extracted_answer = match.group(1) if match else None
|
76
78
|
score = 1.0 if extracted_answer == correct_answer else 0.0
|
sglang/test/test_fp4_moe.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
from typing import Callable
|
3
|
+
|
2
4
|
import pytest
|
3
5
|
import torch
|
6
|
+
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
4
7
|
from sgl_kernel import scaled_fp4_quant
|
5
8
|
|
6
9
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -111,15 +114,16 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
|
|
111
114
|
).sum(dim=1)
|
112
115
|
|
113
116
|
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
117
|
+
def check_moe(
|
118
|
+
m: int,
|
119
|
+
n: int,
|
120
|
+
k: int,
|
121
|
+
e: int,
|
122
|
+
topk: int,
|
123
|
+
dtype: torch.dtype,
|
124
|
+
moe_impl: Callable,
|
125
|
+
flip_w13: bool,
|
121
126
|
):
|
122
|
-
|
123
127
|
torch.manual_seed(7)
|
124
128
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
125
129
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
@@ -167,38 +171,18 @@ def test_cutlass_fp4_moe_no_graph(
|
|
167
171
|
|
168
172
|
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
169
173
|
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
170
|
-
|
171
|
-
ab_strides_13 = torch.full(
|
172
|
-
(e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device
|
173
|
-
)
|
174
|
-
c_strides_13 = torch.full(
|
175
|
-
(e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device
|
176
|
-
)
|
177
|
-
ab_strides_2 = torch.full(
|
178
|
-
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
|
179
|
-
)
|
180
|
-
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
|
181
|
-
params = CutlassMoEParams(
|
182
|
-
CutlassMoEType.BlockscaledFP4,
|
183
|
-
device=a.device,
|
184
|
-
num_experts=e,
|
185
|
-
intermediate_size_per_partition=n, # n
|
186
|
-
hidden_size=k,
|
187
|
-
) # k
|
188
|
-
cutlass_output = cutlass_moe_fp4(
|
174
|
+
test_output = moe_impl(
|
189
175
|
a=a,
|
190
|
-
|
191
|
-
|
176
|
+
topk_weights=topk_weights,
|
177
|
+
topk_ids=topk_ids,
|
178
|
+
w1_q=w1_q,
|
179
|
+
w2_q=w2_q,
|
180
|
+
a1_gs=a1_gs,
|
192
181
|
w1_blockscale=w1_blockscale,
|
193
182
|
w1_alphas=(1 / w1_gs),
|
194
|
-
|
195
|
-
w2_fp4=w2_q,
|
183
|
+
a2_gs=a2_gs,
|
196
184
|
w2_blockscale=w2_blockscale,
|
197
185
|
w2_alphas=(1 / w2_gs),
|
198
|
-
topk_weights=topk_weights,
|
199
|
-
topk_ids=topk_ids,
|
200
|
-
params=params,
|
201
|
-
apply_router_weight_on_input=False,
|
202
186
|
)
|
203
187
|
|
204
188
|
# Reference check:
|
@@ -237,10 +221,108 @@ def test_cutlass_fp4_moe_no_graph(
|
|
237
221
|
block_size=quant_blocksize,
|
238
222
|
)
|
239
223
|
|
224
|
+
if flip_w13:
|
225
|
+
dim = -2
|
226
|
+
size = w1_d.size(dim)
|
227
|
+
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
228
|
+
half = size // 2
|
229
|
+
# Reorder weight
|
230
|
+
w1, w3 = w1_d.split(half, dim=dim)
|
231
|
+
w1_d = torch.cat([w3, w1], dim=dim).contiguous()
|
232
|
+
|
240
233
|
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
|
241
234
|
|
242
|
-
torch.testing.assert_close(torch_output,
|
235
|
+
torch.testing.assert_close(torch_output, test_output, atol=1e-1, rtol=1e-1)
|
236
|
+
|
237
|
+
|
238
|
+
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
239
|
+
@pytest.mark.parametrize("e", [40, 64, 256])
|
240
|
+
@pytest.mark.parametrize("topk", [1, 6, 8])
|
241
|
+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
242
|
+
@torch.inference_mode()
|
243
|
+
def test_cutlass_fp4_moe_no_graph(
|
244
|
+
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
245
|
+
):
|
246
|
+
def cutlass_moe_impl(
|
247
|
+
a,
|
248
|
+
topk_weights,
|
249
|
+
topk_ids,
|
250
|
+
w1_q,
|
251
|
+
w2_q,
|
252
|
+
a1_gs,
|
253
|
+
w1_blockscale,
|
254
|
+
w1_alphas,
|
255
|
+
a2_gs,
|
256
|
+
w2_blockscale,
|
257
|
+
w2_alphas,
|
258
|
+
):
|
259
|
+
params = CutlassMoEParams(
|
260
|
+
CutlassMoEType.BlockscaledFP4,
|
261
|
+
device=a.device,
|
262
|
+
num_experts=e,
|
263
|
+
intermediate_size_per_partition=n, # n
|
264
|
+
hidden_size=k,
|
265
|
+
) # k
|
266
|
+
return cutlass_moe_fp4(
|
267
|
+
a=a,
|
268
|
+
a1_gscale=a1_gs,
|
269
|
+
w1_fp4=w1_q,
|
270
|
+
w1_blockscale=w1_blockscale,
|
271
|
+
w1_alphas=w1_alphas,
|
272
|
+
a2_gscale=a2_gs,
|
273
|
+
w2_fp4=w2_q,
|
274
|
+
w2_blockscale=w2_blockscale,
|
275
|
+
w2_alphas=w2_alphas,
|
276
|
+
topk_weights=topk_weights,
|
277
|
+
topk_ids=topk_ids,
|
278
|
+
params=params,
|
279
|
+
apply_router_weight_on_input=False,
|
280
|
+
)
|
281
|
+
|
282
|
+
check_moe(m, n, k, e, topk, dtype, cutlass_moe_impl, flip_w13=False)
|
283
|
+
|
284
|
+
|
285
|
+
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
286
|
+
@pytest.mark.parametrize("e", [40, 64, 256])
|
287
|
+
@pytest.mark.parametrize("topk", [1, 6, 8])
|
288
|
+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
289
|
+
@torch.inference_mode()
|
290
|
+
def test_flashinfer_fp4_moe_no_graph(
|
291
|
+
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
292
|
+
):
|
293
|
+
def flashinfer_moe_impl(
|
294
|
+
a,
|
295
|
+
topk_weights,
|
296
|
+
topk_ids,
|
297
|
+
w1_q,
|
298
|
+
w2_q,
|
299
|
+
a1_gs,
|
300
|
+
w1_blockscale,
|
301
|
+
w1_alphas,
|
302
|
+
a2_gs,
|
303
|
+
w2_blockscale,
|
304
|
+
w2_alphas,
|
305
|
+
):
|
306
|
+
return flashinfer_cutlass_fused_moe(
|
307
|
+
a,
|
308
|
+
topk_ids.to(torch.int),
|
309
|
+
topk_weights,
|
310
|
+
w1_q.view(torch.long),
|
311
|
+
w2_q.view(torch.long),
|
312
|
+
a.dtype,
|
313
|
+
quant_scales=[
|
314
|
+
a1_gs,
|
315
|
+
w1_blockscale.view(torch.int32),
|
316
|
+
w1_alphas,
|
317
|
+
a2_gs,
|
318
|
+
w2_blockscale.view(torch.int32),
|
319
|
+
w2_alphas,
|
320
|
+
],
|
321
|
+
)[0]
|
322
|
+
|
323
|
+
check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True)
|
243
324
|
|
244
325
|
|
245
326
|
if __name__ == "__main__":
|
246
327
|
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
328
|
+
test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
sglang/test/test_utils.py
CHANGED
@@ -83,7 +83,7 @@ DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-30B-A3B"
|
|
83
83
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
|
84
84
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
|
85
85
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
|
86
|
-
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
86
|
+
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,zai-org/GLM-4.5-Air-FP8"
|
87
87
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
|
88
88
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
89
89
|
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-VL-3B-Instruct"
|