sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration argument parser for command-line applications.
|
|
3
|
+
Handles merging of YAML configuration files with command-line arguments.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List
|
|
9
|
+
|
|
10
|
+
import yaml
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConfigArgumentMerger:
|
|
16
|
+
"""Handles merging of configuration file arguments with command-line arguments."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, boolean_actions: List[str] = None):
|
|
19
|
+
"""Initialize with list of boolean action destinations."""
|
|
20
|
+
self.boolean_actions = boolean_actions or []
|
|
21
|
+
|
|
22
|
+
def merge_config_with_args(self, cli_args: List[str]) -> List[str]:
|
|
23
|
+
"""
|
|
24
|
+
Merge configuration file arguments with command-line arguments.
|
|
25
|
+
|
|
26
|
+
Configuration arguments are inserted after the subcommand to maintain
|
|
27
|
+
proper precedence: CLI > Config > Defaults
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
cli_args: List of command-line arguments
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Merged argument list with config values inserted
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If multiple config files specified or no config file provided
|
|
37
|
+
"""
|
|
38
|
+
config_file_path = self._extract_config_file_path(cli_args)
|
|
39
|
+
if not config_file_path:
|
|
40
|
+
return cli_args
|
|
41
|
+
|
|
42
|
+
config_args = self._parse_yaml_config(config_file_path)
|
|
43
|
+
return self._insert_config_args(cli_args, config_args, config_file_path)
|
|
44
|
+
|
|
45
|
+
def _extract_config_file_path(self, args: List[str]) -> str:
|
|
46
|
+
"""Extract the config file path from arguments."""
|
|
47
|
+
config_indices = [i for i, arg in enumerate(args) if arg == "--config"]
|
|
48
|
+
|
|
49
|
+
if len(config_indices) > 1:
|
|
50
|
+
raise ValueError("Multiple config files specified! Only one allowed.")
|
|
51
|
+
|
|
52
|
+
if not config_indices:
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
config_index = config_indices[0]
|
|
56
|
+
if config_index == len(args) - 1:
|
|
57
|
+
raise ValueError("No config file specified after --config flag!")
|
|
58
|
+
|
|
59
|
+
return args[config_index + 1]
|
|
60
|
+
|
|
61
|
+
def _insert_config_args(
|
|
62
|
+
self, cli_args: List[str], config_args: List[str], config_file_path: str
|
|
63
|
+
) -> List[str]:
|
|
64
|
+
"""Insert configuration arguments into the CLI argument list."""
|
|
65
|
+
config_index = cli_args.index("--config")
|
|
66
|
+
|
|
67
|
+
# Split arguments around config file
|
|
68
|
+
before_config = cli_args[:config_index]
|
|
69
|
+
after_config = cli_args[config_index + 2 :] # Skip --config and file path
|
|
70
|
+
|
|
71
|
+
# Simple merge: config args + CLI args
|
|
72
|
+
return config_args + before_config + after_config
|
|
73
|
+
|
|
74
|
+
def _parse_yaml_config(self, file_path: str) -> List[str]:
|
|
75
|
+
"""
|
|
76
|
+
Parse YAML configuration file and convert to argument list.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
file_path: Path to the YAML configuration file
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List of arguments in format ['--key', 'value', ...]
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ValueError: If file is not YAML or cannot be read
|
|
86
|
+
"""
|
|
87
|
+
self._validate_yaml_file(file_path)
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
with open(file_path, "r") as file:
|
|
91
|
+
config_data = yaml.safe_load(file)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
logger.error(f"Failed to read config file {file_path}: {e}")
|
|
94
|
+
raise
|
|
95
|
+
|
|
96
|
+
# Handle empty files or None content
|
|
97
|
+
if config_data is None:
|
|
98
|
+
config_data = {}
|
|
99
|
+
|
|
100
|
+
if not isinstance(config_data, dict):
|
|
101
|
+
raise ValueError("Config file must contain a dictionary at root level")
|
|
102
|
+
|
|
103
|
+
return self._convert_config_to_args(config_data)
|
|
104
|
+
|
|
105
|
+
def _validate_yaml_file(self, file_path: str) -> None:
|
|
106
|
+
"""Validate that the file is a YAML file."""
|
|
107
|
+
path = Path(file_path)
|
|
108
|
+
if path.suffix.lower() not in [".yaml", ".yml"]:
|
|
109
|
+
raise ValueError(f"Config file must be YAML format, got: {path.suffix}")
|
|
110
|
+
|
|
111
|
+
if not path.exists():
|
|
112
|
+
raise ValueError(f"Config file not found: {file_path}")
|
|
113
|
+
|
|
114
|
+
def _convert_config_to_args(self, config: Dict[str, Any]) -> List[str]:
|
|
115
|
+
"""Convert configuration dictionary to argument list."""
|
|
116
|
+
args = []
|
|
117
|
+
|
|
118
|
+
for key, value in config.items():
|
|
119
|
+
if isinstance(value, bool):
|
|
120
|
+
self._add_boolean_arg(args, key, value)
|
|
121
|
+
elif isinstance(value, list):
|
|
122
|
+
self._add_list_arg(args, key, value)
|
|
123
|
+
else:
|
|
124
|
+
self._add_scalar_arg(args, key, value)
|
|
125
|
+
|
|
126
|
+
return args
|
|
127
|
+
|
|
128
|
+
def _add_boolean_arg(self, args: List[str], key: str, value: bool) -> None:
|
|
129
|
+
"""Add boolean argument to the list."""
|
|
130
|
+
if key in self.boolean_actions:
|
|
131
|
+
# For boolean actions, always add the flag and value
|
|
132
|
+
args.extend([f"--{key}", str(value).lower()])
|
|
133
|
+
else:
|
|
134
|
+
# For regular booleans, only add flag if True
|
|
135
|
+
if value:
|
|
136
|
+
args.append(f"--{key}")
|
|
137
|
+
|
|
138
|
+
def _add_list_arg(self, args: List[str], key: str, value: List[Any]) -> None:
|
|
139
|
+
"""Add list argument to the list."""
|
|
140
|
+
if value: # Only add if list is not empty
|
|
141
|
+
args.append(f"--{key}")
|
|
142
|
+
args.extend(str(item) for item in value)
|
|
143
|
+
|
|
144
|
+
def _add_scalar_arg(self, args: List[str], key: str, value: Any) -> None:
|
|
145
|
+
"""Add scalar argument to the list."""
|
|
146
|
+
args.extend([f"--{key}", str(value)])
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright 2025 SGLang Team
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ==============================================================================
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
23
|
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
|
24
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
|
25
|
+
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
|
26
|
+
from sglang.srt.utils import get_int_env_var
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SboFlags:
|
|
33
|
+
# TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def enable_combine_down_gemm_two_stream_overlap(cls):
|
|
37
|
+
return (
|
|
38
|
+
is_sbo_enabled()
|
|
39
|
+
# currently only cutedsl backend supports it
|
|
40
|
+
and get_moe_runner_backend().is_flashinfer_cutedsl()
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def enable_combine_shared_two_stream_overlap(cls):
|
|
45
|
+
return is_sbo_enabled()
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def fuse_shared_experts_inside_sbo(cls):
|
|
49
|
+
# TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
|
|
50
|
+
return cls.enable_combine_shared_two_stream_overlap()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class CombineOverlapArgs:
|
|
55
|
+
# this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
|
|
56
|
+
overlap: bool
|
|
57
|
+
stream: torch.cuda.Stream
|
|
58
|
+
wait_event: torch.cuda.Event
|
|
59
|
+
num_sms: int
|
|
60
|
+
signal: Optional[torch.Tensor] = None
|
|
61
|
+
threshold: int = 0
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class DownGemmOverlapArgs:
|
|
66
|
+
num_sms: int
|
|
67
|
+
signal: torch.Tensor
|
|
68
|
+
start_event: torch.cuda.Event
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def execute_sbo(
|
|
72
|
+
forward_shared_experts: Callable[[], Any],
|
|
73
|
+
experts: FusedMoE,
|
|
74
|
+
hidden_states: torch.Tensor,
|
|
75
|
+
topk_output: TopKOutput,
|
|
76
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
|
77
|
+
disable_sbo: bool = False,
|
|
78
|
+
):
|
|
79
|
+
|
|
80
|
+
dispatch_output = experts.dispatcher.dispatch(
|
|
81
|
+
hidden_states=hidden_states, topk_output=topk_output
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
|
85
|
+
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
hidden_states = experts.run_moe_core(
|
|
89
|
+
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
|
90
|
+
)
|
|
91
|
+
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
|
92
|
+
e.record()
|
|
93
|
+
|
|
94
|
+
if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
|
|
95
|
+
# TODO reduce sm for non-deepgemm
|
|
96
|
+
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
|
97
|
+
meta_overlap_args["compute_num_sms"]
|
|
98
|
+
):
|
|
99
|
+
forward_shared_experts()
|
|
100
|
+
|
|
101
|
+
hidden_states = experts.dispatcher.combine(
|
|
102
|
+
hidden_states=hidden_states,
|
|
103
|
+
topk_ids=dispatch_output.topk_ids,
|
|
104
|
+
topk_weights=dispatch_output.topk_weights,
|
|
105
|
+
overlap_args=combine_overlap_args,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return hidden_states
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
|
|
112
|
+
if disable_sbo or not (
|
|
113
|
+
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
|
114
|
+
or SboFlags.enable_combine_shared_two_stream_overlap()
|
|
115
|
+
):
|
|
116
|
+
return None, None, {}
|
|
117
|
+
|
|
118
|
+
hidden_states = dispatch_output.hidden_states
|
|
119
|
+
|
|
120
|
+
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
|
121
|
+
|
|
122
|
+
total_num_sms = torch.cuda.get_device_properties(
|
|
123
|
+
device="cuda"
|
|
124
|
+
).multi_processor_count
|
|
125
|
+
communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
|
|
126
|
+
compute_num_sms = total_num_sms - communicate_num_sms
|
|
127
|
+
|
|
128
|
+
assert alt_stream is not None
|
|
129
|
+
combine_wait_event = torch.cuda.Event()
|
|
130
|
+
combine_overlap_args = CombineOverlapArgs(
|
|
131
|
+
overlap=False,
|
|
132
|
+
num_sms=communicate_num_sms,
|
|
133
|
+
stream=alt_stream,
|
|
134
|
+
wait_event=combine_wait_event,
|
|
135
|
+
)
|
|
136
|
+
meta_overlap_args = dict(
|
|
137
|
+
compute_num_sms=compute_num_sms,
|
|
138
|
+
)
|
|
139
|
+
down_gemm_overlap_args = None
|
|
140
|
+
|
|
141
|
+
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
|
|
142
|
+
# TODO use zero_allocator to remove this `torch.zeros` call
|
|
143
|
+
# NOTE ours v2 use uint32 not int32 currently
|
|
144
|
+
combine_signal = torch.zeros(
|
|
145
|
+
num_local_experts, dtype=torch.uint32, device=hidden_states.device
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
down_gemm_overlap_args = DownGemmOverlapArgs(
|
|
149
|
+
signal=combine_signal,
|
|
150
|
+
start_event=combine_wait_event,
|
|
151
|
+
num_sms=compute_num_sms,
|
|
152
|
+
)
|
|
153
|
+
combine_overlap_args.overlap = True
|
|
154
|
+
combine_overlap_args.signal = combine_signal
|
|
155
|
+
combine_overlap_args.threshold = compute_num_sms
|
|
156
|
+
else:
|
|
157
|
+
meta_overlap_args |= dict(
|
|
158
|
+
record_event_after_down=combine_wait_event,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseDraftWorker(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def draft():
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def draft_extend():
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseSpecWorker(ABC):
|
|
21
|
+
@property
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def target_worker(self) -> TpModelWorker:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def draft_worker(self) -> BaseDraftWorker:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def clear_cache_pool(self):
|
|
33
|
+
# TODO: move this abstract method to BaseTpWorker and call through self.model_runner
|
|
34
|
+
pass
|