sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,8 @@ from typing import Optional, Tuple
|
|
|
7
7
|
import torch
|
|
8
8
|
import triton
|
|
9
9
|
|
|
10
|
+
from sglang.srt.server_args import get_global_server_args
|
|
11
|
+
|
|
10
12
|
logger = logging.getLogger(__name__)
|
|
11
13
|
|
|
12
14
|
from dataclasses import dataclass
|
|
@@ -16,10 +18,11 @@ import torch.nn.functional as F
|
|
|
16
18
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
|
17
19
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
18
20
|
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
|
19
|
-
from sglang.srt.managers.schedule_batch import
|
|
20
|
-
|
|
21
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
22
|
+
from sglang.srt.mem_cache.common import (
|
|
23
|
+
alloc_paged_token_slots_extend,
|
|
24
|
+
alloc_token_slots,
|
|
21
25
|
get_last_loc,
|
|
22
|
-
global_server_args_dict,
|
|
23
26
|
)
|
|
24
27
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
25
28
|
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
|
@@ -74,7 +77,10 @@ class NgramVerifyInput(SpecInput):
|
|
|
74
77
|
batch.input_ids = self.draft_token
|
|
75
78
|
|
|
76
79
|
if page_size == 1:
|
|
77
|
-
batch.out_cache_loc =
|
|
80
|
+
batch.out_cache_loc = alloc_token_slots(
|
|
81
|
+
batch.tree_cache,
|
|
82
|
+
len(batch.input_ids),
|
|
83
|
+
)
|
|
78
84
|
end_offset = batch.seq_lens + self.draft_token_num
|
|
79
85
|
else:
|
|
80
86
|
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
|
@@ -87,7 +93,8 @@ class NgramVerifyInput(SpecInput):
|
|
|
87
93
|
batch.req_pool_indices,
|
|
88
94
|
prefix_lens,
|
|
89
95
|
)
|
|
90
|
-
batch.out_cache_loc =
|
|
96
|
+
batch.out_cache_loc = alloc_paged_token_slots_extend(
|
|
97
|
+
batch.tree_cache,
|
|
91
98
|
prefix_lens,
|
|
92
99
|
prefix_lens_cpu,
|
|
93
100
|
end_offset,
|
|
@@ -345,10 +352,8 @@ class NgramVerifyInput(SpecInput):
|
|
|
345
352
|
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
|
346
353
|
target_probs=target_probs,
|
|
347
354
|
draft_probs=draft_probs,
|
|
348
|
-
threshold_single=
|
|
349
|
-
|
|
350
|
-
],
|
|
351
|
-
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
|
|
355
|
+
threshold_single=get_global_server_args().speculative_accept_threshold_single,
|
|
356
|
+
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
|
|
352
357
|
deterministic=True,
|
|
353
358
|
)
|
|
354
359
|
|
|
@@ -6,11 +6,12 @@ import torch
|
|
|
6
6
|
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
|
|
7
7
|
|
|
8
8
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
9
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
|
9
10
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
10
|
-
from sglang.srt.model_executor.forward_batch_info import
|
|
11
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
11
12
|
from sglang.srt.server_args import ServerArgs
|
|
12
13
|
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
|
13
|
-
from sglang.srt.speculative.
|
|
14
|
+
from sglang.srt.speculative.ngram_info import NgramVerifyInput
|
|
14
15
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
15
16
|
|
|
16
17
|
logger = logging.getLogger(__name__)
|
|
@@ -207,18 +208,18 @@ class NGRAMWorker:
|
|
|
207
208
|
batch_tokens.append(put_ids)
|
|
208
209
|
self.ngram_cache.batch_put(batch_tokens)
|
|
209
210
|
|
|
210
|
-
def forward_batch_generation(self, batch: ScheduleBatch) ->
|
|
211
|
+
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
|
|
211
212
|
self._prepare_for_speculative_decoding(batch)
|
|
212
213
|
model_worker_batch = batch.get_model_worker_batch()
|
|
213
214
|
num_accepted_tokens = 0
|
|
214
215
|
|
|
215
216
|
if model_worker_batch.forward_mode.is_target_verify():
|
|
216
|
-
|
|
217
|
+
batch_result = self.target_worker.forward_batch_generation(
|
|
217
218
|
model_worker_batch, is_verify=True
|
|
218
219
|
)
|
|
219
220
|
logits_output, can_run_cuda_graph = (
|
|
220
|
-
|
|
221
|
-
|
|
221
|
+
batch_result.logits_output,
|
|
222
|
+
batch_result.can_run_cuda_graph,
|
|
222
223
|
)
|
|
223
224
|
verify_input = model_worker_batch.spec_info
|
|
224
225
|
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
|
@@ -228,16 +229,16 @@ class NGRAMWorker:
|
|
|
228
229
|
batch.forward_mode = ForwardMode.DECODE
|
|
229
230
|
|
|
230
231
|
else:
|
|
231
|
-
|
|
232
|
+
batch_result = self.target_worker.forward_batch_generation(
|
|
232
233
|
model_worker_batch
|
|
233
234
|
)
|
|
234
235
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
236
|
+
batch_result.logits_output,
|
|
237
|
+
batch_result.next_token_ids,
|
|
238
|
+
batch_result.can_run_cuda_graph,
|
|
238
239
|
)
|
|
239
240
|
|
|
240
|
-
return
|
|
241
|
+
return GenerationBatchResult(
|
|
241
242
|
logits_output=logits_output,
|
|
242
243
|
next_token_ids=next_token_ids,
|
|
243
244
|
num_accepted_tokens=num_accepted_tokens,
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from enum import IntEnum, auto
|
|
3
|
+
from functools import lru_cache
|
|
3
4
|
from typing import List, Tuple
|
|
4
5
|
|
|
5
6
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
@@ -27,6 +28,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|
|
27
28
|
def is_ngram(self):
|
|
28
29
|
return self == SpeculativeAlgorithm.NGRAM
|
|
29
30
|
|
|
31
|
+
@lru_cache(maxsize=None)
|
|
30
32
|
@staticmethod
|
|
31
33
|
def from_string(name: str):
|
|
32
34
|
name_map = {
|
|
@@ -3,24 +3,33 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
5
|
import time
|
|
6
|
+
from contextlib import contextmanager
|
|
6
7
|
from typing import TYPE_CHECKING, List
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
import triton
|
|
10
11
|
import triton.language as tl
|
|
12
|
+
from huggingface_hub import snapshot_download
|
|
11
13
|
|
|
12
14
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
|
15
|
+
from sglang.srt.distributed.parallel_state import (
|
|
16
|
+
GroupCoordinator,
|
|
17
|
+
patch_tensor_parallel_group,
|
|
18
|
+
)
|
|
13
19
|
from sglang.srt.environ import envs
|
|
20
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
14
21
|
from sglang.srt.managers.schedule_batch import Req
|
|
15
22
|
from sglang.srt.utils import is_cuda, is_hip
|
|
16
23
|
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
|
26
|
+
|
|
27
|
+
|
|
17
28
|
if is_cuda():
|
|
18
29
|
from sgl_kernel import fast_topk
|
|
19
30
|
elif is_hip():
|
|
20
31
|
from sgl_kernel import fast_topk
|
|
21
32
|
|
|
22
|
-
if TYPE_CHECKING:
|
|
23
|
-
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
|
24
33
|
|
|
25
34
|
logger = logging.getLogger(__name__)
|
|
26
35
|
|
|
@@ -436,7 +445,7 @@ def select_top_k_tokens(
|
|
|
436
445
|
return input_ids, hidden_states, scores, tree_info
|
|
437
446
|
|
|
438
447
|
|
|
439
|
-
def
|
|
448
|
+
def generate_simulated_accept_index(
|
|
440
449
|
accept_index,
|
|
441
450
|
predict,
|
|
442
451
|
accept_length,
|
|
@@ -604,3 +613,29 @@ def generate_token_bitmask(
|
|
|
604
613
|
|
|
605
614
|
verify_input.grammar = grammar
|
|
606
615
|
return allocate_token_bitmask
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def load_token_map(token_map_path: str) -> List[int]:
|
|
619
|
+
if not os.path.exists(token_map_path):
|
|
620
|
+
cache_dir = snapshot_download(
|
|
621
|
+
os.path.dirname(token_map_path),
|
|
622
|
+
ignore_patterns=["*.bin", "*.safetensors"],
|
|
623
|
+
)
|
|
624
|
+
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
|
625
|
+
hot_token_id = torch.load(token_map_path, weights_only=True)
|
|
626
|
+
return torch.tensor(hot_token_id, dtype=torch.int64)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
@contextmanager
|
|
630
|
+
def draft_tp_context(tp_group: GroupCoordinator):
|
|
631
|
+
# Draft model doesn't use dp and has its own tp group.
|
|
632
|
+
# We disable mscclpp now because it doesn't support 2 comm groups.
|
|
633
|
+
with patch_tensor_parallel_group(tp_group):
|
|
634
|
+
yield
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def detect_nan(logits_output: LogitsProcessorOutput):
|
|
638
|
+
logits = logits_output.next_token_logits
|
|
639
|
+
if torch.any(torch.isnan(logits)):
|
|
640
|
+
logger.error("Detected errors during sampling! NaN in the logits.")
|
|
641
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
|
@@ -1,29 +1,20 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from contextlib import contextmanager
|
|
3
2
|
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
|
|
7
|
-
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
|
8
6
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
9
7
|
from sglang.srt.server_args import ServerArgs
|
|
10
|
-
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
8
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
11
9
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
10
|
+
from sglang.srt.speculative.spec_utils import draft_tp_context, load_token_map
|
|
12
11
|
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
|
13
12
|
|
|
14
13
|
if is_cuda():
|
|
15
|
-
from sgl_kernel import segment_packbits
|
|
14
|
+
from sgl_kernel import segment_packbits # noqa: F401
|
|
16
15
|
|
|
17
16
|
logger = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@contextmanager
|
|
22
|
-
def draft_tp_context(tp_group: GroupCoordinator):
|
|
23
|
-
# Draft model doesn't use dp and has its own tp group.
|
|
24
|
-
# We disable mscclpp now because it doesn't support 2 comm groups.
|
|
25
|
-
with patch_tensor_parallel_group(tp_group):
|
|
26
|
-
yield
|
|
17
|
+
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
|
|
27
18
|
|
|
28
19
|
|
|
29
20
|
class StandaloneWorker(EAGLEWorker):
|
|
@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
|
|
|
51
42
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
|
52
43
|
server_args.speculative_algorithm
|
|
53
44
|
)
|
|
54
|
-
self.padded_static_len = -1
|
|
55
45
|
|
|
56
46
|
# Override the context length of the draft model to be the same as the target model.
|
|
57
47
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
|
@@ -133,9 +133,9 @@ class TiktokenTokenizer:
|
|
|
133
133
|
)
|
|
134
134
|
return self.encode(ret) if tokenize else ret
|
|
135
135
|
|
|
136
|
-
def __call__(self, text, **kwargs):
|
|
136
|
+
def __call__(self, text: List[str], **kwargs):
|
|
137
137
|
return {
|
|
138
|
-
"input_ids": self.encode(text
|
|
138
|
+
"input_ids": [self.encode(x) for x in text],
|
|
139
139
|
}
|
|
140
140
|
|
|
141
141
|
def init_xgrammar(self):
|
sglang/srt/two_batch_overlap.py
CHANGED
|
@@ -4,10 +4,11 @@ import copy
|
|
|
4
4
|
import dataclasses
|
|
5
5
|
import logging
|
|
6
6
|
from dataclasses import replace
|
|
7
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
|
7
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
11
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
12
13
|
from sglang.srt.layers.communicator import (
|
|
13
14
|
CommunicateContext,
|
|
@@ -20,9 +21,11 @@ from sglang.srt.layers.moe import (
|
|
|
20
21
|
get_tbo_token_distribution_threshold,
|
|
21
22
|
is_tbo_enabled,
|
|
22
23
|
)
|
|
23
|
-
from sglang.srt.layers.moe.token_dispatcher import
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
|
25
|
+
DeepEPDispatcher,
|
|
26
|
+
MooncakeEPDispatcher,
|
|
27
|
+
)
|
|
28
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
26
29
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
27
30
|
ForwardBatch,
|
|
28
31
|
ForwardMode,
|
|
@@ -30,12 +33,13 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
|
30
33
|
)
|
|
31
34
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
|
32
35
|
from sglang.srt.operations_strategy import OperationsStrategy
|
|
33
|
-
from sglang.srt.
|
|
36
|
+
from sglang.srt.server_args import get_global_server_args
|
|
34
37
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
35
38
|
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
|
|
36
39
|
|
|
37
40
|
if TYPE_CHECKING:
|
|
38
41
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
|
42
|
+
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
|
39
43
|
|
|
40
44
|
_is_hip = is_hip()
|
|
41
45
|
|
|
@@ -153,7 +157,7 @@ def _update_device_and_sum_field_from_cpu_field(
|
|
|
153
157
|
cpu_value
|
|
154
158
|
if isinstance(cpu_value, torch.Tensor)
|
|
155
159
|
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
|
|
156
|
-
).to(device=
|
|
160
|
+
).to(device=get_global_server_args().device, non_blocking=True)
|
|
157
161
|
setattr(batch, device_field, new_device_value)
|
|
158
162
|
|
|
159
163
|
if sum_field is not None:
|
|
@@ -362,7 +366,7 @@ class TboDPAttentionPreparer:
|
|
|
362
366
|
):
|
|
363
367
|
|
|
364
368
|
deepep_mode = get_deepep_mode()
|
|
365
|
-
|
|
369
|
+
enable_a2a_moe = not get_moe_a2a_backend().is_none()
|
|
366
370
|
enable_two_batch_overlap = is_tbo_enabled()
|
|
367
371
|
|
|
368
372
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
|
@@ -391,7 +395,7 @@ class TboDPAttentionPreparer:
|
|
|
391
395
|
local_batch.forward_mode.is_extend()
|
|
392
396
|
and not local_batch.forward_mode.is_target_verify()
|
|
393
397
|
)
|
|
394
|
-
and
|
|
398
|
+
and enable_a2a_moe
|
|
395
399
|
and (resolved_deepep_mode.is_low_latency())
|
|
396
400
|
)
|
|
397
401
|
else:
|
|
@@ -582,7 +586,7 @@ class TboForwardBatchPreparer:
|
|
|
582
586
|
sum_field=None,
|
|
583
587
|
)
|
|
584
588
|
_, child_b.extend_start_loc = compute_position(
|
|
585
|
-
|
|
589
|
+
get_global_server_args().attention_backend,
|
|
586
590
|
child_b.extend_prefix_lens,
|
|
587
591
|
child_b.extend_seq_lens,
|
|
588
592
|
child_b.extend_num_tokens,
|
|
@@ -667,6 +671,7 @@ class TboForwardBatchPreparer:
|
|
|
667
671
|
"can_run_dp_cuda_graph",
|
|
668
672
|
"dp_padding_mode",
|
|
669
673
|
"global_forward_mode",
|
|
674
|
+
"is_prefill_only",
|
|
670
675
|
"spec_algorithm",
|
|
671
676
|
"capture_hidden_mode",
|
|
672
677
|
"padded_static_len",
|
|
@@ -686,7 +691,7 @@ class TboForwardBatchPreparer:
|
|
|
686
691
|
|
|
687
692
|
# TODO improve, e.g. unify w/ `init_raw`
|
|
688
693
|
if (
|
|
689
|
-
|
|
694
|
+
get_global_server_args().moe_dense_tp_size == 1
|
|
690
695
|
and batch.global_dp_buffer_len is not None
|
|
691
696
|
):
|
|
692
697
|
sum_len = end_token_index - start_token_index
|
|
@@ -754,7 +759,7 @@ class TboForwardBatchPreparer:
|
|
|
754
759
|
value_a = min(tbo_split_token_index, num_token_non_padded)
|
|
755
760
|
value_b = max(0, num_token_non_padded - tbo_split_token_index)
|
|
756
761
|
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
|
|
757
|
-
device=
|
|
762
|
+
device=get_global_server_args().device, non_blocking=True
|
|
758
763
|
)
|
|
759
764
|
|
|
760
765
|
@classmethod
|
|
@@ -966,9 +971,14 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
|
|
|
966
971
|
class MaybeTboDeepEPDispatcher:
|
|
967
972
|
def __init__(self, **kwargs):
|
|
968
973
|
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
974
|
+
if get_moe_a2a_backend().is_deepep():
|
|
975
|
+
self._inners = [
|
|
976
|
+
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
|
977
|
+
]
|
|
978
|
+
elif get_moe_a2a_backend().is_mooncake():
|
|
979
|
+
self._inners = [
|
|
980
|
+
MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
|
981
|
+
]
|
|
972
982
|
|
|
973
983
|
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
|
974
984
|
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
|
@@ -990,3 +1000,7 @@ class MaybeTboDeepEPDispatcher:
|
|
|
990
1000
|
|
|
991
1001
|
def combine_b(self, **kwargs):
|
|
992
1002
|
return self._execute("combine_b", **kwargs)
|
|
1003
|
+
|
|
1004
|
+
def set_quant_config(self, quant_config: dict):
|
|
1005
|
+
for inner in self._inners:
|
|
1006
|
+
inner.set_quant_config(quant_config)
|
sglang/srt/utils/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# Temporarily do this to avoid changing all imports in the repo
|
|
2
|
-
from .common import *
|
|
2
|
+
from sglang.srt.utils.common import *
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import re
|
|
2
3
|
import sys
|
|
3
4
|
from contextlib import nullcontext
|
|
4
5
|
|
|
@@ -108,7 +109,8 @@ def bench_kineto(
|
|
|
108
109
|
if not with_multiple_kernels:
|
|
109
110
|
for name in kernel_names:
|
|
110
111
|
assert (
|
|
111
|
-
sum([name
|
|
112
|
+
sum([int(re.search(name, line) is not None) for line in prof_lines])
|
|
113
|
+
== 1
|
|
112
114
|
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
|
113
115
|
|
|
114
116
|
# Save chrome traces
|
|
@@ -122,7 +124,7 @@ def bench_kineto(
|
|
|
122
124
|
total_time = 0
|
|
123
125
|
total_num = 0
|
|
124
126
|
for line in prof_lines:
|
|
125
|
-
if name
|
|
127
|
+
if re.search(name, line) is not None:
|
|
126
128
|
time_str = line.split()[-2]
|
|
127
129
|
num_str = line.split()[-1]
|
|
128
130
|
for unit, scale in units.items():
|