sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
|
|
41
41
|
"v_head_dim": 512,
|
42
42
|
"num_kv_heads": 1,
|
43
43
|
"layer_id": 0,
|
44
|
+
"tp_q_head_num": 128,
|
45
|
+
"tp_k_head_num": 128,
|
46
|
+
"prefill_head_dim": 192,
|
47
|
+
"prefill_v_head_dim": 128,
|
44
48
|
}
|
45
49
|
|
46
50
|
ROPE_BASE = 10000
|
@@ -92,7 +96,7 @@ TEST_CASES = {
|
|
92
96
|
"description": "Medium-scale batch",
|
93
97
|
},
|
94
98
|
],
|
95
|
-
"
|
99
|
+
"output_match": [
|
96
100
|
{
|
97
101
|
"name": "single_fp16",
|
98
102
|
"batch_size": 1,
|
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
322
326
|
config.update(test_case)
|
323
327
|
return config
|
324
328
|
|
325
|
-
def _create_model_components(self, config):
|
329
|
+
def _create_model_components(self, config, is_prefill=False):
|
326
330
|
"""Create model runners, backends, and layer for testing."""
|
327
331
|
# Create model runners
|
328
332
|
model_runner_trtllm = MockModelRunner(config)
|
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
332
336
|
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
333
337
|
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
334
338
|
|
339
|
+
head_dim = (
|
340
|
+
config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
341
|
+
if not is_prefill
|
342
|
+
else config["prefill_head_dim"]
|
343
|
+
)
|
344
|
+
v_head_dim = (
|
345
|
+
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
|
346
|
+
)
|
347
|
+
|
335
348
|
# Create RadixAttention layer
|
336
349
|
layer = RadixAttention(
|
337
350
|
num_heads=config["num_attention_heads"],
|
338
|
-
head_dim=
|
351
|
+
head_dim=head_dim,
|
339
352
|
scaling=model_runner_trtllm.model_config.scaling,
|
340
353
|
num_kv_heads=config["num_kv_heads"],
|
341
354
|
layer_id=config["layer_id"],
|
342
|
-
v_head_dim=
|
355
|
+
v_head_dim=v_head_dim,
|
343
356
|
prefix="attn_mqa",
|
344
357
|
)
|
345
358
|
|
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
524
537
|
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
525
538
|
print(f"\nRunning decode output matching tests...")
|
526
539
|
|
527
|
-
for test_case in TEST_CASES["
|
540
|
+
for test_case in TEST_CASES["output_match"]:
|
528
541
|
with self.subTest(test_case=test_case["name"]):
|
529
542
|
print(f" Testing {test_case['name']}: {test_case['description']}")
|
530
543
|
|
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
1099
1112
|
self.assertIsNotNone(metadata_3.block_kv_indices)
|
1100
1113
|
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
1101
1114
|
|
1115
|
+
def test_prefill_output_match_self_attention(self):
|
1116
|
+
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
|
1117
|
+
print(f"\nRunning prefill output tests...")
|
1118
|
+
|
1119
|
+
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
|
1120
|
+
with self.subTest(test_case=test_case["name"]):
|
1121
|
+
print(
|
1122
|
+
f"Prefill Testing {test_case['name']}: {test_case['description']}"
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
config = self._merge_config(test_case)
|
1126
|
+
batch_size = config["batch_size"]
|
1127
|
+
max_seq_len = config["max_seq_len"]
|
1128
|
+
|
1129
|
+
# Create components
|
1130
|
+
(
|
1131
|
+
model_runner_trtllm,
|
1132
|
+
model_runner_reference,
|
1133
|
+
trtllm_backend,
|
1134
|
+
reference_backend,
|
1135
|
+
layer,
|
1136
|
+
) = self._create_model_components(config, is_prefill=True)
|
1137
|
+
|
1138
|
+
# Prefill uses full sequences
|
1139
|
+
seq_lens = torch.full(
|
1140
|
+
(batch_size,), max_seq_len, device=config["device"]
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
def _create_forward_batch_prefill(
|
1144
|
+
batch_size,
|
1145
|
+
seq_lens,
|
1146
|
+
extend_prefix_lens,
|
1147
|
+
backend,
|
1148
|
+
model_runner,
|
1149
|
+
config,
|
1150
|
+
):
|
1151
|
+
"""Create a forward batch for the given backend."""
|
1152
|
+
|
1153
|
+
fb = ForwardBatch(
|
1154
|
+
batch_size=batch_size,
|
1155
|
+
input_ids=torch.randint(
|
1156
|
+
0, 100, (batch_size, 1), device=config["device"]
|
1157
|
+
),
|
1158
|
+
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
1159
|
+
seq_lens_sum=int(seq_lens.sum().item()),
|
1160
|
+
extend_prefix_lens=extend_prefix_lens,
|
1161
|
+
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
|
1162
|
+
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
|
1163
|
+
.cpu()
|
1164
|
+
.int()
|
1165
|
+
.tolist(),
|
1166
|
+
forward_mode=ForwardMode.EXTEND,
|
1167
|
+
req_pool_indices=torch.arange(
|
1168
|
+
batch_size, device=config["device"]
|
1169
|
+
),
|
1170
|
+
seq_lens=seq_lens,
|
1171
|
+
seq_lens_cpu=seq_lens.cpu(),
|
1172
|
+
attn_attend_prefix_cache=False,
|
1173
|
+
mha_return_lse=False,
|
1174
|
+
attn_backend=backend,
|
1175
|
+
)
|
1176
|
+
fb.req_to_token_pool = model_runner.req_to_token_pool
|
1177
|
+
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
1178
|
+
|
1179
|
+
# Add position information for RoPE
|
1180
|
+
fb.positions = torch.arange(batch_size, device=config["device"])
|
1181
|
+
|
1182
|
+
return fb
|
1183
|
+
|
1184
|
+
# Create forward batches
|
1185
|
+
fb_trtllm = _create_forward_batch_prefill(
|
1186
|
+
batch_size,
|
1187
|
+
seq_lens.clone(),
|
1188
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1189
|
+
trtllm_backend,
|
1190
|
+
model_runner_trtllm,
|
1191
|
+
config,
|
1192
|
+
)
|
1193
|
+
fb_reference = _create_forward_batch_prefill(
|
1194
|
+
batch_size,
|
1195
|
+
seq_lens.clone(),
|
1196
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1197
|
+
reference_backend,
|
1198
|
+
model_runner_reference,
|
1199
|
+
config,
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
# Initialize metadata for both backends
|
1203
|
+
trtllm_backend.init_forward_metadata(fb_trtllm)
|
1204
|
+
reference_backend.init_forward_metadata(fb_reference)
|
1205
|
+
|
1206
|
+
# Create Q, K, V tensors for prefill
|
1207
|
+
torch.manual_seed(config["seed_qkv"])
|
1208
|
+
|
1209
|
+
def _create_qkv_tensors_prefill(
|
1210
|
+
batch_size, seq_len, config, dtype_override=None
|
1211
|
+
):
|
1212
|
+
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
|
1213
|
+
device = config["device"]
|
1214
|
+
dtype = dtype_override or config["dtype"]
|
1215
|
+
|
1216
|
+
total_tokens = batch_size * seq_len
|
1217
|
+
|
1218
|
+
tp_q_head_num = config["tp_q_head_num"]
|
1219
|
+
tp_k_head_num = config["tp_k_head_num"]
|
1220
|
+
head_dim = config["prefill_head_dim"]
|
1221
|
+
v_head_dim = config["prefill_v_head_dim"]
|
1222
|
+
|
1223
|
+
q = torch.randn(
|
1224
|
+
(total_tokens, tp_q_head_num * head_dim),
|
1225
|
+
dtype=dtype,
|
1226
|
+
device=device,
|
1227
|
+
)
|
1228
|
+
k = torch.randn(
|
1229
|
+
(total_tokens, tp_k_head_num * head_dim),
|
1230
|
+
dtype=dtype,
|
1231
|
+
device=device,
|
1232
|
+
)
|
1233
|
+
v = torch.randn(
|
1234
|
+
(total_tokens, tp_k_head_num * v_head_dim),
|
1235
|
+
dtype=dtype,
|
1236
|
+
device=device,
|
1237
|
+
)
|
1238
|
+
|
1239
|
+
# Reshape as requested
|
1240
|
+
q = q.view(-1, tp_q_head_num, head_dim)
|
1241
|
+
k = k.view(-1, tp_k_head_num, head_dim)
|
1242
|
+
v = v.view(-1, tp_k_head_num, v_head_dim)
|
1243
|
+
|
1244
|
+
return q, k, v
|
1245
|
+
|
1246
|
+
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
|
1247
|
+
# Run prefill on both backends
|
1248
|
+
out_trtllm = trtllm_backend.forward_extend(
|
1249
|
+
q, k, v, layer, fb_trtllm, False
|
1250
|
+
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1251
|
+
out_reference = reference_backend.forward_extend(
|
1252
|
+
q, k, v, layer, fb_reference, False
|
1253
|
+
)
|
1254
|
+
|
1255
|
+
tolerance = config.get("tolerance", 1e-2)
|
1256
|
+
comparison_passed = compare_outputs(
|
1257
|
+
out_trtllm, out_reference, tolerance=tolerance
|
1258
|
+
)
|
1259
|
+
self.assertTrue(
|
1260
|
+
comparison_passed,
|
1261
|
+
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
|
1262
|
+
f"Config: {test_case['name']}, "
|
1263
|
+
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
1264
|
+
)
|
1265
|
+
|
1102
1266
|
|
1103
1267
|
if __name__ == "__main__":
|
1104
1268
|
unittest.main()
|
sglang/test/runners.py
CHANGED
@@ -505,6 +505,7 @@ class SRTRunner:
|
|
505
505
|
mem_fraction_static: float = 0.65,
|
506
506
|
trust_remote_code: bool = False,
|
507
507
|
speculative_draft_model_path: Optional[str] = None,
|
508
|
+
speculative_draft_model_revision: Optional[str] = None,
|
508
509
|
speculative_algorithm: Optional[str] = None,
|
509
510
|
speculative_num_steps: Optional[int] = None,
|
510
511
|
speculative_eagle_topk: Optional[int] = None,
|
@@ -526,6 +527,9 @@ class SRTRunner:
|
|
526
527
|
spec_kwargs = {}
|
527
528
|
if speculative_draft_model_path:
|
528
529
|
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
530
|
+
spec_kwargs["speculative_draft_model_revision"] = (
|
531
|
+
speculative_draft_model_revision
|
532
|
+
)
|
529
533
|
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
530
534
|
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
531
535
|
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
sglang/test/test_cutlass_moe.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoConfig
|
|
9
9
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
10
10
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
11
11
|
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
12
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
12
13
|
|
13
14
|
|
14
15
|
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
@@ -21,7 +22,7 @@ def calc_diff(x, y):
|
|
21
22
|
|
22
23
|
def get_model_config(tp_size: int):
|
23
24
|
config = AutoConfig.from_pretrained(
|
24
|
-
"deepseek-ai/
|
25
|
+
"deepseek-ai/Deepseek-R1", trust_remote_code=True
|
25
26
|
)
|
26
27
|
E = config.n_routed_experts
|
27
28
|
topk = config.num_experts_per_tok
|
@@ -152,14 +153,31 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
152
153
|
problem_sizes2,
|
153
154
|
)
|
154
155
|
|
156
|
+
topk_output = StandardTopKOutput(
|
157
|
+
topk_weights=topk_weights,
|
158
|
+
topk_ids=topk_ids,
|
159
|
+
router_logits=torch.randn(
|
160
|
+
(batch_size, topk), device=topk_weights.device, dtype=dtype
|
161
|
+
),
|
162
|
+
)
|
163
|
+
|
164
|
+
moe_runner_config = MoeRunnerConfig(
|
165
|
+
num_experts=E,
|
166
|
+
top_k=topk,
|
167
|
+
hidden_size=H,
|
168
|
+
intermediate_size_per_partition=I,
|
169
|
+
params_dtype=dtype,
|
170
|
+
activation="silu",
|
171
|
+
inplace=False,
|
172
|
+
)
|
173
|
+
|
155
174
|
# Note: Triton expects non-transposed weights
|
156
|
-
moe_config = MoeRunnerConfig(inplace=False)
|
157
175
|
triton_lambda = lambda: fused_experts(
|
158
176
|
x,
|
159
177
|
w1,
|
160
178
|
w2,
|
161
|
-
|
162
|
-
|
179
|
+
topk_output,
|
180
|
+
moe_runner_config,
|
163
181
|
use_fp8_w8a8=True,
|
164
182
|
w1_scale=w1_scale,
|
165
183
|
w2_scale=w2_scale,
|
@@ -224,8 +242,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
224
242
|
x,
|
225
243
|
w1, # Original shape
|
226
244
|
w2, # Original shape
|
227
|
-
|
228
|
-
|
245
|
+
topk_output,
|
246
|
+
moe_runner_config,
|
229
247
|
use_fp8_w8a8=True,
|
230
248
|
w1_scale=w1_scale,
|
231
249
|
w2_scale=w2_scale,
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import requests
|
4
|
+
|
5
|
+
from sglang.srt.utils import kill_process_tree
|
6
|
+
from sglang.test.test_utils import (
|
7
|
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
8
|
+
CustomTestCase,
|
9
|
+
popen_with_error_check,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
class TestDisaggregationBase(CustomTestCase):
|
14
|
+
@classmethod
|
15
|
+
def setUpClass(cls):
|
16
|
+
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
|
17
|
+
pass
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def launch_lb(cls):
|
21
|
+
lb_command = [
|
22
|
+
"python3",
|
23
|
+
"-m",
|
24
|
+
"sglang_router.launch_router",
|
25
|
+
"--pd-disaggregation",
|
26
|
+
"--mini-lb", # FIXME: remove this
|
27
|
+
"--prefill",
|
28
|
+
cls.prefill_url,
|
29
|
+
"--decode",
|
30
|
+
cls.decode_url,
|
31
|
+
"--host",
|
32
|
+
cls.base_host,
|
33
|
+
"--port",
|
34
|
+
cls.lb_port,
|
35
|
+
]
|
36
|
+
print("Starting load balancer:", " ".join(lb_command))
|
37
|
+
cls.process_lb = popen_with_error_check(lb_command)
|
38
|
+
cls.wait_server_ready(cls.lb_url + "/health")
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
|
42
|
+
start_time = time.perf_counter()
|
43
|
+
while True:
|
44
|
+
try:
|
45
|
+
response = requests.get(url)
|
46
|
+
if response.status_code == 200:
|
47
|
+
print(f"Server {url} is ready")
|
48
|
+
return
|
49
|
+
except Exception:
|
50
|
+
pass
|
51
|
+
|
52
|
+
if time.perf_counter() - start_time > timeout:
|
53
|
+
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
54
|
+
time.sleep(1)
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def tearDownClass(cls):
|
58
|
+
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
59
|
+
if process:
|
60
|
+
try:
|
61
|
+
kill_process_tree(process.pid)
|
62
|
+
except Exception as e:
|
63
|
+
print(f"Error killing process {process.pid}: {e}")
|
64
|
+
|
65
|
+
# wait for 5 seconds
|
66
|
+
time.sleep(5)
|