sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Merge Chrome trace files from multiple ranks (TP, DP, PP, EP) into a single trace."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import gzip
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ProfileMerger:
|
|
15
|
+
"""Merge profile traces from all parallelism types: TP, DP, PP, EP."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, output_dir: str, profile_id: str):
|
|
18
|
+
self.output_dir = output_dir
|
|
19
|
+
self.profile_id = profile_id
|
|
20
|
+
self.merged_trace_path = os.path.join(
|
|
21
|
+
output_dir, f"merged-{profile_id}.trace.json.gz"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Rank types in priority order (used for sorting and labeling)
|
|
25
|
+
self.rank_types = ["tp", "dp", "pp", "ep"]
|
|
26
|
+
|
|
27
|
+
# Sort index multipliers: DP (highest) > EP > PP > TP (lowest)
|
|
28
|
+
# These ensure proper visual ordering in trace viewer
|
|
29
|
+
self.sort_index_multipliers = {
|
|
30
|
+
"dp_rank": 100_000_000,
|
|
31
|
+
"ep_rank": 1_000_000,
|
|
32
|
+
"pp_rank": 10_000,
|
|
33
|
+
"tp_rank": 100,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
# PID threshold for sort_index updates (only update for system PIDs < 1000)
|
|
37
|
+
self.pid_sort_index_threshold = 1000
|
|
38
|
+
|
|
39
|
+
def merge_chrome_traces(self) -> str:
|
|
40
|
+
"""Merge Chrome traces from all ranks into a single trace.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Path to merged trace file.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If no trace files found.
|
|
47
|
+
"""
|
|
48
|
+
trace_files = self._discover_trace_files()
|
|
49
|
+
if not trace_files:
|
|
50
|
+
raise ValueError(f"No trace files found for profile_id: {self.profile_id}")
|
|
51
|
+
|
|
52
|
+
logger.info(f"Found {len(trace_files)} trace files to merge")
|
|
53
|
+
|
|
54
|
+
merged_trace = {"traceEvents": []}
|
|
55
|
+
all_device_properties = []
|
|
56
|
+
|
|
57
|
+
for trace_file in sorted(trace_files, key=self._get_rank_sort_key):
|
|
58
|
+
rank_info = self._extract_rank_info(trace_file)
|
|
59
|
+
logger.info(f"Processing {trace_file} with rank info: {rank_info}")
|
|
60
|
+
|
|
61
|
+
output = self._handle_file(trace_file, rank_info)
|
|
62
|
+
|
|
63
|
+
merged_trace["traceEvents"].extend(output["traceEvents"])
|
|
64
|
+
|
|
65
|
+
if "deviceProperties" in output:
|
|
66
|
+
all_device_properties.extend(output["deviceProperties"])
|
|
67
|
+
del output["deviceProperties"]
|
|
68
|
+
|
|
69
|
+
for key, value in output.items():
|
|
70
|
+
if key != "traceEvents" and key not in merged_trace:
|
|
71
|
+
merged_trace[key] = value
|
|
72
|
+
|
|
73
|
+
if all_device_properties:
|
|
74
|
+
merged_trace["deviceProperties"] = all_device_properties
|
|
75
|
+
|
|
76
|
+
with gzip.open(self.merged_trace_path, "wb") as f:
|
|
77
|
+
f.write(json.dumps(merged_trace).encode("utf-8"))
|
|
78
|
+
|
|
79
|
+
logger.info(f"Merged profile saved to: {self.merged_trace_path}")
|
|
80
|
+
logger.info(f"Total events merged: {len(merged_trace['traceEvents'])}")
|
|
81
|
+
|
|
82
|
+
return self.merged_trace_path
|
|
83
|
+
|
|
84
|
+
def _discover_trace_files(self) -> List[str]:
|
|
85
|
+
"""Discover trace files matching profile_id (supports TP/DP/PP/EP formats)."""
|
|
86
|
+
patterns = [f"{self.profile_id}*.trace.json.gz"]
|
|
87
|
+
|
|
88
|
+
trace_files = []
|
|
89
|
+
for pattern in patterns:
|
|
90
|
+
search_pattern = os.path.join(self.output_dir, pattern)
|
|
91
|
+
trace_files.extend(glob.glob(search_pattern))
|
|
92
|
+
|
|
93
|
+
trace_files = [
|
|
94
|
+
f
|
|
95
|
+
for f in trace_files
|
|
96
|
+
if not f.endswith(f"merged-{self.profile_id}.trace.json.gz")
|
|
97
|
+
and not f.endswith("-memory.pickle")
|
|
98
|
+
and "TP-" in f
|
|
99
|
+
]
|
|
100
|
+
trace_files = list(set(trace_files))
|
|
101
|
+
return trace_files
|
|
102
|
+
|
|
103
|
+
def _extract_rank_info(self, filename: str) -> Dict[str, int]:
|
|
104
|
+
"""Extract rank info (TP/DP/PP/EP) from filename."""
|
|
105
|
+
basename = os.path.basename(filename)
|
|
106
|
+
rank_info = {}
|
|
107
|
+
|
|
108
|
+
for rank_type in self.rank_types:
|
|
109
|
+
match = re.search(rf"{rank_type.upper()}-(\d+)", basename)
|
|
110
|
+
if match:
|
|
111
|
+
rank_info[f"{rank_type}_rank"] = int(match.group(1))
|
|
112
|
+
|
|
113
|
+
return rank_info
|
|
114
|
+
|
|
115
|
+
def _create_rank_label(self, rank_info: Dict[str, int]) -> str:
|
|
116
|
+
parts = []
|
|
117
|
+
for rank_type in self.rank_types:
|
|
118
|
+
rank_key = f"{rank_type}_rank"
|
|
119
|
+
if rank_key in rank_info:
|
|
120
|
+
parts.append(f"{rank_type.upper()}{rank_info[rank_key]:02d}")
|
|
121
|
+
|
|
122
|
+
return f"[{'-'.join(parts)}]" if parts else "[Unknown]"
|
|
123
|
+
|
|
124
|
+
def _handle_file(self, path: str, rank_info: Dict[str, int]) -> Dict[str, Any]:
|
|
125
|
+
logger.info(f"Processing file: {path}")
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
with gzip.open(path, "rt", encoding="utf-8") as f:
|
|
129
|
+
trace = json.load(f)
|
|
130
|
+
|
|
131
|
+
output = {
|
|
132
|
+
key: value for key, value in trace.items() if key != "traceEvents"
|
|
133
|
+
}
|
|
134
|
+
output["traceEvents"] = self._process_events(
|
|
135
|
+
trace.get("traceEvents", []), rank_info
|
|
136
|
+
)
|
|
137
|
+
return output
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"Failed to process trace file {path}: {e}")
|
|
141
|
+
return {"traceEvents": []}
|
|
142
|
+
|
|
143
|
+
def _process_events(
|
|
144
|
+
self, events: List[Dict], rank_info: Dict[str, int]
|
|
145
|
+
) -> List[Dict]:
|
|
146
|
+
"""Process events: update sort_index and add rank labels to PIDs."""
|
|
147
|
+
rank_label = self._create_rank_label(rank_info)
|
|
148
|
+
|
|
149
|
+
for event in events:
|
|
150
|
+
if event.get("name") == "process_sort_index":
|
|
151
|
+
pid = self._maybe_cast_int(event.get("pid"))
|
|
152
|
+
if pid is not None and pid < self.pid_sort_index_threshold:
|
|
153
|
+
event["args"]["sort_index"] = self._calculate_sort_index(
|
|
154
|
+
rank_info, pid
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
event["pid"] = f"{rank_label} {event['pid']}"
|
|
158
|
+
|
|
159
|
+
return events
|
|
160
|
+
|
|
161
|
+
def _calculate_sort_index(self, rank_info: Dict[str, int], pid: int) -> int:
|
|
162
|
+
sort_index = pid
|
|
163
|
+
for rank_type, multiplier in self.sort_index_multipliers.items():
|
|
164
|
+
sort_index += rank_info.get(rank_type, 0) * multiplier
|
|
165
|
+
return sort_index
|
|
166
|
+
|
|
167
|
+
def _get_rank_sort_key(self, path: str) -> Tuple[int, int, int, int]:
|
|
168
|
+
rank_info = self._extract_rank_info(path)
|
|
169
|
+
return tuple(
|
|
170
|
+
rank_info.get(f"{rank_type}_rank", 0)
|
|
171
|
+
for rank_type in ["dp", "ep", "pp", "tp"]
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _maybe_cast_int(self, x) -> Optional[int]:
|
|
175
|
+
try:
|
|
176
|
+
return int(x)
|
|
177
|
+
except (ValueError, TypeError):
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
def get_merge_summary(self) -> Dict[str, Any]:
|
|
181
|
+
if not os.path.exists(self.merged_trace_path):
|
|
182
|
+
return {"error": "Merged trace file not found"}
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
with gzip.open(self.merged_trace_path, "rt") as f:
|
|
186
|
+
merged_data = json.load(f)
|
|
187
|
+
|
|
188
|
+
trace_files = self._discover_trace_files()
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"merged_file": self.merged_trace_path,
|
|
192
|
+
"total_events": len(merged_data.get("traceEvents", [])),
|
|
193
|
+
"total_files": len(trace_files),
|
|
194
|
+
"source_files": [os.path.basename(f) for f in trace_files],
|
|
195
|
+
"profile_id": self.profile_id,
|
|
196
|
+
"device_properties_count": len(merged_data.get("deviceProperties", [])),
|
|
197
|
+
}
|
|
198
|
+
except Exception as e:
|
|
199
|
+
return {"error": f"Failed to read merged trace: {str(e)}"}
|
|
@@ -66,7 +66,7 @@ class MockModelRunner:
|
|
|
66
66
|
enable_memory_saver=False,
|
|
67
67
|
)
|
|
68
68
|
# Required by torch native backend
|
|
69
|
-
self.server_args = ServerArgs(model_path="
|
|
69
|
+
self.server_args = ServerArgs(model_path="dummy")
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
|
@@ -4,7 +4,6 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from sglang.srt.configs.model_config import AttentionArch
|
|
6
6
|
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
|
7
|
-
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
|
8
7
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
9
8
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
|
10
9
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
@@ -2,8 +2,6 @@ import unittest
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
|
6
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
|
7
5
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
|
8
6
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
9
7
|
from sglang.test.test_utils import CustomTestCase
|
|
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
|
|
|
16
16
|
TRTLLMMLABackend,
|
|
17
17
|
TRTLLMMLADecodeMetadata,
|
|
18
18
|
)
|
|
19
|
-
from sglang.srt.layers.attention.utils import
|
|
19
|
+
from sglang.srt.layers.attention.utils import get_num_page_per_block_flashmla
|
|
20
20
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
21
21
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
|
22
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
23
|
+
from sglang.srt.server_args import (
|
|
24
|
+
ServerArgs,
|
|
25
|
+
get_global_server_args,
|
|
26
|
+
set_global_server_args_for_scheduler,
|
|
27
|
+
)
|
|
23
28
|
from sglang.srt.utils import is_flashinfer_available
|
|
24
29
|
from sglang.test.test_utils import CustomTestCase
|
|
25
30
|
|
|
@@ -104,15 +109,15 @@ TEST_CASES = {
|
|
|
104
109
|
"page_size": 32,
|
|
105
110
|
"description": "Single FP16 vs reference",
|
|
106
111
|
},
|
|
107
|
-
{
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
},
|
|
112
|
+
# {
|
|
113
|
+
# "name": "single_fp8",
|
|
114
|
+
# "batch_size": 1,
|
|
115
|
+
# "max_seq_len": 64,
|
|
116
|
+
# "page_size": 64,
|
|
117
|
+
# "tolerance": 1e-1,
|
|
118
|
+
# "kv_cache_dtype": torch.float8_e4m3fn,
|
|
119
|
+
# "description": "Single FP8 vs reference",
|
|
120
|
+
# },
|
|
116
121
|
{
|
|
117
122
|
"name": "batch_fp16",
|
|
118
123
|
"batch_size": 32,
|
|
@@ -120,15 +125,15 @@ TEST_CASES = {
|
|
|
120
125
|
"page_size": 32,
|
|
121
126
|
"description": "Batch FP16 vs reference",
|
|
122
127
|
},
|
|
123
|
-
{
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
},
|
|
128
|
+
# {
|
|
129
|
+
# "name": "batch_fp8",
|
|
130
|
+
# "batch_size": 32,
|
|
131
|
+
# "max_seq_len": 64,
|
|
132
|
+
# "page_size": 64,
|
|
133
|
+
# "tolerance": 1e-1,
|
|
134
|
+
# "kv_cache_dtype": torch.float8_e4m3fn,
|
|
135
|
+
# "description": "Batch FP8 vs reference",
|
|
136
|
+
# },
|
|
132
137
|
],
|
|
133
138
|
"page_size_consistency": [
|
|
134
139
|
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
|
|
@@ -213,13 +218,7 @@ class MockModelRunner:
|
|
|
213
218
|
self.page_size = config["page_size"]
|
|
214
219
|
|
|
215
220
|
# Server args stub - needed by attention backends
|
|
216
|
-
self.server_args =
|
|
217
|
-
"ServerArgs",
|
|
218
|
-
(),
|
|
219
|
-
{
|
|
220
|
-
"enable_dp_attention": False, # Default value for testing
|
|
221
|
-
},
|
|
222
|
-
)
|
|
221
|
+
self.server_args = get_global_server_args()
|
|
223
222
|
|
|
224
223
|
# Model-config stub with MLA attributes
|
|
225
224
|
self.model_config = type(
|
|
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
|
|
|
320
319
|
class TestTRTLLMMLA(CustomTestCase):
|
|
321
320
|
"""Test suite for TRTLLM MLA backend with centralized configuration."""
|
|
322
321
|
|
|
322
|
+
@classmethod
|
|
323
|
+
def setUpClass(cls):
|
|
324
|
+
"""Set up global server args for testing."""
|
|
325
|
+
server_args = ServerArgs(model_path="dummy")
|
|
326
|
+
server_args.enable_dp_attention = False
|
|
327
|
+
set_global_server_args_for_scheduler(server_args)
|
|
328
|
+
|
|
329
|
+
@classmethod
|
|
330
|
+
def tearDownClass(cls):
|
|
331
|
+
pass
|
|
332
|
+
|
|
323
333
|
def _merge_config(self, test_case):
|
|
324
334
|
"""Merge test case with default configuration."""
|
|
325
335
|
config = DEFAULT_CONFIG.copy()
|
|
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
841
851
|
backend.init_forward_metadata(fb)
|
|
842
852
|
|
|
843
853
|
# Verify metadata exists
|
|
844
|
-
self.assertIsNotNone(backend.
|
|
845
|
-
self.assertIsInstance(
|
|
854
|
+
self.assertIsNotNone(backend.forward_decode_metadata)
|
|
855
|
+
self.assertIsInstance(
|
|
856
|
+
backend.forward_decode_metadata, TRTLLMMLADecodeMetadata
|
|
857
|
+
)
|
|
846
858
|
|
|
847
859
|
# Test metadata structure
|
|
848
|
-
metadata = backend.
|
|
849
|
-
self.assertIsNotNone(
|
|
850
|
-
metadata.workspace, "Workspace should be allocated"
|
|
851
|
-
)
|
|
860
|
+
metadata = backend.forward_decode_metadata
|
|
852
861
|
self.assertIsNotNone(
|
|
853
862
|
metadata.block_kv_indices, "Block KV indices should be created"
|
|
854
863
|
)
|
|
855
864
|
|
|
856
|
-
# Test workspace properties
|
|
857
|
-
self.assertEqual(metadata.workspace.device.type, "cuda")
|
|
858
|
-
self.assertEqual(metadata.workspace.dtype, torch.uint8)
|
|
859
|
-
self.assertGreater(
|
|
860
|
-
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
|
|
861
|
-
)
|
|
862
|
-
|
|
863
865
|
# Test block KV indices properties
|
|
864
866
|
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
|
|
865
867
|
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
|
|
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
915
917
|
|
|
916
918
|
# Should satisfy TRT-LLM and Triton constraints
|
|
917
919
|
trtllm_constraint = 128 // scenario["page_size"]
|
|
918
|
-
|
|
919
|
-
|
|
920
|
+
triton_constraint = get_num_page_per_block_flashmla(
|
|
921
|
+
scenario["page_size"]
|
|
920
922
|
)
|
|
923
|
+
constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
|
|
921
924
|
self.assertEqual(
|
|
922
925
|
calculated_blocks % constraint_lcm,
|
|
923
926
|
0,
|
|
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
965
968
|
|
|
966
969
|
# Initialize metadata
|
|
967
970
|
backend.init_forward_metadata(fb)
|
|
968
|
-
metadata = backend.
|
|
971
|
+
metadata = backend.forward_decode_metadata
|
|
969
972
|
|
|
970
973
|
# Verify KV indices structure
|
|
971
974
|
block_kv_indices = metadata.block_kv_indices
|
|
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
1016
1019
|
|
|
1017
1020
|
# Verify CUDA graph buffers are allocated
|
|
1018
1021
|
self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
|
|
1019
|
-
self.assertIsNotNone(backend.decode_cuda_graph_workspace)
|
|
1020
1022
|
|
|
1021
1023
|
# Test capture metadata
|
|
1022
1024
|
seq_lens = torch.full(
|
|
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
1038
1040
|
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
|
|
1039
1041
|
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
|
|
1040
1042
|
|
|
1041
|
-
self.assertIsNotNone(capture_metadata.workspace)
|
|
1042
1043
|
self.assertIsNotNone(capture_metadata.block_kv_indices)
|
|
1043
1044
|
|
|
1044
1045
|
# Test replay with different sequence lengths
|
|
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
1061
1062
|
)
|
|
1062
1063
|
|
|
1063
1064
|
# Verify replay updated the metadata
|
|
1064
|
-
replay_metadata = backend.
|
|
1065
|
+
replay_metadata = backend.forward_decode_metadata
|
|
1065
1066
|
self.assertIsNotNone(replay_metadata)
|
|
1066
|
-
self.assertEqual(
|
|
1067
|
-
replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
|
|
1068
|
-
)
|
|
1069
1067
|
|
|
1070
1068
|
def test_metadata_consistency_across_calls(self):
|
|
1071
1069
|
"""Test metadata consistency across multiple forward calls."""
|
|
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
1083
1081
|
config["batch_size"], seq_lens_1, backend, model_runner, config
|
|
1084
1082
|
)
|
|
1085
1083
|
backend.init_forward_metadata(fb_1)
|
|
1086
|
-
metadata_1 = backend.
|
|
1084
|
+
metadata_1 = backend.forward_decode_metadata
|
|
1087
1085
|
|
|
1088
1086
|
# Second call with same sequence lengths
|
|
1089
1087
|
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
|
|
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
1091
1089
|
config["batch_size"], seq_lens_2, backend, model_runner, config
|
|
1092
1090
|
)
|
|
1093
1091
|
backend.init_forward_metadata(fb_2)
|
|
1094
|
-
metadata_2 = backend.
|
|
1092
|
+
metadata_2 = backend.forward_decode_metadata
|
|
1095
1093
|
|
|
1096
1094
|
# Metadata structure should be consistent
|
|
1097
|
-
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
|
|
1098
1095
|
self.assertEqual(
|
|
1099
1096
|
metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
|
|
1100
1097
|
)
|
|
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
1105
1102
|
config["batch_size"], seq_lens_3, backend, model_runner, config
|
|
1106
1103
|
)
|
|
1107
1104
|
backend.init_forward_metadata(fb_3)
|
|
1108
|
-
metadata_3 = backend.
|
|
1105
|
+
metadata_3 = backend.forward_decode_metadata
|
|
1109
1106
|
|
|
1110
1107
|
# Should still have valid structure
|
|
1111
|
-
self.assertIsNotNone(metadata_3.workspace)
|
|
1112
1108
|
self.assertIsNotNone(metadata_3.block_kv_indices)
|
|
1113
1109
|
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
|
1114
1110
|
|
|
@@ -1263,6 +1259,178 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
|
1263
1259
|
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
|
1264
1260
|
)
|
|
1265
1261
|
|
|
1262
|
+
def test_draft_extend_padding_unpadding_kernels(self):
|
|
1263
|
+
"""Test TRTLLM MLA Triton kernels: pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel."""
|
|
1264
|
+
|
|
1265
|
+
# Import the kernels
|
|
1266
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
|
1267
|
+
pad_draft_extend_query_kernel,
|
|
1268
|
+
unpad_draft_extend_output_kernel,
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
def _create_test_data(
|
|
1272
|
+
self, batch_size, max_seq_len, num_heads, head_dim, dtype=torch.float32
|
|
1273
|
+
):
|
|
1274
|
+
"""Create test data for kernel testing."""
|
|
1275
|
+
device = torch.device("cuda")
|
|
1276
|
+
|
|
1277
|
+
# Create sequence lengths (varying lengths for each batch)
|
|
1278
|
+
seq_lens = torch.randint(
|
|
1279
|
+
1, max_seq_len + 1, (batch_size,), device=device, dtype=torch.int32
|
|
1280
|
+
)
|
|
1281
|
+
|
|
1282
|
+
# Create cumulative sequence lengths
|
|
1283
|
+
cum_seq_lens = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
|
|
1284
|
+
cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0)
|
|
1285
|
+
|
|
1286
|
+
# Create input query tensor (flattened format)
|
|
1287
|
+
total_tokens = cum_seq_lens[-1].item()
|
|
1288
|
+
q_input = torch.randn(
|
|
1289
|
+
total_tokens, num_heads, head_dim, device=device, dtype=dtype
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
# Create padded query tensor (batch format)
|
|
1293
|
+
padded_q = torch.zeros(
|
|
1294
|
+
batch_size, max_seq_len, num_heads, head_dim, device=device, dtype=dtype
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
return q_input, padded_q, seq_lens, cum_seq_lens
|
|
1298
|
+
|
|
1299
|
+
def _create_test_output_data(
|
|
1300
|
+
self,
|
|
1301
|
+
batch_size,
|
|
1302
|
+
token_per_batch,
|
|
1303
|
+
tp_q_head_num,
|
|
1304
|
+
v_head_dim,
|
|
1305
|
+
dtype=torch.float32,
|
|
1306
|
+
):
|
|
1307
|
+
"""Create test data for unpad kernel testing."""
|
|
1308
|
+
device = torch.device("cuda")
|
|
1309
|
+
|
|
1310
|
+
# Create accept lengths (varying lengths for each batch)
|
|
1311
|
+
accept_lengths = torch.randint(
|
|
1312
|
+
1, token_per_batch + 1, (batch_size,), device=device, dtype=torch.int32
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1315
|
+
# Create cumulative accept lengths
|
|
1316
|
+
cum_accept_lengths = torch.zeros(
|
|
1317
|
+
batch_size + 1, device=device, dtype=torch.int32
|
|
1318
|
+
)
|
|
1319
|
+
cum_accept_lengths[1:] = torch.cumsum(accept_lengths, dim=0)
|
|
1320
|
+
|
|
1321
|
+
# Create raw output tensor (batch format)
|
|
1322
|
+
raw_out = torch.randn(
|
|
1323
|
+
batch_size,
|
|
1324
|
+
token_per_batch,
|
|
1325
|
+
tp_q_head_num,
|
|
1326
|
+
v_head_dim,
|
|
1327
|
+
device=device,
|
|
1328
|
+
dtype=dtype,
|
|
1329
|
+
)
|
|
1330
|
+
|
|
1331
|
+
# Create output tensor (flattened format)
|
|
1332
|
+
total_tokens = cum_accept_lengths[-1].item()
|
|
1333
|
+
output = torch.empty(
|
|
1334
|
+
total_tokens, tp_q_head_num, v_head_dim, device=device, dtype=dtype
|
|
1335
|
+
)
|
|
1336
|
+
|
|
1337
|
+
return raw_out, output, accept_lengths, cum_accept_lengths
|
|
1338
|
+
|
|
1339
|
+
# Test 1: pad_draft_extend_query_kernel basic functionality
|
|
1340
|
+
with self.subTest(test="pad_kernel_basic"):
|
|
1341
|
+
batch_size = 4
|
|
1342
|
+
max_seq_len = 8
|
|
1343
|
+
num_heads = 16
|
|
1344
|
+
head_dim = 64
|
|
1345
|
+
|
|
1346
|
+
q_input, padded_q, seq_lens, cum_seq_lens = _create_test_data(
|
|
1347
|
+
self, batch_size, max_seq_len, num_heads, head_dim
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
# Launch kernel
|
|
1351
|
+
BLOCK_SIZE = 64
|
|
1352
|
+
grid = (batch_size * max_seq_len,)
|
|
1353
|
+
|
|
1354
|
+
pad_draft_extend_query_kernel[grid](
|
|
1355
|
+
q_ptr=q_input,
|
|
1356
|
+
padded_q_ptr=padded_q,
|
|
1357
|
+
seq_lens_q_ptr=seq_lens,
|
|
1358
|
+
cumsum_ptr=cum_seq_lens,
|
|
1359
|
+
batch_size=batch_size,
|
|
1360
|
+
max_seq_len=max_seq_len,
|
|
1361
|
+
num_heads=num_heads,
|
|
1362
|
+
head_dim=head_dim,
|
|
1363
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
1364
|
+
)
|
|
1365
|
+
|
|
1366
|
+
# Verify the padding worked correctly
|
|
1367
|
+
for i in range(batch_size):
|
|
1368
|
+
seq_len = seq_lens[i].item()
|
|
1369
|
+
|
|
1370
|
+
# Check that valid positions are copied correctly
|
|
1371
|
+
for pos in range(seq_len):
|
|
1372
|
+
input_start = cum_seq_lens[i].item()
|
|
1373
|
+
input_pos = input_start + pos
|
|
1374
|
+
|
|
1375
|
+
# Compare input and output for valid positions
|
|
1376
|
+
input_data = q_input[input_pos]
|
|
1377
|
+
output_data = padded_q[i, pos]
|
|
1378
|
+
|
|
1379
|
+
torch.testing.assert_close(
|
|
1380
|
+
input_data, output_data, rtol=1e-5, atol=1e-6
|
|
1381
|
+
)
|
|
1382
|
+
|
|
1383
|
+
# Check that invalid positions are zero
|
|
1384
|
+
for pos in range(seq_len, max_seq_len):
|
|
1385
|
+
output_data = padded_q[i, pos]
|
|
1386
|
+
self.assertTrue(
|
|
1387
|
+
torch.allclose(output_data, torch.zeros_like(output_data)),
|
|
1388
|
+
f"Position {pos} in batch {i} should be zero",
|
|
1389
|
+
)
|
|
1390
|
+
|
|
1391
|
+
# Test 2: unpad_draft_extend_output_kernel basic functionality
|
|
1392
|
+
with self.subTest(test="unpad_kernel_basic"):
|
|
1393
|
+
batch_size = 4
|
|
1394
|
+
token_per_batch = 8
|
|
1395
|
+
tp_q_head_num = 16
|
|
1396
|
+
v_head_dim = 64
|
|
1397
|
+
|
|
1398
|
+
raw_out, output, accept_lengths, cum_accept_lengths = (
|
|
1399
|
+
_create_test_output_data(
|
|
1400
|
+
self, batch_size, token_per_batch, tp_q_head_num, v_head_dim
|
|
1401
|
+
)
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
# Launch kernel
|
|
1405
|
+
BLOCK_SIZE = 64
|
|
1406
|
+
grid = (batch_size * token_per_batch,)
|
|
1407
|
+
|
|
1408
|
+
unpad_draft_extend_output_kernel[grid](
|
|
1409
|
+
raw_out_ptr=raw_out,
|
|
1410
|
+
output_ptr=output,
|
|
1411
|
+
accept_length_ptr=accept_lengths,
|
|
1412
|
+
cumsum_ptr=cum_accept_lengths,
|
|
1413
|
+
batch_size=batch_size,
|
|
1414
|
+
token_per_batch=token_per_batch,
|
|
1415
|
+
tp_q_head_num=tp_q_head_num,
|
|
1416
|
+
v_head_dim=v_head_dim,
|
|
1417
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
1418
|
+
)
|
|
1419
|
+
|
|
1420
|
+
# Verify the unpadding worked correctly
|
|
1421
|
+
for i in range(batch_size):
|
|
1422
|
+
accept_len = accept_lengths[i].item()
|
|
1423
|
+
output_start = cum_accept_lengths[i].item()
|
|
1424
|
+
|
|
1425
|
+
# Check that valid positions are copied correctly
|
|
1426
|
+
for pos in range(accept_len):
|
|
1427
|
+
input_data = raw_out[i, pos]
|
|
1428
|
+
output_data = output[output_start + pos]
|
|
1429
|
+
|
|
1430
|
+
torch.testing.assert_close(
|
|
1431
|
+
input_data, output_data, rtol=1e-5, atol=1e-6
|
|
1432
|
+
)
|
|
1433
|
+
|
|
1266
1434
|
|
|
1267
1435
|
if __name__ == "__main__":
|
|
1268
1436
|
unittest.main()
|
|
@@ -1,16 +1,14 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import ast
|
|
3
3
|
import asyncio
|
|
4
|
-
import json
|
|
5
4
|
import re
|
|
6
5
|
import time
|
|
6
|
+
from typing import Optional
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
import sglang as sgl
|
|
11
|
-
from sglang.
|
|
12
|
-
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
|
13
|
-
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
|
11
|
+
from sglang.utils import download_and_cache_file, read_jsonl
|
|
14
12
|
|
|
15
13
|
INVALID = -9999999
|
|
16
14
|
|