sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +378 -160
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +10 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +136 -25
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +63 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +83 -80
- sglang/srt/entrypoints/grpc_server.py +430 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +195 -102
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +58 -6
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +33 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +20 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +10 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +24 -10
- sglang/srt/layers/attention/flashinfer_backend.py +258 -22
- sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
- sglang/srt/layers/attention/utils.py +89 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +12 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +64 -19
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +152 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
- sglang/srt/layers/moe/ep_moe/layer.py +154 -625
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
- sglang/srt/layers/moe/moe_runner/runner.py +6 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
- sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +7 -6
- sglang/srt/layers/moe/utils.py +20 -5
- sglang/srt/layers/quantization/__init__.py +5 -58
- sglang/srt/layers/quantization/awq.py +183 -9
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +27 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +152 -81
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +35 -68
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +23 -48
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +87 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +62 -9
- sglang/srt/layers/rotary_embedding.py +686 -17
- sglang/srt/layers/sampler.py +47 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +69 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +420 -514
- sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +375 -95
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +11 -2
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +517 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +71 -25
- sglang/srt/model_executor/model_runner.py +362 -270
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +418 -140
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +327 -382
- sglang/srt/models/glm4_moe_nextn.py +6 -16
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +32 -199
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +22 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3.py +34 -4
- sglang/srt/models/qwen3_moe.py +19 -37
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +7 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +2 -6
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +28 -2
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +846 -163
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +36 -31
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +272 -82
- sglang/srt/utils/hf_transformers_utils.py +44 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +463 -107
- sglang/test/test_deterministic_utils.py +74 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/models/vila.py +0 -306
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import logging
|
|
5
|
+
from contextlib import ExitStack
|
|
6
|
+
from typing import Any, Callable, Optional, Union
|
|
7
|
+
from unittest.mock import patch
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.fx as fx
|
|
11
|
+
|
|
12
|
+
import sglang.srt.compilation.weak_ref_tensor_jit # noqa: F401
|
|
13
|
+
from sglang.srt.compilation.compilation_config import CompilationConfig
|
|
14
|
+
from sglang.srt.compilation.compilation_counter import compilation_counter
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def weak_ref_tensor(tensor: Any) -> Any:
|
|
20
|
+
"""
|
|
21
|
+
Create a weak reference to a tensor.
|
|
22
|
+
The new tensor will share the same data as the original tensor,
|
|
23
|
+
but will not keep the original tensor alive.
|
|
24
|
+
"""
|
|
25
|
+
if isinstance(tensor, torch.Tensor):
|
|
26
|
+
# TODO(yuwei): introduce weak_ref_tensor from sgl_kernel
|
|
27
|
+
return torch.ops.jit_weak_ref_tensor.weak_ref_tensor(tensor)
|
|
28
|
+
return tensor
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def weak_ref_tensors(
|
|
32
|
+
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
|
33
|
+
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
|
34
|
+
"""
|
|
35
|
+
Convenience function to create weak references to tensors,
|
|
36
|
+
for single tensor, list of tensors or tuple of tensors.
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(tensors, torch.Tensor):
|
|
39
|
+
return weak_ref_tensor(tensors)
|
|
40
|
+
if isinstance(tensors, list):
|
|
41
|
+
return [weak_ref_tensor(t) for t in tensors]
|
|
42
|
+
if isinstance(tensors, tuple):
|
|
43
|
+
return tuple(weak_ref_tensor(t) for t in tensors)
|
|
44
|
+
raise ValueError("Invalid type for tensors")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclasses.dataclass
|
|
48
|
+
class ConcreteSizeEntry:
|
|
49
|
+
runtime_shape: int
|
|
50
|
+
need_to_compile: bool # the size is in compile_sizes
|
|
51
|
+
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
|
52
|
+
|
|
53
|
+
compiled: bool = False
|
|
54
|
+
runnable: Callable = None # type: ignore
|
|
55
|
+
num_finished_warmup: int = 0
|
|
56
|
+
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
|
57
|
+
output: Optional[Any] = None
|
|
58
|
+
|
|
59
|
+
# for cudagraph debugging, track the input addresses
|
|
60
|
+
# during capture, and check if they are the same during replay
|
|
61
|
+
input_addresses: Optional[list[int]] = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class CUDAPiecewiseBackend:
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
graph: fx.GraphModule,
|
|
69
|
+
compile_config: CompilationConfig,
|
|
70
|
+
inductor_config: dict[str, Any],
|
|
71
|
+
graph_pool: Any,
|
|
72
|
+
piecewise_compile_index: int,
|
|
73
|
+
total_piecewise_compiles: int,
|
|
74
|
+
sym_shape_indices: list[int],
|
|
75
|
+
compiled_graph_for_general_shape: Callable,
|
|
76
|
+
sglang_backend,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
The backend for piecewise compilation.
|
|
80
|
+
It mainly handles the compilation and cudagraph capturing.
|
|
81
|
+
|
|
82
|
+
We will compile `self.graph` once for the general shape,
|
|
83
|
+
and then compile for different shapes specified in
|
|
84
|
+
`compilation_config.compile_sizes`.
|
|
85
|
+
|
|
86
|
+
Independently, we will capture cudagraph for different shapes.
|
|
87
|
+
|
|
88
|
+
If a shape needs both compilation and cudagraph, we will
|
|
89
|
+
compile it first, and then capture cudagraph.
|
|
90
|
+
"""
|
|
91
|
+
self.graph = graph
|
|
92
|
+
self.inductor_config = inductor_config
|
|
93
|
+
self.graph_pool = graph_pool
|
|
94
|
+
self.piecewise_compile_index = piecewise_compile_index
|
|
95
|
+
self.total_piecewise_compiles = total_piecewise_compiles
|
|
96
|
+
self.sglang_backend = sglang_backend
|
|
97
|
+
|
|
98
|
+
self.is_first_graph = piecewise_compile_index == 0
|
|
99
|
+
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
|
100
|
+
|
|
101
|
+
self.compile_sizes: set[int] = set([])
|
|
102
|
+
self.compile_config = compile_config
|
|
103
|
+
self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())
|
|
104
|
+
|
|
105
|
+
self.first_run_finished = False
|
|
106
|
+
|
|
107
|
+
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
|
108
|
+
|
|
109
|
+
self.sym_shape_indices = sym_shape_indices
|
|
110
|
+
|
|
111
|
+
self.is_debugging_mode = True
|
|
112
|
+
|
|
113
|
+
# the entries for different shapes that we need to either
|
|
114
|
+
# compile or capture cudagraph
|
|
115
|
+
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
|
116
|
+
|
|
117
|
+
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
|
118
|
+
# and updates during the compilation process, so we need to copy it
|
|
119
|
+
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
|
120
|
+
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
|
121
|
+
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
|
122
|
+
runtime_shape=shape,
|
|
123
|
+
need_to_compile=shape in self.compile_sizes,
|
|
124
|
+
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def check_for_ending_compilation(self):
|
|
128
|
+
if self.is_last_graph and not self.to_be_compiled_sizes:
|
|
129
|
+
# no specific sizes to compile
|
|
130
|
+
# save the hash of the inductor graph for the next run
|
|
131
|
+
self.sglang_backend.compiler_manager.save_to_file()
|
|
132
|
+
|
|
133
|
+
def __call__(self, *args) -> Any:
|
|
134
|
+
if not self.first_run_finished:
|
|
135
|
+
self.first_run_finished = True
|
|
136
|
+
self.check_for_ending_compilation()
|
|
137
|
+
return self.compiled_graph_for_general_shape(*args)
|
|
138
|
+
runtime_shape = args[self.sym_shape_indices[0]]
|
|
139
|
+
if runtime_shape not in self.concrete_size_entries:
|
|
140
|
+
# we don't need to do anything for this shape
|
|
141
|
+
return self.compiled_graph_for_general_shape(*args)
|
|
142
|
+
|
|
143
|
+
entry = self.concrete_size_entries[runtime_shape]
|
|
144
|
+
|
|
145
|
+
if entry.runnable is None:
|
|
146
|
+
entry.runnable = self.compiled_graph_for_general_shape
|
|
147
|
+
|
|
148
|
+
if entry.need_to_compile and not entry.compiled:
|
|
149
|
+
entry.compiled = True
|
|
150
|
+
self.to_be_compiled_sizes.remove(runtime_shape)
|
|
151
|
+
# args are real arguments
|
|
152
|
+
entry.runnable = self.sglang_backend.compiler_manager.compile(
|
|
153
|
+
self.graph,
|
|
154
|
+
args,
|
|
155
|
+
self.inductor_config,
|
|
156
|
+
graph_index=self.piecewise_compile_index,
|
|
157
|
+
num_graphs=self.total_piecewise_compiles,
|
|
158
|
+
runtime_shape=runtime_shape,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# finished compilations for all required shapes
|
|
162
|
+
if self.is_last_graph and not self.to_be_compiled_sizes:
|
|
163
|
+
self.check_for_ending_compilation()
|
|
164
|
+
|
|
165
|
+
# Skip CUDA graphs if this entry doesn't use them OR
|
|
166
|
+
# if we're supposed to skip them globally
|
|
167
|
+
# skip_cuda_graphs = get_forward_context().skip_cuda_graphs
|
|
168
|
+
# if not entry.use_cudagraph or skip_cuda_graphs:
|
|
169
|
+
# return entry.runnable(*args)
|
|
170
|
+
|
|
171
|
+
if entry.cudagraph is None:
|
|
172
|
+
if entry.num_finished_warmup < 1: # noqa
|
|
173
|
+
entry.num_finished_warmup += 1
|
|
174
|
+
return entry.runnable(*args)
|
|
175
|
+
|
|
176
|
+
input_addresses = [
|
|
177
|
+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
|
178
|
+
]
|
|
179
|
+
entry.input_addresses = input_addresses
|
|
180
|
+
cudagraph = torch.cuda.CUDAGraph()
|
|
181
|
+
|
|
182
|
+
with ExitStack() as stack:
|
|
183
|
+
if not self.is_first_graph:
|
|
184
|
+
# during every model forward, we will capture
|
|
185
|
+
# many pieces of cudagraphs (roughly one per layer).
|
|
186
|
+
# running gc again and again across layers will
|
|
187
|
+
# make the cudagraph capture very slow.
|
|
188
|
+
# therefore, we only run gc for the first graph,
|
|
189
|
+
# and disable gc for the rest of the graphs.
|
|
190
|
+
stack.enter_context(patch("gc.collect", lambda: None))
|
|
191
|
+
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
|
192
|
+
|
|
193
|
+
# mind-exploding: carefully manage the reference and memory.
|
|
194
|
+
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
|
195
|
+
# `output` is managed by pytorch's cudagraph pool
|
|
196
|
+
output = entry.runnable(*args)
|
|
197
|
+
if self.is_last_graph:
|
|
198
|
+
# by converting it to weak ref,
|
|
199
|
+
# the original `output` will immediately be released
|
|
200
|
+
# to save memory. It is only safe to do this for
|
|
201
|
+
# the last graph, because the output of the last graph
|
|
202
|
+
# will not be used by any other cuda graph.
|
|
203
|
+
output = weak_ref_tensors(output)
|
|
204
|
+
|
|
205
|
+
# here we always use weak ref for the output
|
|
206
|
+
# to save memory
|
|
207
|
+
entry.output = weak_ref_tensors(output)
|
|
208
|
+
entry.cudagraph = cudagraph
|
|
209
|
+
|
|
210
|
+
compilation_counter.num_cudagraph_captured += 1
|
|
211
|
+
|
|
212
|
+
# important: we need to return the output, rather than
|
|
213
|
+
# the weak ref of the output, so that pytorch can correctly
|
|
214
|
+
# manage the memory during cuda graph capture
|
|
215
|
+
return output
|
|
216
|
+
|
|
217
|
+
if self.is_debugging_mode:
|
|
218
|
+
# check if the input addresses are the same
|
|
219
|
+
new_input_addresses = [
|
|
220
|
+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
|
221
|
+
]
|
|
222
|
+
assert new_input_addresses == entry.input_addresses, (
|
|
223
|
+
"Input addresses for cudagraphs are different during replay."
|
|
224
|
+
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
entry.cudagraph.replay()
|
|
228
|
+
return entry.output
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import operator
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from typing import Optional, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
10
|
+
|
|
11
|
+
from sglang.srt.compilation.fx_utils import is_func
|
|
12
|
+
from sglang.srt.compilation.inductor_pass import SGLangInductorPass
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FixFunctionalizationPass(SGLangInductorPass):
|
|
18
|
+
"""
|
|
19
|
+
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
|
20
|
+
After this pass, DCE (dead-code elimination) should never be run,
|
|
21
|
+
as de-functionalized nodes may appear as dead code.
|
|
22
|
+
|
|
23
|
+
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __call__(self, graph: torch.fx.Graph):
|
|
27
|
+
self.begin()
|
|
28
|
+
self.dump_graph(graph, "before_fix_functionalization")
|
|
29
|
+
|
|
30
|
+
self.nodes_to_remove: list[torch.fx.Node] = []
|
|
31
|
+
count = 0
|
|
32
|
+
for node in graph.nodes:
|
|
33
|
+
if not is_func(node, auto_functionalized):
|
|
34
|
+
continue # Avoid deep if-elif nesting
|
|
35
|
+
count += 1
|
|
36
|
+
|
|
37
|
+
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
|
38
|
+
|
|
39
|
+
# Remove the nodes all at once
|
|
40
|
+
count_removed = len(self.nodes_to_remove)
|
|
41
|
+
for node in self.nodes_to_remove:
|
|
42
|
+
graph.erase_node(node)
|
|
43
|
+
|
|
44
|
+
logger.debug(
|
|
45
|
+
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
|
46
|
+
)
|
|
47
|
+
self.dump_graph(graph, "after_fix_functionalization")
|
|
48
|
+
self.end_and_log()
|
|
49
|
+
|
|
50
|
+
def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
|
|
51
|
+
"""
|
|
52
|
+
Stage a node (or nodes) for removal at the end of the pass.
|
|
53
|
+
"""
|
|
54
|
+
if isinstance(node_or_nodes, torch.fx.Node):
|
|
55
|
+
self.nodes_to_remove.append(node_or_nodes)
|
|
56
|
+
else:
|
|
57
|
+
self.nodes_to_remove.extend(node_or_nodes)
|
|
58
|
+
|
|
59
|
+
def defunctionalize(
|
|
60
|
+
self,
|
|
61
|
+
graph: torch.fx.Graph,
|
|
62
|
+
node: torch.fx.Node,
|
|
63
|
+
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
|
64
|
+
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
De-functionalize a node by replacing it with a call to the original.
|
|
68
|
+
It also replaces the getitem users with the mutated arguments.
|
|
69
|
+
See replace_users_with_mutated_args and insert_defunctionalized.
|
|
70
|
+
"""
|
|
71
|
+
self.replace_users_with_mutated_args(node, mutated_args)
|
|
72
|
+
self.insert_defunctionalized(graph, node, args=args)
|
|
73
|
+
self._remove(node)
|
|
74
|
+
|
|
75
|
+
def replace_users_with_mutated_args(
|
|
76
|
+
self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Replace all getitem users of the auto-functionalized node with the
|
|
80
|
+
mutated arguments.
|
|
81
|
+
:param node: The auto-functionalized node
|
|
82
|
+
:param mutated_args: The mutated arguments, indexed by getitem index.
|
|
83
|
+
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
|
84
|
+
"""
|
|
85
|
+
for idx, user in self.getitem_users(node).items():
|
|
86
|
+
arg = mutated_args[idx]
|
|
87
|
+
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
|
88
|
+
user.replace_all_uses_with(arg)
|
|
89
|
+
self._remove(user)
|
|
90
|
+
|
|
91
|
+
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
|
92
|
+
"""
|
|
93
|
+
Returns the operator.getitem users of the auto-functionalized node,
|
|
94
|
+
indexed by the index they are getting.
|
|
95
|
+
"""
|
|
96
|
+
users = {}
|
|
97
|
+
for user in node.users:
|
|
98
|
+
if is_func(user, operator.getitem):
|
|
99
|
+
idx = user.args[1]
|
|
100
|
+
users[idx] = user
|
|
101
|
+
return users
|
|
102
|
+
|
|
103
|
+
def insert_defunctionalized(
|
|
104
|
+
self,
|
|
105
|
+
graph: torch.fx.Graph,
|
|
106
|
+
node: torch.fx.Node,
|
|
107
|
+
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
Insert a new defunctionalized node into the graph before node.
|
|
111
|
+
If one of the kwargs is 'out', provide args directly,
|
|
112
|
+
as node.kwargs cannot be used.
|
|
113
|
+
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
|
114
|
+
|
|
115
|
+
:param graph: Graph to insert the defunctionalized node into
|
|
116
|
+
:param node: The auto-functionalized node to defunctionalize
|
|
117
|
+
:param args: If we cannot use kwargs, specify args directly.
|
|
118
|
+
If an arg is a string, `node.kwargs[arg]` is used.
|
|
119
|
+
""" # noqa: E501
|
|
120
|
+
assert is_func(
|
|
121
|
+
node, auto_functionalized
|
|
122
|
+
), f"node must be auto-functionalized, is {node} instead"
|
|
123
|
+
|
|
124
|
+
# Create a new call to the original function
|
|
125
|
+
with graph.inserting_before(node):
|
|
126
|
+
function = node.args[0]
|
|
127
|
+
if args is None:
|
|
128
|
+
graph.call_function(function, kwargs=node.kwargs)
|
|
129
|
+
else:
|
|
130
|
+
# Args passed as strings refer to items in node.kwargs
|
|
131
|
+
args = tuple(
|
|
132
|
+
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
|
133
|
+
)
|
|
134
|
+
graph.call_function(function, args=args)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
|
|
2
|
+
|
|
3
|
+
import operator
|
|
4
|
+
from collections.abc import Iterable, Iterator
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from torch import fx
|
|
8
|
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
9
|
+
from torch._ops import OpOverload
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def is_func(node: fx.Node, target) -> bool:
|
|
13
|
+
return node.op == "call_function" and node.target == target
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
|
17
|
+
return is_func(node, auto_functionalized) and node.args[0] == op
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Returns the first specified node with the given op (if it exists)
|
|
21
|
+
def find_specified_fn_maybe(
|
|
22
|
+
nodes: Iterable[fx.Node], op: OpOverload
|
|
23
|
+
) -> Optional[fx.Node]:
|
|
24
|
+
for node in nodes:
|
|
25
|
+
if node.target == op:
|
|
26
|
+
return node
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Returns the first specified node with the given op
|
|
31
|
+
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
|
32
|
+
node = find_specified_fn_maybe(nodes, op)
|
|
33
|
+
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
|
34
|
+
return node
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Returns the first auto_functionalized node with the given op (if it exists)
|
|
38
|
+
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
|
|
39
|
+
for node in nodes:
|
|
40
|
+
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
|
41
|
+
return node
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Returns the first auto_functionalized node with the given op
|
|
46
|
+
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
|
47
|
+
node = find_auto_fn_maybe(nodes, op)
|
|
48
|
+
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
|
49
|
+
return node
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Returns the getitem node that extracts the idx-th element from node
|
|
53
|
+
# (if it exists)
|
|
54
|
+
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
|
55
|
+
for user in node.users:
|
|
56
|
+
if is_func(user, operator.getitem) and user.args[1] == idx:
|
|
57
|
+
return user
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# Returns the getitem node that extracts the idx-th element from node
|
|
62
|
+
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
|
63
|
+
ret = find_getitem_maybe(node, idx)
|
|
64
|
+
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
|
65
|
+
return ret
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# An auto-functionalization-aware utility for finding nodes with a specific op
|
|
69
|
+
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
|
70
|
+
if not op._schema.is_mutable:
|
|
71
|
+
yield from graph.find_nodes(op="call_function", target=op)
|
|
72
|
+
|
|
73
|
+
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
|
74
|
+
if n.args[0] == op:
|
|
75
|
+
yield n
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Asserts that the node only has one user and returns it
|
|
79
|
+
# Even if a node has only 1 user, it might share storage with another node,
|
|
80
|
+
# which might need to be taken into account.
|
|
81
|
+
def get_only_user(node: fx.Node) -> fx.Node:
|
|
82
|
+
assert len(node.users) == 1
|
|
83
|
+
return next(iter(node.users))
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
import types
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
from typing import Any, Callable, Optional, Union
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import fx
|
|
14
|
+
from torch._dynamo.utils import lazy_format_graph_code
|
|
15
|
+
from torch._inductor.custom_graph_pass import CustomGraphPass
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
_pass_context = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PassContext:
|
|
23
|
+
|
|
24
|
+
def __init__(self, runtime_shape: Optional[int]):
|
|
25
|
+
self.runtime_shape = runtime_shape
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_pass_context() -> PassContext:
|
|
29
|
+
"""Get the current pass context."""
|
|
30
|
+
assert _pass_context is not None
|
|
31
|
+
return _pass_context
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@contextmanager
|
|
35
|
+
def pass_context(runtime_shape: Optional[int]):
|
|
36
|
+
"""A context manager that stores the current pass context,
|
|
37
|
+
usually it is a list of sizes to specialize.
|
|
38
|
+
"""
|
|
39
|
+
global _pass_context
|
|
40
|
+
prev_context = _pass_context
|
|
41
|
+
_pass_context = PassContext(runtime_shape)
|
|
42
|
+
try:
|
|
43
|
+
yield
|
|
44
|
+
finally:
|
|
45
|
+
_pass_context = prev_context
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class InductorPass(CustomGraphPass):
|
|
49
|
+
"""
|
|
50
|
+
A custom graph pass that uses a hash of its source as the UUID.
|
|
51
|
+
This is defined as a convenience and should work in most cases.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def uuid(self) -> Any:
|
|
55
|
+
"""
|
|
56
|
+
Provide a unique identifier for the pass, used in Inductor code cache.
|
|
57
|
+
This should depend on the pass implementation, so that changes to the
|
|
58
|
+
pass result in recompilation.
|
|
59
|
+
By default, the object source is hashed.
|
|
60
|
+
"""
|
|
61
|
+
return InductorPass.hash_source(self)
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def hash_source(*srcs: Union[str, Any]):
|
|
65
|
+
"""
|
|
66
|
+
Utility method to hash the sources of functions or objects.
|
|
67
|
+
:param srcs: strings or objects to add to the hash.
|
|
68
|
+
Objects and functions have their source inspected.
|
|
69
|
+
:return:
|
|
70
|
+
"""
|
|
71
|
+
hasher = hashlib.sha256()
|
|
72
|
+
for src in srcs:
|
|
73
|
+
if isinstance(src, str):
|
|
74
|
+
src_str = src
|
|
75
|
+
elif isinstance(src, types.FunctionType):
|
|
76
|
+
src_str = inspect.getsource(src)
|
|
77
|
+
else:
|
|
78
|
+
src_str = inspect.getsource(src.__class__)
|
|
79
|
+
hasher.update(src_str.encode("utf-8"))
|
|
80
|
+
return hasher.hexdigest()
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def hash_dict(dict_: dict[Any, Any]):
|
|
84
|
+
"""
|
|
85
|
+
Utility method to hash a dictionary, can alternatively be used for uuid.
|
|
86
|
+
:return: A sha256 hash of the json rep of the dictionary.
|
|
87
|
+
"""
|
|
88
|
+
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
|
89
|
+
return hashlib.sha256(encoded).hexdigest()
|
|
90
|
+
|
|
91
|
+
def is_applicable_for_shape(self, shape: Optional[int]):
|
|
92
|
+
return True
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class CallableInductorPass(InductorPass):
|
|
96
|
+
"""
|
|
97
|
+
This class is a wrapper for a callable that automatically provides an
|
|
98
|
+
implementation of the UUID.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None
|
|
103
|
+
):
|
|
104
|
+
self.callable = callable
|
|
105
|
+
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
|
106
|
+
|
|
107
|
+
def __call__(self, graph: torch.fx.Graph):
|
|
108
|
+
self.callable(graph)
|
|
109
|
+
|
|
110
|
+
def uuid(self) -> Any:
|
|
111
|
+
return self._uuid
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class SGLangInductorPass(InductorPass):
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
):
|
|
119
|
+
self.pass_name = self.__class__.__name__
|
|
120
|
+
|
|
121
|
+
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
|
122
|
+
lazy_format_graph_code(stage, graph.owning_module)
|
|
123
|
+
|
|
124
|
+
def begin(self):
|
|
125
|
+
self._start_time = time.perf_counter_ns()
|
|
126
|
+
|
|
127
|
+
def end_and_log(self):
|
|
128
|
+
self._end_time = time.perf_counter_ns()
|
|
129
|
+
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
|
130
|
+
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class PrinterInductorPass(SGLangInductorPass):
|
|
134
|
+
|
|
135
|
+
def __init__(self, name: str):
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.name = name
|
|
138
|
+
|
|
139
|
+
def __call__(self, graph: torch.fx.Graph):
|
|
140
|
+
self.dump_graph(graph, self.name)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from torch import fx as fx
|
|
6
|
+
|
|
7
|
+
from sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass
|
|
8
|
+
from sglang.srt.compilation.inductor_pass import (
|
|
9
|
+
CustomGraphPass,
|
|
10
|
+
InductorPass,
|
|
11
|
+
SGLangInductorPass,
|
|
12
|
+
get_pass_context,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PostGradPassManager(CustomGraphPass):
|
|
19
|
+
"""
|
|
20
|
+
The pass manager for post-grad passes.
|
|
21
|
+
It handles configuration, adding custom passes, and running passes.
|
|
22
|
+
It supports uuid for the Inductor code cache. That includes torch<2.6
|
|
23
|
+
support using pickling (in .inductor_pass.CustomGraphPass).
|
|
24
|
+
|
|
25
|
+
The order of the post-grad post-passes is:
|
|
26
|
+
1. passes (constructor parameter)
|
|
27
|
+
2. default passes (NoopEliminationPass, FusionPass)
|
|
28
|
+
3. config["post_grad_custom_post_pass"] (if it exists)
|
|
29
|
+
4. fix_functionalization
|
|
30
|
+
This way, all passes operate on a functionalized graph.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
self.passes: list[SGLangInductorPass] = []
|
|
35
|
+
|
|
36
|
+
def __call__(self, graph: fx.Graph):
|
|
37
|
+
shape = get_pass_context().runtime_shape
|
|
38
|
+
for pass_ in self.passes:
|
|
39
|
+
if pass_.is_applicable_for_shape(shape):
|
|
40
|
+
pass_(graph)
|
|
41
|
+
|
|
42
|
+
# always run fix_functionalization last
|
|
43
|
+
self.fix_functionalization(graph)
|
|
44
|
+
|
|
45
|
+
def configure(
|
|
46
|
+
self,
|
|
47
|
+
):
|
|
48
|
+
self.pass_config = dict()
|
|
49
|
+
self.fix_functionalization = FixFunctionalizationPass()
|
|
50
|
+
|
|
51
|
+
def add(self, pass_: InductorPass):
|
|
52
|
+
assert isinstance(pass_, InductorPass)
|
|
53
|
+
self.passes.append(pass_)
|
|
54
|
+
|
|
55
|
+
def uuid(self):
|
|
56
|
+
"""
|
|
57
|
+
The PostGradPassManager is set as a custom pass in the Inductor and
|
|
58
|
+
affects compilation caching. Its uuid depends on the UUIDs of all
|
|
59
|
+
dependent passes and the pass config. See InductorPass for more info.
|
|
60
|
+
"""
|
|
61
|
+
pass_manager_uuid = "fshdakhsa"
|
|
62
|
+
state = {"pass_config": pass_manager_uuid, "passes": []}
|
|
63
|
+
for pass_ in self.passes:
|
|
64
|
+
state["passes"].append(pass_.uuid())
|
|
65
|
+
state["passes"].append(self.fix_functionalization.uuid())
|
|
66
|
+
return InductorPass.hash_dict(state)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, List, Optional
|
|
4
|
+
|
|
5
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ForwardContext:
|
|
10
|
+
def __init__(self):
|
|
11
|
+
self.forward_batch = None
|
|
12
|
+
self.attention_layer = None
|
|
13
|
+
|
|
14
|
+
def set_forward_batch(self, forward_batch: ForwardBatch):
|
|
15
|
+
self.forward_batch = forward_batch
|
|
16
|
+
|
|
17
|
+
def set_attention_layers(self, layers: List[Any]):
|
|
18
|
+
self.attention_layers = layers
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_forward_context: Optional[ForwardContext] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_forward_context() -> Optional[ForwardContext]:
|
|
25
|
+
if _forward_context is None:
|
|
26
|
+
return None
|
|
27
|
+
return _forward_context
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@contextmanager
|
|
31
|
+
def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any]):
|
|
32
|
+
global _forward_context
|
|
33
|
+
prev_forward_context = _forward_context
|
|
34
|
+
_forward_context = ForwardContext()
|
|
35
|
+
_forward_context.set_forward_batch(forward_batch)
|
|
36
|
+
_forward_context.set_attention_layers(attention_layers)
|
|
37
|
+
try:
|
|
38
|
+
yield
|
|
39
|
+
finally:
|
|
40
|
+
_forward_context = prev_forward_context
|