sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,441 @@
|
|
1
|
+
"""
|
2
|
+
MMMU evaluation for VLMs using the run_eval simple-evals interface.
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import base64
|
9
|
+
import io
|
10
|
+
from typing import List, Optional, Tuple
|
11
|
+
|
12
|
+
from datasets import concatenate_datasets, load_dataset
|
13
|
+
from PIL import Image
|
14
|
+
|
15
|
+
from sglang.test import simple_eval_common as common
|
16
|
+
from sglang.test.simple_eval_common import (
|
17
|
+
HTML_JINJA,
|
18
|
+
Eval,
|
19
|
+
EvalResult,
|
20
|
+
SamplerBase,
|
21
|
+
SingleEvalResult,
|
22
|
+
map_with_progress,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class MMMUVLMEval(Eval):
|
27
|
+
DOMAIN_CAT2SUB_CAT = {
|
28
|
+
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
|
29
|
+
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
|
30
|
+
"Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
|
31
|
+
"Health and Medicine": [
|
32
|
+
"Basic_Medical_Science",
|
33
|
+
"Clinical_Medicine",
|
34
|
+
"Diagnostics_and_Laboratory_Medicine",
|
35
|
+
"Pharmacy",
|
36
|
+
"Public_Health",
|
37
|
+
],
|
38
|
+
"Humanities and Social Science": [
|
39
|
+
"History",
|
40
|
+
"Literature",
|
41
|
+
"Sociology",
|
42
|
+
"Psychology",
|
43
|
+
],
|
44
|
+
"Tech and Engineering": [
|
45
|
+
"Agriculture",
|
46
|
+
"Architecture_and_Engineering",
|
47
|
+
"Computer_Science",
|
48
|
+
"Electronics",
|
49
|
+
"Energy_and_Power",
|
50
|
+
"Materials",
|
51
|
+
"Mechanical_Engineering",
|
52
|
+
],
|
53
|
+
}
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
|
57
|
+
):
|
58
|
+
"""Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
|
59
|
+
self.num_examples = num_examples
|
60
|
+
self.num_threads = num_threads
|
61
|
+
self.seed = seed
|
62
|
+
# Prepare samples deterministically across all MMMU subjects (validation split)
|
63
|
+
self.samples = self._prepare_mmmu_samples(self.num_examples)
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def _to_data_uri(image: Image.Image) -> str:
|
67
|
+
if image.mode == "RGBA":
|
68
|
+
image = image.convert("RGB")
|
69
|
+
buf = io.BytesIO()
|
70
|
+
image.save(buf, format="PNG")
|
71
|
+
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
72
|
+
return f"data:image/png;base64,{b64}"
|
73
|
+
|
74
|
+
@staticmethod
|
75
|
+
def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
|
76
|
+
index2ans = {}
|
77
|
+
all_choices = []
|
78
|
+
ch = ord("A")
|
79
|
+
for opt in options:
|
80
|
+
letter = chr(ch)
|
81
|
+
index2ans[letter] = opt
|
82
|
+
all_choices.append(letter)
|
83
|
+
ch += 1
|
84
|
+
return index2ans, all_choices
|
85
|
+
|
86
|
+
def _prepare_mmmu_samples(self, k: int) -> List[dict]:
|
87
|
+
# Subjects and domains copied from MMMU data_utils to categorize results
|
88
|
+
subjects: List[str] = []
|
89
|
+
for subs in self.DOMAIN_CAT2SUB_CAT.values():
|
90
|
+
subjects.extend(subs)
|
91
|
+
|
92
|
+
# Load validation split of each subject
|
93
|
+
datasets = []
|
94
|
+
for subj in subjects:
|
95
|
+
try:
|
96
|
+
d = load_dataset("MMMU/MMMU", subj, split="validation")
|
97
|
+
# attach subject info via transform
|
98
|
+
d = d.add_column("__subject__", [subj] * len(d))
|
99
|
+
datasets.append(d)
|
100
|
+
except Exception:
|
101
|
+
continue
|
102
|
+
if not datasets:
|
103
|
+
raise RuntimeError("Failed to load MMMU datasets")
|
104
|
+
|
105
|
+
merged = concatenate_datasets(datasets)
|
106
|
+
|
107
|
+
# Deterministic selection: sort by id (fallback to subject+index)
|
108
|
+
def _key(idx):
|
109
|
+
ex = merged[idx]
|
110
|
+
return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
|
111
|
+
|
112
|
+
order = sorted(range(len(merged)), key=_key)
|
113
|
+
picked_indices = order[:k]
|
114
|
+
|
115
|
+
samples: List[dict] = []
|
116
|
+
for idx in picked_indices:
|
117
|
+
ex = merged[idx]
|
118
|
+
subject = ex["__subject__"]
|
119
|
+
image = ex.get("image_1")
|
120
|
+
if image is None or not hasattr(image, "convert"):
|
121
|
+
continue
|
122
|
+
data_uri = self._to_data_uri(image)
|
123
|
+
question = ex.get("question", "")
|
124
|
+
answer = ex.get("answer")
|
125
|
+
raw_options = ex.get("options")
|
126
|
+
question_type = "open"
|
127
|
+
index2ans = None
|
128
|
+
all_choices = None
|
129
|
+
options = None
|
130
|
+
if raw_options:
|
131
|
+
try:
|
132
|
+
options = (
|
133
|
+
raw_options
|
134
|
+
if isinstance(raw_options, list)
|
135
|
+
else list(eval(raw_options))
|
136
|
+
)
|
137
|
+
if isinstance(options, list) and len(options) > 0:
|
138
|
+
index2ans, all_choices = self._build_mc_mapping(options)
|
139
|
+
question_type = "multiple-choice"
|
140
|
+
except Exception:
|
141
|
+
options = None
|
142
|
+
|
143
|
+
# Build final textual prompt; include choices if MC
|
144
|
+
prompt_text = f"Question: {question}\n\n"
|
145
|
+
if options:
|
146
|
+
letters = [chr(ord("A") + i) for i in range(len(options))]
|
147
|
+
for letter, opt in zip(letters, options):
|
148
|
+
prompt_text += f"{letter}) {opt}\n"
|
149
|
+
prompt_text += "\nAnswer: "
|
150
|
+
|
151
|
+
samples.append(
|
152
|
+
{
|
153
|
+
"id": ex.get("id", f"{subject}:{idx}"),
|
154
|
+
"final_input_prompt": prompt_text,
|
155
|
+
"image_data": data_uri,
|
156
|
+
"answer": answer,
|
157
|
+
"question_type": question_type,
|
158
|
+
"index2ans": index2ans,
|
159
|
+
"all_choices": all_choices,
|
160
|
+
"category": subject,
|
161
|
+
}
|
162
|
+
)
|
163
|
+
|
164
|
+
return samples
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
|
168
|
+
"""Split a prompt containing an inline image tag into prefix and suffix.
|
169
|
+
|
170
|
+
If no tag is present, treat the whole prompt as prefix and empty suffix.
|
171
|
+
"""
|
172
|
+
if "<" in prompt and ">" in prompt:
|
173
|
+
prefix = prompt.split("<")[0]
|
174
|
+
suffix = prompt.split(">", 1)[1]
|
175
|
+
return prefix, suffix
|
176
|
+
return prompt, ""
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
|
180
|
+
"""Split a prompt containing an inline image tag into prefix and suffix.
|
181
|
+
|
182
|
+
If no tag is present, treat the whole prompt as prefix and empty suffix.
|
183
|
+
"""
|
184
|
+
# Build a vision+text message for OpenAI-compatible API
|
185
|
+
prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
|
186
|
+
|
187
|
+
content: List[dict] = []
|
188
|
+
if prefix:
|
189
|
+
content.append({"type": "text", "text": prefix})
|
190
|
+
content.append({"type": "image_url", "image_url": {"url": image_data}})
|
191
|
+
if suffix:
|
192
|
+
content.append({"type": "text", "text": suffix})
|
193
|
+
prompt_messages = [{"role": "user", "content": content}]
|
194
|
+
|
195
|
+
return prompt_messages
|
196
|
+
|
197
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
198
|
+
def fn(sample: dict):
|
199
|
+
prompt = sample["final_input_prompt"]
|
200
|
+
image_data = sample["image_data"]
|
201
|
+
prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
|
202
|
+
prompt, image_data
|
203
|
+
)
|
204
|
+
|
205
|
+
# Sample
|
206
|
+
response_text = sampler(prompt_messages)
|
207
|
+
|
208
|
+
# Parse and score
|
209
|
+
gold = sample["answer"]
|
210
|
+
if (
|
211
|
+
sample["question_type"] == "multiple-choice"
|
212
|
+
and sample["all_choices"]
|
213
|
+
and sample["index2ans"]
|
214
|
+
):
|
215
|
+
pred = _parse_multi_choice_response(
|
216
|
+
response_text, sample["all_choices"], sample["index2ans"]
|
217
|
+
)
|
218
|
+
score = 1.0 if (gold is not None and pred == gold) else 0.0
|
219
|
+
extracted_answer = pred
|
220
|
+
else:
|
221
|
+
parsed_list = _parse_open_response(response_text)
|
222
|
+
score = (
|
223
|
+
1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
|
224
|
+
)
|
225
|
+
extracted_answer = ", ".join(map(str, parsed_list))
|
226
|
+
|
227
|
+
html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
|
228
|
+
prompt_messages=prompt_messages,
|
229
|
+
next_message=dict(content=response_text, role="assistant"),
|
230
|
+
score=score,
|
231
|
+
correct_answer=gold,
|
232
|
+
extracted_answer=extracted_answer,
|
233
|
+
)
|
234
|
+
|
235
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
236
|
+
return SingleEvalResult(
|
237
|
+
html=html_rendered,
|
238
|
+
score=score,
|
239
|
+
metrics={"__category__": sample["category"]},
|
240
|
+
convo=convo,
|
241
|
+
)
|
242
|
+
|
243
|
+
results = map_with_progress(fn, self.samples, self.num_threads)
|
244
|
+
|
245
|
+
# Build category table and overall accuracy
|
246
|
+
# Gather per-sample correctness and category
|
247
|
+
per_cat_total: dict[str, int] = {}
|
248
|
+
per_cat_correct: dict[str, int] = {}
|
249
|
+
htmls = []
|
250
|
+
convos = []
|
251
|
+
scores: List[float] = []
|
252
|
+
for r in results:
|
253
|
+
# __category__ stored under metrics
|
254
|
+
cat = r.metrics.get("__category__") if r.metrics else None
|
255
|
+
if cat is None:
|
256
|
+
cat = "Unknown"
|
257
|
+
per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
|
258
|
+
if r.score:
|
259
|
+
per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
|
260
|
+
htmls.append(r.html)
|
261
|
+
convos.append(r.convo)
|
262
|
+
if r.score is not None:
|
263
|
+
scores.append(r.score)
|
264
|
+
|
265
|
+
evaluation_result = {}
|
266
|
+
for cat, tot in per_cat_total.items():
|
267
|
+
corr = per_cat_correct.get(cat, 0)
|
268
|
+
acc = (corr / tot) if tot > 0 else 0.0
|
269
|
+
evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
|
270
|
+
|
271
|
+
printable_results = {}
|
272
|
+
# Domains first
|
273
|
+
for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
|
274
|
+
acc_sum = 0.0
|
275
|
+
num_sum = 0
|
276
|
+
for cat in cats:
|
277
|
+
if cat in evaluation_result:
|
278
|
+
acc_sum += (
|
279
|
+
evaluation_result[cat]["acc"]
|
280
|
+
* evaluation_result[cat]["num_example"]
|
281
|
+
)
|
282
|
+
num_sum += evaluation_result[cat]["num_example"]
|
283
|
+
if num_sum > 0:
|
284
|
+
printable_results[f"Overall-{domain}"] = {
|
285
|
+
"num": num_sum,
|
286
|
+
"acc": round(acc_sum / num_sum, 3),
|
287
|
+
}
|
288
|
+
# add each sub-category row if present
|
289
|
+
for cat in cats:
|
290
|
+
if cat in evaluation_result:
|
291
|
+
printable_results[cat] = {
|
292
|
+
"num": evaluation_result[cat]["num_example"],
|
293
|
+
"acc": evaluation_result[cat]["acc"],
|
294
|
+
}
|
295
|
+
|
296
|
+
# Overall
|
297
|
+
total_num = sum(v["num_example"] for v in evaluation_result.values())
|
298
|
+
overall_acc = (
|
299
|
+
sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
|
300
|
+
/ total_num
|
301
|
+
if total_num > 0
|
302
|
+
else 0.0
|
303
|
+
)
|
304
|
+
printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
|
305
|
+
|
306
|
+
# Build EvalResult
|
307
|
+
return EvalResult(
|
308
|
+
score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
|
309
|
+
)
|
310
|
+
|
311
|
+
|
312
|
+
def _parse_multi_choice_response(
|
313
|
+
response: str, all_choices: List[str], index2ans: dict
|
314
|
+
) -> str:
|
315
|
+
# loosely adapted from benchmark mmmu eval
|
316
|
+
for char in [",", ".", "!", "?", ";", ":", "'"]:
|
317
|
+
response = response.strip(char)
|
318
|
+
response = " " + response + " "
|
319
|
+
|
320
|
+
# Prefer explicit letter with bracket e.g. (A)
|
321
|
+
candidates: List[str] = []
|
322
|
+
for choice in all_choices:
|
323
|
+
if f"({choice})" in response:
|
324
|
+
candidates.append(choice)
|
325
|
+
if not candidates:
|
326
|
+
for choice in all_choices:
|
327
|
+
if f" {choice} " in response:
|
328
|
+
candidates.append(choice)
|
329
|
+
if not candidates and len(response.split()) > 5:
|
330
|
+
# try match by option text
|
331
|
+
for idx, ans in index2ans.items():
|
332
|
+
if ans and ans.lower() in response.lower():
|
333
|
+
candidates.append(idx)
|
334
|
+
if not candidates:
|
335
|
+
# fallback to first choice
|
336
|
+
return all_choices[0]
|
337
|
+
if len(candidates) == 1:
|
338
|
+
return candidates[0]
|
339
|
+
# choose the last occurrence
|
340
|
+
starts = []
|
341
|
+
for can in candidates:
|
342
|
+
pos = response.rfind(f"({can})")
|
343
|
+
if pos == -1:
|
344
|
+
pos = response.rfind(f" {can} ")
|
345
|
+
if pos == -1 and index2ans.get(can):
|
346
|
+
pos = response.lower().rfind(index2ans[can].lower())
|
347
|
+
starts.append(pos)
|
348
|
+
return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
|
349
|
+
|
350
|
+
|
351
|
+
def _check_is_number(s: str) -> bool:
|
352
|
+
try:
|
353
|
+
float(s.replace(",", ""))
|
354
|
+
return True
|
355
|
+
except Exception:
|
356
|
+
return False
|
357
|
+
|
358
|
+
|
359
|
+
def _normalize_str(s: str):
|
360
|
+
s = s.strip()
|
361
|
+
if _check_is_number(s):
|
362
|
+
s = s.replace(",", "")
|
363
|
+
try:
|
364
|
+
v = round(float(s), 2)
|
365
|
+
return [v]
|
366
|
+
except Exception:
|
367
|
+
return [s.lower()]
|
368
|
+
return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
|
369
|
+
|
370
|
+
|
371
|
+
def _extract_numbers(s: str) -> List[str]:
|
372
|
+
import re as _re
|
373
|
+
|
374
|
+
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
|
375
|
+
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
|
376
|
+
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
|
377
|
+
return (
|
378
|
+
_re.findall(pattern_commas, s)
|
379
|
+
+ _re.findall(pattern_scientific, s)
|
380
|
+
+ _re.findall(pattern_simple, s)
|
381
|
+
)
|
382
|
+
|
383
|
+
|
384
|
+
def _parse_open_response(response: str) -> List[str]:
|
385
|
+
import re as _re
|
386
|
+
|
387
|
+
def get_key_subresponses(resp: str) -> List[str]:
|
388
|
+
resp = resp.strip().strip(".").lower()
|
389
|
+
subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
|
390
|
+
indicators = [
|
391
|
+
"could be ",
|
392
|
+
"so ",
|
393
|
+
"is ",
|
394
|
+
"thus ",
|
395
|
+
"therefore ",
|
396
|
+
"final ",
|
397
|
+
"answer ",
|
398
|
+
"result ",
|
399
|
+
]
|
400
|
+
keys = []
|
401
|
+
for i, s in enumerate(subs):
|
402
|
+
cands = [*indicators]
|
403
|
+
if i == len(subs) - 1:
|
404
|
+
cands.append("=")
|
405
|
+
shortest = None
|
406
|
+
for ind in cands:
|
407
|
+
if ind in s:
|
408
|
+
part = s.split(ind)[-1].strip()
|
409
|
+
if not shortest or len(part) < len(shortest):
|
410
|
+
shortest = part
|
411
|
+
if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
412
|
+
keys.append(shortest)
|
413
|
+
return keys or [resp]
|
414
|
+
|
415
|
+
key_resps = get_key_subresponses(response)
|
416
|
+
pred_list = key_resps.copy()
|
417
|
+
for r in key_resps:
|
418
|
+
pred_list.extend(_extract_numbers(r))
|
419
|
+
out = []
|
420
|
+
for x in pred_list:
|
421
|
+
out.extend(_normalize_str(x))
|
422
|
+
# dedup
|
423
|
+
return list(dict.fromkeys(out))
|
424
|
+
|
425
|
+
|
426
|
+
def _eval_open(gold, preds: List[str]) -> bool:
|
427
|
+
if isinstance(gold, list):
|
428
|
+
norm_answers = []
|
429
|
+
for ans in gold:
|
430
|
+
norm_answers.extend(_normalize_str(ans))
|
431
|
+
else:
|
432
|
+
norm_answers = _normalize_str(gold)
|
433
|
+
for p in preds:
|
434
|
+
if isinstance(p, str):
|
435
|
+
for na in norm_answers:
|
436
|
+
if isinstance(na, str) and na in p:
|
437
|
+
return True
|
438
|
+
else:
|
439
|
+
if p in norm_answers:
|
440
|
+
return True
|
441
|
+
return False
|
sglang/test/test_block_fp8.py
CHANGED
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
|
|
621
621
|
w_s,
|
622
622
|
)
|
623
623
|
|
624
|
-
from deep_gemm import
|
624
|
+
from deep_gemm import fp8_m_grouped_gemm_nt_masked
|
625
625
|
|
626
626
|
with torch.inference_mode():
|
627
627
|
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
628
|
-
|
628
|
+
fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
629
629
|
out = oe[:, :M, :]
|
630
630
|
|
631
631
|
self.assertTrue(
|
sglang/test/test_cutlass_moe.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoConfig
|
|
9
9
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
10
10
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
11
11
|
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
12
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
12
13
|
|
13
14
|
|
14
15
|
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
@@ -21,7 +22,7 @@ def calc_diff(x, y):
|
|
21
22
|
|
22
23
|
def get_model_config(tp_size: int):
|
23
24
|
config = AutoConfig.from_pretrained(
|
24
|
-
"deepseek-ai/
|
25
|
+
"deepseek-ai/Deepseek-R1", trust_remote_code=True
|
25
26
|
)
|
26
27
|
E = config.n_routed_experts
|
27
28
|
topk = config.num_experts_per_tok
|
@@ -152,14 +153,31 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
152
153
|
problem_sizes2,
|
153
154
|
)
|
154
155
|
|
156
|
+
topk_output = StandardTopKOutput(
|
157
|
+
topk_weights=topk_weights,
|
158
|
+
topk_ids=topk_ids,
|
159
|
+
router_logits=torch.randn(
|
160
|
+
(batch_size, topk), device=topk_weights.device, dtype=dtype
|
161
|
+
),
|
162
|
+
)
|
163
|
+
|
164
|
+
moe_runner_config = MoeRunnerConfig(
|
165
|
+
num_experts=E,
|
166
|
+
top_k=topk,
|
167
|
+
hidden_size=H,
|
168
|
+
intermediate_size_per_partition=I,
|
169
|
+
params_dtype=dtype,
|
170
|
+
activation="silu",
|
171
|
+
inplace=False,
|
172
|
+
)
|
173
|
+
|
155
174
|
# Note: Triton expects non-transposed weights
|
156
|
-
moe_config = MoeRunnerConfig(inplace=False)
|
157
175
|
triton_lambda = lambda: fused_experts(
|
158
176
|
x,
|
159
177
|
w1,
|
160
178
|
w2,
|
161
|
-
|
162
|
-
|
179
|
+
topk_output,
|
180
|
+
moe_runner_config,
|
163
181
|
use_fp8_w8a8=True,
|
164
182
|
w1_scale=w1_scale,
|
165
183
|
w2_scale=w2_scale,
|
@@ -224,8 +242,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|
224
242
|
x,
|
225
243
|
w1, # Original shape
|
226
244
|
w2, # Original shape
|
227
|
-
|
228
|
-
|
245
|
+
topk_output,
|
246
|
+
moe_runner_config,
|
229
247
|
use_fp8_w8a8=True,
|
230
248
|
w1_scale=w1_scale,
|
231
249
|
w2_scale=w2_scale,
|