sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
|
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
11
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.quantization.base_config import (
|
13
15
|
FusedMoEMethodBase,
|
14
16
|
LinearMethodBase,
|
@@ -24,8 +26,10 @@ from sglang.srt.utils import (
|
|
24
26
|
)
|
25
27
|
|
26
28
|
if TYPE_CHECKING:
|
27
|
-
from sglang.srt.layers.moe.
|
28
|
-
|
29
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
30
|
+
CombineInput,
|
31
|
+
StandardDispatchOutput,
|
32
|
+
)
|
29
33
|
|
30
34
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
31
35
|
|
@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
155
159
|
layer: torch.nn.Module,
|
156
160
|
num_experts: int,
|
157
161
|
hidden_size: int,
|
158
|
-
|
162
|
+
intermediate_size_per_partition: int,
|
159
163
|
params_dtype: torch.dtype,
|
160
164
|
with_bias: bool = False,
|
161
165
|
**extra_weight_attrs,
|
@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
163
167
|
self.with_bias = with_bias
|
164
168
|
|
165
169
|
# Fused gate_up_proj (column parallel)
|
166
|
-
w13_weight_n, w13_weight_k = 2 *
|
170
|
+
w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size
|
167
171
|
if self.use_triton_kernels:
|
168
172
|
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
169
173
|
w13_weight = torch.nn.Parameter(
|
@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
175
179
|
|
176
180
|
if self.with_bias:
|
177
181
|
w13_weight_bias = torch.nn.Parameter(
|
178
|
-
torch.empty(
|
182
|
+
torch.empty(
|
183
|
+
num_experts,
|
184
|
+
2 * intermediate_size_per_partition,
|
185
|
+
dtype=torch.float32,
|
186
|
+
),
|
179
187
|
requires_grad=False,
|
180
188
|
)
|
181
189
|
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
184
192
|
# down_proj (row parallel)
|
185
193
|
w2_weight_n, w2_weight_k = (
|
186
194
|
hidden_size,
|
187
|
-
|
195
|
+
intermediate_size_per_partition,
|
188
196
|
)
|
189
197
|
if self.use_triton_kernels:
|
190
198
|
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
222
230
|
|
223
231
|
return
|
224
232
|
|
233
|
+
def create_moe_runner(
|
234
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
235
|
+
):
|
236
|
+
self.moe_runner_config = moe_runner_config
|
237
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
238
|
+
|
225
239
|
def apply(
|
226
240
|
self,
|
227
241
|
layer: torch.nn.Module,
|
228
|
-
|
229
|
-
|
230
|
-
moe_runner_config: MoeRunnerConfig,
|
231
|
-
) -> torch.Tensor:
|
242
|
+
dispatch_output: StandardDispatchOutput,
|
243
|
+
) -> CombineInput:
|
232
244
|
|
233
245
|
return self.forward(
|
234
|
-
x=x,
|
235
246
|
layer=layer,
|
236
|
-
|
237
|
-
moe_runner_config=moe_runner_config,
|
247
|
+
dispatch_output=dispatch_output,
|
238
248
|
)
|
239
249
|
|
240
250
|
def forward_cuda(
|
241
251
|
self,
|
242
252
|
layer: torch.nn.Module,
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
253
|
+
dispatch_output: StandardDispatchOutput,
|
254
|
+
) -> CombineInput:
|
255
|
+
|
256
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
257
|
+
|
258
|
+
x = dispatch_output.hidden_states
|
259
|
+
topk_output = dispatch_output.topk_output
|
260
|
+
|
261
|
+
moe_runner_config = self.moe_runner_config
|
247
262
|
|
248
263
|
if self.use_triton_kernels:
|
249
264
|
if self.with_bias:
|
250
265
|
assert self.triton_kernel_moe_with_bias_forward is not None
|
251
|
-
|
266
|
+
output = self.triton_kernel_moe_with_bias_forward(
|
252
267
|
hidden_states=x,
|
253
268
|
w1=layer.w13_weight,
|
254
269
|
w2=layer.w2_weight,
|
@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
261
276
|
)
|
262
277
|
else:
|
263
278
|
assert self.triton_kernel_moe_forward is not None
|
264
|
-
|
279
|
+
output = self.triton_kernel_moe_forward(
|
265
280
|
hidden_states=x,
|
266
281
|
w1=layer.w13_weight,
|
267
282
|
w2=layer.w2_weight,
|
268
283
|
topk_output=topk_output,
|
269
284
|
moe_runner_config=moe_runner_config,
|
270
285
|
)
|
286
|
+
return StandardCombineInput(hidden_states=output)
|
271
287
|
else:
|
272
288
|
if _use_aiter:
|
273
289
|
assert not moe_runner_config.no_combine, "unsupported"
|
@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
284
300
|
topk_weights = torch.ones_like(
|
285
301
|
topk_weights, dtype=torch.float32
|
286
302
|
) # topk_weights must be FP32 (float32)
|
287
|
-
|
303
|
+
output = fused_moe(
|
288
304
|
x,
|
289
305
|
layer.w13_weight,
|
290
306
|
layer.w2_weight,
|
@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
296
312
|
else ActivationType.Gelu
|
297
313
|
),
|
298
314
|
)
|
315
|
+
return StandardCombineInput(hidden_states=output)
|
299
316
|
else:
|
300
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
301
|
-
fused_experts,
|
302
|
-
)
|
303
317
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
b1=getattr(layer, "w13_weight_bias", None),
|
318
|
+
quant_info = TritonMoeQuantInfo(
|
319
|
+
w13_weight=layer.w13_weight,
|
320
|
+
w2_weight=layer.w2_weight,
|
321
|
+
b13=getattr(layer, "w13_weight_bias", None),
|
309
322
|
b2=getattr(layer, "w2_weight_bias", None),
|
310
|
-
topk_output=topk_output,
|
311
|
-
moe_runner_config=moe_runner_config,
|
312
323
|
)
|
324
|
+
return self.runner.run(dispatch_output, quant_info)
|
313
325
|
|
314
326
|
def forward_cpu(
|
315
327
|
self,
|
316
328
|
layer: torch.nn.Module,
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
329
|
+
dispatch_output: StandardDispatchOutput,
|
330
|
+
) -> CombineInput:
|
331
|
+
|
332
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
333
|
+
|
334
|
+
x = dispatch_output.hidden_states
|
335
|
+
topk_output = dispatch_output.topk_output
|
336
|
+
|
337
|
+
moe_runner_config = self.moe_runner_config
|
338
|
+
|
321
339
|
assert (
|
322
340
|
moe_runner_config.activation == "silu"
|
323
341
|
), f"activation = {moe_runner_config.activation} is not supported."
|
@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
332
350
|
x, topk_weights = apply_topk_weights_cpu(
|
333
351
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
334
352
|
)
|
335
|
-
|
353
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
336
354
|
x,
|
337
355
|
layer.w13_weight,
|
338
356
|
layer.w2_weight,
|
@@ -348,33 +366,103 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
348
366
|
None, # a2_scale
|
349
367
|
True, # is_vnni
|
350
368
|
)
|
369
|
+
return StandardCombineInput(hidden_states=output)
|
351
370
|
else:
|
352
371
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
353
372
|
|
354
|
-
|
373
|
+
output = moe_forward_native(
|
355
374
|
layer,
|
356
375
|
x,
|
357
376
|
topk_output,
|
358
377
|
moe_runner_config,
|
359
378
|
)
|
379
|
+
return StandardCombineInput(hidden_states=output)
|
360
380
|
|
361
381
|
def forward_npu(
|
362
382
|
self,
|
363
383
|
layer: torch.nn.Module,
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
384
|
+
dispatch_output: StandardDispatchOutput,
|
385
|
+
) -> CombineInput:
|
386
|
+
|
387
|
+
import torch_npu
|
388
|
+
|
389
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
390
|
+
|
391
|
+
x = dispatch_output.hidden_states
|
392
|
+
topk_weights, topk_ids, _ = dispatch_output.topk_output
|
393
|
+
|
394
|
+
original_dtype = x.dtype
|
395
|
+
num_tokens = x.shape[0]
|
396
|
+
topk_weights = topk_weights.to(x.dtype)
|
397
|
+
topk_ids = topk_ids.to(torch.int32)
|
398
|
+
num_experts = layer.num_experts
|
399
|
+
top_k = layer.top_k
|
400
|
+
row_idx_len = num_tokens * top_k
|
401
|
+
row_idx = (
|
402
|
+
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
|
403
|
+
.view(top_k, -1)
|
404
|
+
.permute(1, 0)
|
405
|
+
.contiguous()
|
406
|
+
)
|
369
407
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
408
|
+
hidden_states, expanded_row_idx, expanded_expert_idx = (
|
409
|
+
torch_npu.npu_moe_init_routing(
|
410
|
+
x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
|
411
|
+
)
|
412
|
+
)
|
413
|
+
|
414
|
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
415
|
+
expanded_expert_idx, num_experts
|
375
416
|
)
|
376
417
|
|
377
|
-
|
418
|
+
expert_tokens = expert_tokens.to(torch.int64)
|
419
|
+
if layer.w13_weight.shape[-1] == layer.hidden_size:
|
420
|
+
w13 = layer.w13_weight.transpose(1, 2)
|
421
|
+
w2 = layer.w2_weight.transpose(1, 2)
|
422
|
+
|
423
|
+
# gmm1: gate_up_proj
|
424
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
425
|
+
x=[hidden_states],
|
426
|
+
weight=[w13],
|
427
|
+
split_item=2,
|
428
|
+
group_list_type=0,
|
429
|
+
group_type=0,
|
430
|
+
group_list=expert_tokens,
|
431
|
+
output_dtype=original_dtype,
|
432
|
+
)[0]
|
433
|
+
|
434
|
+
# act_fn:
|
435
|
+
if self.moe_runner_config.activation == "silu":
|
436
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
437
|
+
else:
|
438
|
+
from sglang.srt.layers.activation import GeluAndMul
|
439
|
+
|
440
|
+
hidden_states = GeluAndMul()(hidden_states)
|
441
|
+
|
442
|
+
# gmm2: down_proj
|
443
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
444
|
+
x=[hidden_states],
|
445
|
+
weight=[w2],
|
446
|
+
split_item=2,
|
447
|
+
group_list_type=0,
|
448
|
+
group_type=0,
|
449
|
+
group_list=expert_tokens,
|
450
|
+
output_dtype=original_dtype,
|
451
|
+
)[0]
|
452
|
+
|
453
|
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
454
|
+
hidden_states,
|
455
|
+
skip1=None,
|
456
|
+
skip2=None,
|
457
|
+
bias=None,
|
458
|
+
scales=topk_weights,
|
459
|
+
expanded_src_to_dst_row=expanded_row_idx,
|
460
|
+
export_for_source_row=topk_ids,
|
461
|
+
)
|
462
|
+
|
463
|
+
return StandardCombineInput(hidden_states=final_hidden_states)
|
464
|
+
|
465
|
+
def forward_tpu(self, *args, **kwargs) -> CombineInput:
|
378
466
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
379
467
|
|
380
468
|
forward_native = forward_cpu
|
@@ -17,12 +17,15 @@ from sglang.srt.layers.quantization.base_config import (
|
|
17
17
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
18
18
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
19
19
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
20
|
-
from sglang.srt.utils import set_weight_attrs
|
20
|
+
from sglang.srt.utils import is_npu, set_weight_attrs
|
21
21
|
|
22
22
|
if TYPE_CHECKING:
|
23
23
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
24
24
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
25
|
-
from sglang.srt.layers.moe.
|
25
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
26
|
+
CombineInput,
|
27
|
+
StandardDispatchOutput,
|
28
|
+
)
|
26
29
|
|
27
30
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
28
31
|
|
@@ -133,7 +136,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
133
136
|
layer: EPMoE,
|
134
137
|
num_experts: int,
|
135
138
|
hidden_size: int,
|
136
|
-
|
139
|
+
intermediate_size_per_partition: int,
|
137
140
|
params_dtype: torch.dtype,
|
138
141
|
**extra_weight_attrs,
|
139
142
|
):
|
@@ -145,7 +148,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
145
148
|
w13_weight = torch.nn.Parameter(
|
146
149
|
torch.empty(
|
147
150
|
num_experts,
|
148
|
-
|
151
|
+
intermediate_size_per_partition * 2,
|
149
152
|
hidden_size // 2,
|
150
153
|
dtype=torch.int8,
|
151
154
|
),
|
@@ -159,7 +162,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
159
162
|
torch.empty(
|
160
163
|
num_experts,
|
161
164
|
hidden_size,
|
162
|
-
|
165
|
+
intermediate_size_per_partition // 2,
|
163
166
|
dtype=torch.int8,
|
164
167
|
),
|
165
168
|
requires_grad=False,
|
@@ -173,7 +176,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
173
176
|
w13_weight_scale = torch.nn.Parameter(
|
174
177
|
torch.zeros(
|
175
178
|
num_experts,
|
176
|
-
2 *
|
179
|
+
2 * intermediate_size_per_partition,
|
177
180
|
hidden_size // self.quant_config.group_size,
|
178
181
|
dtype=torch.float32,
|
179
182
|
),
|
@@ -186,7 +189,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
186
189
|
torch.zeros(
|
187
190
|
num_experts,
|
188
191
|
hidden_size,
|
189
|
-
|
192
|
+
intermediate_size_per_partition // self.quant_config.group_size,
|
190
193
|
dtype=torch.float32,
|
191
194
|
),
|
192
195
|
requires_grad=False,
|
@@ -220,13 +223,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
220
223
|
)
|
221
224
|
self.c_strides1 = torch.full(
|
222
225
|
(num_experts, 3),
|
223
|
-
2 *
|
226
|
+
2 * intermediate_size_per_partition,
|
224
227
|
device=device,
|
225
228
|
dtype=torch.int64,
|
226
229
|
)
|
227
230
|
self.a_strides2 = torch.full(
|
228
231
|
(num_experts, 3),
|
229
|
-
|
232
|
+
intermediate_size_per_partition,
|
230
233
|
device=device,
|
231
234
|
dtype=torch.int64,
|
232
235
|
)
|
@@ -282,16 +285,22 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
282
285
|
)
|
283
286
|
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
284
287
|
|
288
|
+
def create_moe_runner(
|
289
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
290
|
+
):
|
291
|
+
self.moe_runner_config = moe_runner_config
|
292
|
+
|
285
293
|
def apply(
|
286
294
|
self,
|
287
295
|
layer: EPMoE,
|
288
|
-
|
289
|
-
|
290
|
-
moe_runner_config: MoeRunnerConfig,
|
291
|
-
) -> torch.Tensor:
|
296
|
+
dispatch_output: StandardDispatchOutput,
|
297
|
+
) -> CombineInput:
|
292
298
|
|
293
|
-
# TODO(ch-wan): move it out of this class
|
294
299
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
300
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
301
|
+
|
302
|
+
x = dispatch_output.hidden_states
|
303
|
+
topk_output = dispatch_output.topk_output
|
295
304
|
|
296
305
|
topk_weights, topk_ids, _ = topk_output
|
297
306
|
local_topk_ids = topk_ids
|
@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
328
337
|
layer.w13_input_scale,
|
329
338
|
layer.w2_input_scale,
|
330
339
|
)
|
331
|
-
if moe_runner_config.routed_scaling_factor is not None:
|
332
|
-
output *= moe_runner_config.routed_scaling_factor
|
333
|
-
return output
|
340
|
+
if self.moe_runner_config.routed_scaling_factor is not None:
|
341
|
+
output *= self.moe_runner_config.routed_scaling_factor
|
342
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
5
5
|
import torch
|
6
6
|
from torch.nn.parameter import Parameter
|
7
7
|
|
8
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
9
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
8
10
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
9
11
|
from sglang.srt.layers.quantization.base_config import (
|
10
12
|
FusedMoEMethodBase,
|
@@ -26,8 +28,10 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
26
28
|
from sglang.srt.utils import set_weight_attrs
|
27
29
|
|
28
30
|
if TYPE_CHECKING:
|
29
|
-
from sglang.srt.layers.moe.
|
30
|
-
|
31
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
32
|
+
CombineInput,
|
33
|
+
StandardDispatchOutput,
|
34
|
+
)
|
31
35
|
|
32
36
|
_is_fp8_fnuz = is_fp8_fnuz()
|
33
37
|
|
@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
209
213
|
layer: torch.nn.Module,
|
210
214
|
num_experts: int,
|
211
215
|
hidden_size: int,
|
212
|
-
|
216
|
+
intermediate_size_per_partition: int,
|
213
217
|
params_dtype: torch.dtype,
|
214
218
|
**extra_weight_attrs,
|
215
219
|
):
|
@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
218
222
|
# WEIGHTS
|
219
223
|
w13_weight = torch.nn.Parameter(
|
220
224
|
torch.empty(
|
221
|
-
num_experts,
|
225
|
+
num_experts,
|
226
|
+
2 * intermediate_size_per_partition,
|
227
|
+
hidden_size,
|
228
|
+
dtype=fp8_dtype,
|
222
229
|
),
|
223
230
|
requires_grad=False,
|
224
231
|
)
|
@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
226
233
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
227
234
|
|
228
235
|
w2_weight = torch.nn.Parameter(
|
229
|
-
torch.empty(
|
236
|
+
torch.empty(
|
237
|
+
num_experts,
|
238
|
+
hidden_size,
|
239
|
+
intermediate_size_per_partition,
|
240
|
+
dtype=fp8_dtype,
|
241
|
+
),
|
230
242
|
requires_grad=False,
|
231
243
|
)
|
232
244
|
layer.register_parameter("w2_weight", w2_weight)
|
233
245
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
234
246
|
|
235
247
|
w13_weight_scale = torch.nn.Parameter(
|
236
|
-
torch.ones(
|
248
|
+
torch.ones(
|
249
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
250
|
+
),
|
237
251
|
requires_grad=False,
|
238
252
|
)
|
239
253
|
w2_weight_scale = torch.nn.Parameter(
|
@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
266
280
|
layer.w2_weight_scale.data, requires_grad=False
|
267
281
|
)
|
268
282
|
|
283
|
+
def create_moe_runner(
|
284
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
285
|
+
):
|
286
|
+
self.moe_runner_config = moe_runner_config
|
287
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
288
|
+
|
269
289
|
def apply(
|
270
290
|
self,
|
271
291
|
layer: torch.nn.Module,
|
272
|
-
|
273
|
-
|
274
|
-
moe_runner_config: MoeRunnerConfig,
|
275
|
-
) -> torch.Tensor:
|
276
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
292
|
+
dispatch_output: StandardDispatchOutput,
|
293
|
+
) -> CombineInput:
|
277
294
|
|
278
|
-
|
279
|
-
|
280
|
-
layer.
|
281
|
-
layer.w2_weight,
|
282
|
-
topk_output=topk_output,
|
283
|
-
moe_runner_config=moe_runner_config,
|
295
|
+
quant_info = TritonMoeQuantInfo(
|
296
|
+
w13_weight=layer.w13_weight,
|
297
|
+
w2_weight=layer.w2_weight,
|
284
298
|
use_fp8_w8a8=True,
|
285
299
|
per_channel_quant=True,
|
286
|
-
|
287
|
-
w2_scale=
|
288
|
-
|
300
|
+
w13_scale=layer.w13_weight_scale,
|
301
|
+
w2_scale=layer.w2_weight_scale,
|
302
|
+
a13_scale=layer.w13_input_scale,
|
289
303
|
a2_scale=layer.w2_input_scale,
|
290
304
|
)
|
305
|
+
return self.runner.run(dispatch_output, quant_info)
|