sglang 0.4.6.post5__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_offline_throughput.py +10 -4
- sglang/bench_one_batch_server.py +67 -11
- sglang/bench_serving.py +86 -75
- sglang/lang/backend/runtime_endpoint.py +24 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/profiler.py +167 -0
- sglang/srt/_custom_ops.py +34 -0
- sglang/srt/configs/internvl.py +8 -12
- sglang/srt/configs/model_config.py +33 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -2
- sglang/srt/constrained/llguidance_backend.py +9 -8
- sglang/srt/constrained/outlines_backend.py +5 -4
- sglang/srt/constrained/xgrammar_backend.py +18 -18
- sglang/srt/conversation.py +52 -8
- sglang/srt/custom_op.py +38 -3
- sglang/srt/debug_utils.py +74 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -0
- sglang/srt/disaggregation/common/conn.py +407 -0
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +261 -52
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +16 -9
- sglang/srt/disaggregation/kv_events.py +60 -5
- sglang/srt/disaggregation/launch_lb.py +140 -0
- sglang/srt/disaggregation/mini_lb.py +29 -48
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +446 -149
- sglang/srt/disaggregation/mooncake/transfer_engine.py +32 -16
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +134 -437
- sglang/srt/disaggregation/prefill.py +130 -43
- sglang/srt/disaggregation/utils.py +127 -86
- sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
- sglang/srt/distributed/parallel_state.py +52 -5
- sglang/srt/entrypoints/EngineBase.py +6 -0
- sglang/srt/entrypoints/engine.py +116 -5
- sglang/srt/entrypoints/http_server.py +28 -4
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +138 -86
- sglang/srt/function_call/deepseekv3_detector.py +54 -6
- sglang/srt/function_call/ebnf_composer.py +33 -19
- sglang/srt/function_call/function_call_parser.py +27 -0
- sglang/srt/function_call/llama32_detector.py +33 -14
- sglang/srt/function_call/mistral_detector.py +73 -26
- sglang/srt/function_call/pythonic_detector.py +86 -20
- sglang/srt/function_call/qwen25_detector.py +64 -10
- sglang/srt/function_call/utils.py +17 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +503 -125
- sglang/srt/layers/attention/base_attn_backend.py +4 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +40 -34
- sglang/srt/layers/attention/flashattention_backend.py +137 -63
- sglang/srt/layers/attention/flashinfer_backend.py +46 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +59 -25
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/intel_amx_backend.py +128 -0
- sglang/srt/layers/attention/tbo_backend.py +232 -0
- sglang/srt/layers/attention/torch_native_backend.py +3 -0
- sglang/srt/layers/attention/triton_backend.py +304 -65
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +281 -197
- sglang/srt/layers/dp_attention.py +6 -5
- sglang/srt/layers/layernorm.py +30 -19
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/cutlass_moe.py +170 -7
- sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +136 -72
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +24 -45
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +221 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -4
- sglang/srt/layers/moe/topk.py +60 -26
- sglang/srt/layers/multimodal.py +3 -3
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/blockwise_int8.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +69 -127
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +28 -23
- sglang/srt/layers/quantization/fp8_kernel.py +156 -75
- sglang/srt/layers/quantization/fp8_utils.py +250 -69
- sglang/srt/layers/quantization/modelopt_quant.py +334 -7
- sglang/srt/layers/quantization/moe_wna16.py +3 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +6 -12
- sglang/srt/layers/sampler.py +80 -79
- sglang/srt/layers/utils.py +6 -0
- sglang/srt/lora/layers.py +12 -15
- sglang/srt/lora/lora.py +49 -5
- sglang/srt/lora/lora_manager.py +98 -39
- sglang/srt/lora/mem_pool.py +28 -21
- sglang/srt/lora/utils.py +17 -13
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +13 -5
- sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
- sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
- sglang/srt/managers/{deepseek_eplb.py → eplb_algorithms/deepseek_vec.py} +5 -7
- sglang/srt/managers/eplb_manager.py +55 -14
- sglang/srt/managers/expert_distribution.py +220 -46
- sglang/srt/managers/expert_location.py +110 -56
- sglang/srt/managers/expert_location_dispatch.py +23 -6
- sglang/srt/managers/io_struct.py +43 -8
- sglang/srt/managers/mm_utils.py +88 -38
- sglang/srt/managers/multimodal_processors/base_processor.py +190 -18
- sglang/srt/managers/multimodal_processors/gemma3.py +4 -31
- sglang/srt/managers/multimodal_processors/internvl.py +4 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +15 -34
- sglang/srt/managers/multimodal_processors/minicpm.py +2 -1
- sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -64
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +173 -38
- sglang/srt/managers/scheduler.py +376 -127
- sglang/srt/managers/tokenizer_manager.py +163 -19
- sglang/srt/managers/utils.py +0 -4
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +191 -113
- sglang/srt/model_executor/expert_location_updater.py +157 -22
- sglang/srt/model_executor/forward_batch_info.py +52 -22
- sglang/srt/model_executor/model_runner.py +102 -62
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/utils.py +67 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +1 -1
- sglang/srt/models/deepseek_v2.py +623 -290
- sglang/srt/models/gemma3_causal.py +7 -0
- sglang/srt/models/gemma3_mm.py +19 -14
- sglang/srt/models/idefics2.py +342 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/kimi_vl.py +4 -4
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/minicpmo.py +2 -5
- sglang/srt/models/minicpmv.py +3 -295
- sglang/srt/models/phi4mm.py +512 -0
- sglang/srt/models/qwen2.py +38 -9
- sglang/srt/models/qwen2_5_vl.py +3 -9
- sglang/srt/models/qwen2_eagle.py +4 -1
- sglang/srt/models/qwen2_moe.py +58 -191
- sglang/srt/models/qwen2_vl.py +3 -9
- sglang/srt/models/qwen3.py +41 -10
- sglang/srt/models/qwen3_moe.py +230 -191
- sglang/srt/models/registry.py +9 -1
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/transformers.py +291 -0
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +248 -28
- sglang/srt/openai_api/protocol.py +68 -3
- sglang/srt/openai_api/utils.py +172 -0
- sglang/srt/operations.py +37 -2
- sglang/srt/operations_strategy.py +200 -24
- sglang/srt/sampling/sampling_batch_info.py +37 -1
- sglang/srt/sampling/sampling_params.py +4 -1
- sglang/srt/server_args.py +381 -209
- sglang/srt/speculative/build_eagle_tree.py +9 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +12 -14
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +256 -0
- sglang/srt/speculative/eagle_utils.py +440 -200
- sglang/srt/speculative/eagle_worker.py +234 -63
- sglang/srt/two_batch_overlap.py +637 -0
- sglang/srt/utils.py +187 -7
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +54 -10
- sglang/test/send_one.py +4 -0
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_cutlass_moe.py +3 -3
- sglang/test/test_fp4_moe.py +248 -0
- sglang/test/test_utils.py +82 -7
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +17 -14
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +359 -321
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +1 -1
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import logging
|
3
4
|
import os
|
5
|
+
import time
|
4
6
|
from dataclasses import dataclass
|
5
|
-
from typing import
|
7
|
+
from typing import List, Optional
|
6
8
|
|
7
9
|
import torch
|
8
10
|
import torch.nn.functional as F
|
@@ -12,6 +14,7 @@ import triton.language as tl
|
|
12
14
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
13
15
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
14
16
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
17
|
+
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
15
18
|
from sglang.srt.managers.schedule_batch import (
|
16
19
|
Req,
|
17
20
|
ScheduleBatch,
|
@@ -19,10 +22,8 @@ from sglang.srt.managers.schedule_batch import (
|
|
19
22
|
global_server_args_dict,
|
20
23
|
)
|
21
24
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
22
|
-
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
23
|
-
from sglang.srt.
|
24
|
-
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
25
|
-
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
25
|
+
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
26
|
+
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
26
27
|
|
27
28
|
if is_cuda():
|
28
29
|
from sgl_kernel import (
|
@@ -31,18 +32,19 @@ if is_cuda():
|
|
31
32
|
tree_speculative_sampling_target_only,
|
32
33
|
verify_tree_greedy,
|
33
34
|
)
|
35
|
+
from sgl_kernel.top_k import fast_topk
|
34
36
|
elif is_hip():
|
35
37
|
from sgl_kernel import verify_tree_greedy
|
36
38
|
|
37
|
-
if TYPE_CHECKING:
|
38
|
-
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
39
|
-
|
40
|
-
import logging
|
41
39
|
|
42
40
|
logger = logging.getLogger(__name__)
|
43
41
|
|
44
42
|
|
43
|
+
# Simulate acceptance length for benchmarking purposes
|
45
44
|
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
45
|
+
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
|
46
|
+
|
47
|
+
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
46
48
|
|
47
49
|
|
48
50
|
@dataclass
|
@@ -66,8 +68,6 @@ class EagleDraftInput:
|
|
66
68
|
kv_indptr: torch.Tensor = None
|
67
69
|
kv_indices: torch.Tensor = None
|
68
70
|
|
69
|
-
all_padding_lens: Optional[torch.Tensor] = None
|
70
|
-
|
71
71
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
72
72
|
# Prefill only generate 1 token.
|
73
73
|
assert len(self.verified_id) == len(batch.seq_lens)
|
@@ -85,32 +85,29 @@ class EagleDraftInput:
|
|
85
85
|
batch: ScheduleBatch,
|
86
86
|
speculative_num_steps: int,
|
87
87
|
):
|
88
|
-
|
89
|
-
|
90
|
-
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
88
|
+
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
89
|
+
batch.input_ids = self.verified_id
|
90
|
+
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
91
91
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
92
92
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
93
93
|
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
94
|
-
|
94
|
+
batch.return_logprob = False
|
95
|
+
batch.return_hidden_states = False
|
95
96
|
|
96
|
-
self.
|
97
|
-
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
97
|
+
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
98
98
|
self.accept_length.add_(1)
|
99
|
+
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
|
100
|
+
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
99
101
|
|
100
|
-
|
101
|
-
|
102
|
+
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
|
103
|
+
batch.input_ids,
|
102
104
|
batch.seq_lens,
|
103
105
|
self.accept_length,
|
104
|
-
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
105
106
|
self.positions,
|
106
|
-
|
107
|
-
next_power_of_2(speculative_num_steps + 1),
|
107
|
+
self.verified_id,
|
108
|
+
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
|
108
109
|
)
|
109
110
|
|
110
|
-
batch.seq_lens_sum = sum(seq_lens_cpu)
|
111
|
-
batch.input_ids = self.verified_id
|
112
|
-
self.verified_id = new_verified_id
|
113
|
-
|
114
111
|
def generate_attn_arg_prefill(
|
115
112
|
self,
|
116
113
|
req_pool_indices: torch.Tensor,
|
@@ -119,15 +116,17 @@ class EagleDraftInput:
|
|
119
116
|
req_to_token: torch.Tensor,
|
120
117
|
):
|
121
118
|
bs = self.accept_length.numel()
|
122
|
-
|
123
119
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
124
120
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
125
|
-
|
126
121
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
127
122
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
128
123
|
|
129
|
-
|
130
|
-
|
124
|
+
if paged_kernel_lens_sum is None:
|
125
|
+
paged_kernel_lens_sum = cum_kv_seq_len[-1]
|
126
|
+
|
127
|
+
kv_indices = torch.empty(
|
128
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
129
|
+
)
|
131
130
|
|
132
131
|
create_flashinfer_kv_indices_triton[(bs,)](
|
133
132
|
req_to_token,
|
@@ -138,7 +137,6 @@ class EagleDraftInput:
|
|
138
137
|
kv_indices,
|
139
138
|
req_to_token.size(1),
|
140
139
|
)
|
141
|
-
|
142
140
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
143
141
|
|
144
142
|
def filter_batch(self, new_indices: torch.Tensor):
|
@@ -187,56 +185,14 @@ class EagleVerifyInput:
|
|
187
185
|
retrive_next_token: torch.Tensor
|
188
186
|
retrive_next_sibling: torch.Tensor
|
189
187
|
retrive_cum_len: torch.Tensor
|
190
|
-
draft_token_num: int
|
191
188
|
spec_steps: int
|
189
|
+
topk: int
|
190
|
+
draft_token_num: int
|
192
191
|
capture_hidden_mode: CaptureHiddenMode
|
192
|
+
seq_lens_sum: int
|
193
|
+
seq_lens_cpu: torch.Tensor
|
193
194
|
grammar: BaseGrammarObject = None
|
194
195
|
|
195
|
-
@classmethod
|
196
|
-
def create(
|
197
|
-
cls,
|
198
|
-
verified_id: torch.Tensor,
|
199
|
-
score_list: List[torch.Tensor],
|
200
|
-
token_list: List[torch.Tensor],
|
201
|
-
parents_list: List[torch.Tensor],
|
202
|
-
seq_lens: torch.Tensor,
|
203
|
-
seq_lens_sum: int,
|
204
|
-
topk: int,
|
205
|
-
spec_steps: int,
|
206
|
-
num_verify_tokens: int,
|
207
|
-
):
|
208
|
-
(
|
209
|
-
tree_mask,
|
210
|
-
position,
|
211
|
-
retrive_index,
|
212
|
-
retrive_next_token,
|
213
|
-
retrive_next_sibling,
|
214
|
-
draft_tokens,
|
215
|
-
) = build_tree_kernel_efficient(
|
216
|
-
verified_id,
|
217
|
-
score_list,
|
218
|
-
token_list,
|
219
|
-
parents_list,
|
220
|
-
seq_lens,
|
221
|
-
seq_lens_sum,
|
222
|
-
topk,
|
223
|
-
spec_steps,
|
224
|
-
num_verify_tokens,
|
225
|
-
)
|
226
|
-
|
227
|
-
return cls(
|
228
|
-
draft_tokens,
|
229
|
-
tree_mask,
|
230
|
-
position,
|
231
|
-
retrive_index,
|
232
|
-
retrive_next_token,
|
233
|
-
retrive_next_sibling,
|
234
|
-
None,
|
235
|
-
num_verify_tokens,
|
236
|
-
spec_steps,
|
237
|
-
CaptureHiddenMode.FULL,
|
238
|
-
)
|
239
|
-
|
240
196
|
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
241
197
|
batch.input_ids = self.draft_token
|
242
198
|
|
@@ -311,7 +267,7 @@ class EagleVerifyInput:
|
|
311
267
|
logits_output: torch.Tensor,
|
312
268
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
313
269
|
page_size: int,
|
314
|
-
vocab_mask: Optional[torch.Tensor] = None,
|
270
|
+
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
315
271
|
) -> torch.Tensor:
|
316
272
|
"""
|
317
273
|
Verify and find accepted tokens based on logits output and batch
|
@@ -335,6 +291,14 @@ class EagleVerifyInput:
|
|
335
291
|
)
|
336
292
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
337
293
|
|
294
|
+
# Apply the custom logit processors if registered in the sampling info.
|
295
|
+
if sampling_info.has_custom_logit_processor:
|
296
|
+
apply_custom_logit_processor(
|
297
|
+
logits_output.next_token_logits,
|
298
|
+
sampling_info,
|
299
|
+
num_tokens_in_batch=self.draft_token_num,
|
300
|
+
)
|
301
|
+
|
338
302
|
# Apply penalty
|
339
303
|
if sampling_info.penalizer_orchestrator.is_required:
|
340
304
|
# This is a relaxed version of penalties for speculative decoding.
|
@@ -364,11 +328,11 @@ class EagleVerifyInput:
|
|
364
328
|
predicts=predict, # mutable
|
365
329
|
accept_index=accept_index, # mutable
|
366
330
|
accept_token_num=accept_length, # mutable
|
367
|
-
candidates=candidates
|
368
|
-
retrive_index=self.retrive_index
|
369
|
-
retrive_next_token=self.retrive_next_token
|
370
|
-
retrive_next_sibling=self.retrive_next_sibling
|
371
|
-
target_predict=target_predict
|
331
|
+
candidates=candidates,
|
332
|
+
retrive_index=self.retrive_index,
|
333
|
+
retrive_next_token=self.retrive_next_token,
|
334
|
+
retrive_next_sibling=self.retrive_next_sibling,
|
335
|
+
target_predict=target_predict,
|
372
336
|
)
|
373
337
|
else:
|
374
338
|
# apply temperature and get target probs
|
@@ -396,16 +360,23 @@ class EagleVerifyInput:
|
|
396
360
|
draft_probs = torch.zeros(
|
397
361
|
target_probs.shape, dtype=torch.float32, device="cuda"
|
398
362
|
)
|
363
|
+
|
364
|
+
# coins for rejection sampling
|
399
365
|
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
366
|
+
# coins for final sampling
|
367
|
+
coins_for_final_sampling = torch.rand(
|
368
|
+
(bs,), dtype=torch.float32, device="cuda"
|
369
|
+
)
|
400
370
|
tree_speculative_sampling_target_only(
|
401
371
|
predicts=predict, # mutable
|
402
372
|
accept_index=accept_index, # mutable
|
403
373
|
accept_token_num=accept_length, # mutable
|
404
|
-
candidates=candidates
|
405
|
-
retrive_index=self.retrive_index
|
406
|
-
retrive_next_token=self.retrive_next_token
|
407
|
-
retrive_next_sibling=self.retrive_next_sibling
|
374
|
+
candidates=candidates,
|
375
|
+
retrive_index=self.retrive_index,
|
376
|
+
retrive_next_token=self.retrive_next_token,
|
377
|
+
retrive_next_sibling=self.retrive_next_sibling,
|
408
378
|
uniform_samples=coins,
|
379
|
+
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
409
380
|
target_probs=target_probs,
|
410
381
|
draft_probs=draft_probs,
|
411
382
|
threshold_single=global_server_args_dict[
|
@@ -428,8 +399,8 @@ class EagleVerifyInput:
|
|
428
399
|
spec_steps=self.spec_steps,
|
429
400
|
)
|
430
401
|
|
431
|
-
new_accept_index = []
|
432
402
|
unfinished_index = []
|
403
|
+
unfinished_accept_index = []
|
433
404
|
accept_index_cpu = accept_index.tolist()
|
434
405
|
predict_cpu = predict.tolist()
|
435
406
|
has_finished = False
|
@@ -437,12 +408,10 @@ class EagleVerifyInput:
|
|
437
408
|
# Iterate every accepted token and check if req has finished after append the token
|
438
409
|
# should be checked BEFORE free kv cache slots
|
439
410
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
440
|
-
new_accept_index_ = []
|
441
411
|
for j, idx in enumerate(accept_index_row):
|
442
412
|
if idx == -1:
|
443
413
|
break
|
444
414
|
id = predict_cpu[idx]
|
445
|
-
# if not found_finished:
|
446
415
|
req.output_ids.append(id)
|
447
416
|
req.check_finished()
|
448
417
|
if req.finished():
|
@@ -451,8 +420,6 @@ class EagleVerifyInput:
|
|
451
420
|
accept_index[i, j + 1 :] = -1
|
452
421
|
break
|
453
422
|
else:
|
454
|
-
new_accept_index_.append(idx)
|
455
|
-
# update grammar state
|
456
423
|
if req.grammar is not None:
|
457
424
|
try:
|
458
425
|
req.grammar.accept_token(id)
|
@@ -462,50 +429,104 @@ class EagleVerifyInput:
|
|
462
429
|
)
|
463
430
|
raise e
|
464
431
|
if not req.finished():
|
465
|
-
new_accept_index.extend(new_accept_index_)
|
466
432
|
unfinished_index.append(i)
|
433
|
+
if idx == -1:
|
434
|
+
unfinished_accept_index.append(accept_index[i, :j])
|
435
|
+
else:
|
436
|
+
unfinished_accept_index.append(accept_index[i])
|
467
437
|
req.spec_verify_ct += 1
|
468
438
|
|
469
439
|
if has_finished:
|
470
440
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
471
441
|
|
472
442
|
# Free the KV cache for unaccepted tokens
|
443
|
+
# TODO: fuse them
|
473
444
|
accept_index = accept_index[accept_index != -1]
|
474
445
|
verified_id = predict[accept_index]
|
475
446
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
476
447
|
evict_mask[accept_index] = False
|
477
448
|
|
478
|
-
if page_size
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
449
|
+
if page_size == 1:
|
450
|
+
# TODO: boolean array index leads to a device sync. Remove it.
|
451
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
452
|
+
else:
|
453
|
+
if self.topk == 1:
|
454
|
+
# Only evict full empty page. Do not evict partial empty page
|
455
|
+
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
456
|
+
batch.seq_lens,
|
457
|
+
evict_mask,
|
458
|
+
page_size,
|
459
|
+
self.draft_token_num,
|
460
|
+
next_power_of_2(self.draft_token_num),
|
461
|
+
)
|
462
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
463
|
+
else:
|
464
|
+
# Shift the accepted tokens to the beginning.
|
465
|
+
# Only evict the last part
|
466
|
+
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
467
|
+
batch.seq_lens,
|
468
|
+
batch.out_cache_loc,
|
469
|
+
accept_index,
|
470
|
+
accept_length,
|
471
|
+
self.draft_token_num,
|
472
|
+
page_size,
|
473
|
+
)
|
474
|
+
to_free_slots = torch.empty(
|
475
|
+
(to_free_num_slots.sum().item(),),
|
476
|
+
dtype=torch.int64,
|
477
|
+
device=to_free_num_slots.device,
|
478
|
+
)
|
486
479
|
|
487
|
-
|
480
|
+
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
481
|
+
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
482
|
+
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
483
|
+
# to_free_slots: [ 2, 5, 7 8]
|
484
|
+
# to_free_slots also needs to be page-aligned without the first partial page
|
485
|
+
#
|
486
|
+
# split each row of out_cache_loc into two parts.
|
487
|
+
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
488
|
+
# 2. the second part goes to to_free_slots.
|
489
|
+
get_target_cache_loc[(bs,)](
|
490
|
+
tgt_cache_loc,
|
491
|
+
to_free_slots,
|
492
|
+
accept_length,
|
493
|
+
to_free_num_slots,
|
494
|
+
batch.out_cache_loc,
|
495
|
+
self.draft_token_num,
|
496
|
+
next_power_of_2(self.draft_token_num),
|
497
|
+
next_power_of_2(bs),
|
498
|
+
)
|
499
|
+
|
500
|
+
# Free the kv cache
|
501
|
+
token_to_kv_pool_allocator.free(to_free_slots)
|
502
|
+
|
503
|
+
# Copy the kv cache
|
504
|
+
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
505
|
+
tgt_cache_loc, src_cache_loc
|
506
|
+
)
|
488
507
|
|
489
508
|
# Construct EagleVerifyOutput
|
490
509
|
if not has_finished:
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
510
|
+
if page_size == 1 or self.topk == 1:
|
511
|
+
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
512
|
+
assign_req_to_token_pool[(bs,)](
|
513
|
+
batch.req_pool_indices,
|
514
|
+
batch.req_to_token_pool.req_to_token,
|
515
|
+
batch.seq_lens,
|
516
|
+
batch.seq_lens + accept_length + 1,
|
517
|
+
batch.out_cache_loc,
|
518
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
519
|
+
next_power_of_2(bs),
|
520
|
+
)
|
521
|
+
else:
|
522
|
+
batch.out_cache_loc = tgt_cache_loc
|
501
523
|
batch.seq_lens.add_(accept_length + 1)
|
502
|
-
accept_length_cpu = accept_length.tolist()
|
503
524
|
|
504
525
|
draft_input = EagleDraftInput()
|
505
526
|
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
506
527
|
draft_input.verified_id = verified_id
|
507
528
|
draft_input.accept_length = accept_length
|
508
|
-
draft_input.accept_length_cpu =
|
529
|
+
draft_input.accept_length_cpu = accept_length.tolist()
|
509
530
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
510
531
|
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
511
532
|
|
@@ -513,47 +534,66 @@ class EagleVerifyInput:
|
|
513
534
|
draft_input=draft_input,
|
514
535
|
logits_output=logits_output,
|
515
536
|
verified_id=verified_id,
|
516
|
-
accept_length_per_req_cpu=accept_length_cpu,
|
537
|
+
accept_length_per_req_cpu=draft_input.accept_length_cpu,
|
517
538
|
accepted_indices=accept_index,
|
518
539
|
)
|
519
540
|
else:
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
541
|
+
if page_size == 1 or self.topk == 1:
|
542
|
+
assign_req_to_token_pool[(bs,)](
|
543
|
+
batch.req_pool_indices,
|
544
|
+
batch.req_to_token_pool.req_to_token,
|
545
|
+
batch.seq_lens,
|
546
|
+
batch.seq_lens + accept_length + 1,
|
547
|
+
batch.out_cache_loc[accept_index],
|
548
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
549
|
+
next_power_of_2(bs),
|
550
|
+
)
|
551
|
+
batch.seq_lens.add_(accept_length + 1)
|
531
552
|
|
553
|
+
accept_length_cpu = accept_length.tolist()
|
532
554
|
draft_input = EagleDraftInput()
|
533
|
-
if len(
|
534
|
-
|
535
|
-
unfinished_index_device = torch.tensor(
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
draft_input.verified_id = predict[new_accept_index]
|
540
|
-
draft_input.accept_length_cpu = [
|
555
|
+
if len(unfinished_accept_index) > 0:
|
556
|
+
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
557
|
+
unfinished_index_device = torch.tensor(
|
558
|
+
unfinished_index, dtype=torch.int64, device=predict.device
|
559
|
+
)
|
560
|
+
draft_input_accept_length_cpu = [
|
541
561
|
accept_length_cpu[i] for i in unfinished_index
|
542
562
|
]
|
543
|
-
|
544
|
-
|
545
|
-
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
546
|
-
unfinished_index_device
|
547
|
-
]
|
548
|
-
draft_input.req_pool_indices_for_draft_extend = (
|
549
|
-
batch.req_pool_indices[unfinished_index_device]
|
550
|
-
)
|
563
|
+
if page_size == 1 or self.topk == 1:
|
564
|
+
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
|
551
565
|
else:
|
552
|
-
|
553
|
-
|
554
|
-
|
566
|
+
batch.out_cache_loc = torch.empty(
|
567
|
+
len(unfinished_index) + sum(draft_input_accept_length_cpu),
|
568
|
+
dtype=torch.int64,
|
569
|
+
device=predict.device,
|
570
|
+
)
|
571
|
+
accept_length_filter = create_accept_length_filter(
|
572
|
+
accept_length,
|
573
|
+
unfinished_index_device,
|
574
|
+
batch.seq_lens,
|
555
575
|
)
|
556
|
-
|
576
|
+
filter_finished_cache_loc_kernel[(bs,)](
|
577
|
+
batch.out_cache_loc,
|
578
|
+
tgt_cache_loc,
|
579
|
+
accept_length,
|
580
|
+
accept_length_filter,
|
581
|
+
next_power_of_2(bs),
|
582
|
+
next_power_of_2(self.draft_token_num),
|
583
|
+
)
|
584
|
+
|
585
|
+
draft_input.hidden_states = batch.spec_info.hidden_states[
|
586
|
+
unfinished_accept_index
|
587
|
+
]
|
588
|
+
draft_input.verified_id = predict[unfinished_accept_index]
|
589
|
+
draft_input.accept_length_cpu = draft_input_accept_length_cpu
|
590
|
+
draft_input.accept_length = accept_length[unfinished_index_device]
|
591
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
592
|
+
unfinished_index_device
|
593
|
+
]
|
594
|
+
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
595
|
+
unfinished_index_device
|
596
|
+
]
|
557
597
|
|
558
598
|
return EagleVerifyOutput(
|
559
599
|
draft_input=draft_input,
|
@@ -565,26 +605,28 @@ class EagleVerifyInput:
|
|
565
605
|
|
566
606
|
|
567
607
|
@triton.jit
|
568
|
-
def
|
608
|
+
def create_extend_after_decode_spec_info(
|
569
609
|
verified_id,
|
570
|
-
|
571
|
-
|
572
|
-
accept_len_cum,
|
610
|
+
seq_lens,
|
611
|
+
accept_lens,
|
573
612
|
positions,
|
574
613
|
new_verified_id,
|
575
|
-
|
614
|
+
bs_upper: tl.constexpr,
|
576
615
|
):
|
577
616
|
pid = tl.program_id(axis=0)
|
578
|
-
|
579
|
-
seq_length = tl.load(
|
580
|
-
accept_length = tl.load(
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
617
|
+
offsets = tl.arange(0, bs_upper)
|
618
|
+
seq_length = tl.load(seq_lens + pid)
|
619
|
+
accept_length = tl.load(accept_lens + pid)
|
620
|
+
|
621
|
+
accept_len_cumsum = tl.sum(
|
622
|
+
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
|
623
|
+
)
|
624
|
+
positions_ptr = positions + accept_len_cumsum
|
625
|
+
mask = offsets < accept_length
|
626
|
+
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
|
627
|
+
|
628
|
+
accept_len_cumsum += accept_length - 1
|
629
|
+
verified_id_data = tl.load(verified_id + accept_len_cumsum)
|
588
630
|
tl.store(new_verified_id + pid, verified_id_data)
|
589
631
|
|
590
632
|
|
@@ -605,8 +647,8 @@ def assign_req_to_token_pool(
|
|
605
647
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
606
648
|
|
607
649
|
length_offset = tl.arange(0, bs_upper)
|
608
|
-
start = tl.load(start_offset + length_offset, mask=length_offset < pid)
|
609
|
-
end = tl.load(end_offset + length_offset, mask=length_offset < pid)
|
650
|
+
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
|
651
|
+
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
|
610
652
|
out_offset = tl.sum(end - start, axis=0)
|
611
653
|
|
612
654
|
out_cache_ptr = out_cache_loc + out_offset
|
@@ -628,36 +670,75 @@ def assign_draft_cache_locs(
|
|
628
670
|
req_pool_indices,
|
629
671
|
req_to_token,
|
630
672
|
seq_lens,
|
673
|
+
extend_lens,
|
674
|
+
num_new_pages_per_topk,
|
631
675
|
out_cache_loc,
|
632
676
|
pool_len: tl.constexpr,
|
633
677
|
topk: tl.constexpr,
|
634
678
|
speculative_num_steps: tl.constexpr,
|
635
679
|
page_size: tl.constexpr,
|
680
|
+
bs_upper: tl.constexpr,
|
681
|
+
iter_upper: tl.constexpr,
|
636
682
|
):
|
637
|
-
BLOCK_SIZE: tl.constexpr =
|
683
|
+
BLOCK_SIZE: tl.constexpr = 128
|
638
684
|
pid = tl.program_id(axis=0)
|
639
|
-
kv_start = tl.load(seq_lens + pid)
|
640
685
|
|
641
686
|
if page_size == 1 or topk == 1:
|
642
|
-
|
687
|
+
copy_len = topk * speculative_num_steps
|
643
688
|
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
644
689
|
else:
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
) // page_size
|
650
|
-
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
690
|
+
bs_offset = tl.arange(0, bs_upper)
|
691
|
+
copy_len = tl.load(extend_lens + pid)
|
692
|
+
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
|
693
|
+
out_cache_ptr = out_cache_loc + cum_copy_len
|
651
694
|
|
695
|
+
# Part 1: Copy from out_cache_loc to req_to_token
|
696
|
+
kv_start = tl.load(seq_lens + pid)
|
652
697
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
653
|
-
|
654
|
-
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
698
|
+
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
655
699
|
for i in range(num_loop):
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
700
|
+
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
701
|
+
mask = copy_offset < copy_len
|
702
|
+
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
|
703
|
+
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
|
704
|
+
|
705
|
+
if page_size == 1 or topk == 1:
|
706
|
+
return
|
707
|
+
|
708
|
+
# Part 2: Copy the indices for the last partial page
|
709
|
+
prefix_len = tl.load(seq_lens + pid)
|
710
|
+
last_page_len = prefix_len % page_size
|
711
|
+
offsets = tl.arange(0, page_size)
|
712
|
+
mask = offsets < last_page_len
|
713
|
+
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
|
714
|
+
prefix_base = token_pool + prefix_len - last_page_len
|
715
|
+
|
716
|
+
for topk_id in range(topk):
|
717
|
+
value = tl.load(prefix_base + offsets, mask=mask)
|
718
|
+
tl.store(
|
719
|
+
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
|
720
|
+
value,
|
721
|
+
mask=mask,
|
722
|
+
)
|
723
|
+
|
724
|
+
# Part 3: Remove the padding in out_cache_loc
|
725
|
+
iter_offest = tl.arange(0, iter_upper)
|
726
|
+
for topk_id in range(topk):
|
727
|
+
indices = tl.load(
|
728
|
+
prefix_base
|
729
|
+
+ topk_id * num_new_pages_per_topk_ * page_size
|
730
|
+
+ last_page_len
|
731
|
+
+ iter_offest,
|
732
|
+
mask=iter_offest < speculative_num_steps,
|
733
|
+
)
|
734
|
+
tl.store(
|
735
|
+
out_cache_loc
|
736
|
+
+ pid * topk * speculative_num_steps
|
737
|
+
+ topk_id * speculative_num_steps
|
738
|
+
+ iter_offest,
|
739
|
+
indices,
|
740
|
+
mask=iter_offest < speculative_num_steps,
|
741
|
+
)
|
661
742
|
|
662
743
|
|
663
744
|
@triton.jit
|
@@ -668,29 +749,33 @@ def generate_draft_decode_kv_indices(
|
|
668
749
|
kv_indices,
|
669
750
|
kv_indptr,
|
670
751
|
positions,
|
671
|
-
num_seqs: tl.constexpr,
|
672
|
-
topk: tl.constexpr,
|
673
752
|
pool_len: tl.constexpr,
|
674
753
|
kv_indices_stride: tl.constexpr,
|
675
754
|
kv_indptr_stride: tl.constexpr,
|
676
755
|
bs_upper: tl.constexpr,
|
677
756
|
iter_upper: tl.constexpr,
|
678
757
|
num_tokens_upper: tl.constexpr,
|
758
|
+
page_size: tl.constexpr,
|
679
759
|
):
|
680
760
|
BLOCK_SIZE: tl.constexpr = 128
|
681
761
|
iters = tl.program_id(axis=0)
|
682
762
|
bid = tl.program_id(axis=1)
|
683
763
|
topk_id = tl.program_id(axis=2)
|
684
764
|
|
765
|
+
num_steps = tl.num_programs(axis=0)
|
766
|
+
num_seqs = tl.num_programs(axis=1)
|
767
|
+
topk = tl.num_programs(axis=2)
|
768
|
+
|
685
769
|
kv_indices += kv_indices_stride * iters
|
686
770
|
kv_indptr += kv_indptr_stride * iters
|
687
771
|
iters += 1
|
688
772
|
|
689
773
|
load_offset = tl.arange(0, bs_upper)
|
690
|
-
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
|
774
|
+
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
|
691
775
|
seq_len = tl.load(paged_kernel_lens + bid)
|
692
776
|
cum_seq_len = tl.sum(seq_lens)
|
693
777
|
|
778
|
+
# Update kv_indices
|
694
779
|
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
695
780
|
kv_ptr = kv_indices + kv_offset
|
696
781
|
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
@@ -704,10 +789,26 @@ def generate_draft_decode_kv_indices(
|
|
704
789
|
kv_offset += BLOCK_SIZE
|
705
790
|
|
706
791
|
extend_offset = tl.arange(0, iter_upper)
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
792
|
+
if page_size == 1 or topk == 1:
|
793
|
+
extend_data = tl.load(
|
794
|
+
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
|
795
|
+
mask=extend_offset < iters,
|
796
|
+
)
|
797
|
+
else:
|
798
|
+
prefix_len = seq_len
|
799
|
+
last_page_len = prefix_len % page_size
|
800
|
+
num_new_pages_per_topk = (
|
801
|
+
last_page_len + num_steps + page_size - 1
|
802
|
+
) // page_size
|
803
|
+
prefix_base = seq_len // page_size * page_size
|
804
|
+
start = (
|
805
|
+
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
|
806
|
+
)
|
807
|
+
extend_data = tl.load(
|
808
|
+
token_pool_ptr + start + extend_offset,
|
809
|
+
mask=extend_offset < iters,
|
810
|
+
)
|
811
|
+
|
711
812
|
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
712
813
|
|
713
814
|
# Update kv_indptr
|
@@ -716,7 +817,7 @@ def generate_draft_decode_kv_indices(
|
|
716
817
|
zid = bid * topk + topk_id
|
717
818
|
if zid == 0:
|
718
819
|
zid = num_seqs * topk
|
719
|
-
positions = tl.load(positions + bs_offset, mask=bs_offset < zid)
|
820
|
+
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
|
720
821
|
base = tl.sum(positions)
|
721
822
|
tl.store(kv_indptr + zid, base + zid * iters)
|
722
823
|
|
@@ -734,7 +835,9 @@ def align_evict_mask_to_page_size(
|
|
734
835
|
bid = tl.program_id(axis=0)
|
735
836
|
seq_len = tl.load(seq_lens + bid)
|
736
837
|
io_mask = t_range < num_draft_tokens
|
737
|
-
mask_row = tl.load(
|
838
|
+
mask_row = tl.load(
|
839
|
+
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
|
840
|
+
)
|
738
841
|
|
739
842
|
num_trues = tl.sum(mask_row)
|
740
843
|
num_false = num_draft_tokens - num_trues
|
@@ -744,6 +847,116 @@ def align_evict_mask_to_page_size(
|
|
744
847
|
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
745
848
|
|
746
849
|
|
850
|
+
@triton.jit
|
851
|
+
def get_target_cache_loc(
|
852
|
+
tgt_cache_loc,
|
853
|
+
to_free_slots,
|
854
|
+
accept_length,
|
855
|
+
to_free_num_slots,
|
856
|
+
out_cache_loc,
|
857
|
+
num_verify_tokens: tl.constexpr,
|
858
|
+
num_verify_tokens_upper: tl.constexpr,
|
859
|
+
bs_upper: tl.constexpr,
|
860
|
+
):
|
861
|
+
bid = tl.program_id(axis=0)
|
862
|
+
offset = tl.arange(0, num_verify_tokens_upper)
|
863
|
+
bs_offset = tl.arange(0, bs_upper)
|
864
|
+
|
865
|
+
# write the first part to tgt_cache_loc
|
866
|
+
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
867
|
+
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
|
868
|
+
copy_len = tl.load(accept_length + bid) + 1
|
869
|
+
out_cache_loc_row = tl.load(
|
870
|
+
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
|
871
|
+
)
|
872
|
+
tl.store(
|
873
|
+
tgt_cache_loc + tgt_cache_loc_start + offset,
|
874
|
+
out_cache_loc_row,
|
875
|
+
mask=offset < copy_len,
|
876
|
+
)
|
877
|
+
|
878
|
+
# write the second part to to_free_num_pages
|
879
|
+
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
|
880
|
+
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
|
881
|
+
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
|
882
|
+
to_free_slots_start = tl.sum(to_free_num_slots_all)
|
883
|
+
|
884
|
+
copy_len = to_free_num_slots_cur
|
885
|
+
out_cache_loc_row = tl.load(
|
886
|
+
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
|
887
|
+
mask=offset < copy_len,
|
888
|
+
)
|
889
|
+
tl.store(
|
890
|
+
to_free_slots + to_free_slots_start + offset,
|
891
|
+
out_cache_loc_row,
|
892
|
+
mask=offset < copy_len,
|
893
|
+
)
|
894
|
+
|
895
|
+
|
896
|
+
@torch.compile(dynamic=True)
|
897
|
+
def get_src_tgt_cache_loc(
|
898
|
+
seq_lens: torch.Tensor,
|
899
|
+
out_cache_loc: torch.Tensor,
|
900
|
+
accept_index: torch.Tensor,
|
901
|
+
accept_length: torch.Tensor,
|
902
|
+
draft_token_num: int,
|
903
|
+
page_size: int,
|
904
|
+
):
|
905
|
+
src_cache_loc = out_cache_loc[accept_index]
|
906
|
+
tgt_cache_loc = torch.empty_like(src_cache_loc)
|
907
|
+
extended_len = seq_lens + draft_token_num
|
908
|
+
keep_len = torch.minimum(
|
909
|
+
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
|
910
|
+
extended_len,
|
911
|
+
)
|
912
|
+
to_free_num_slots = extended_len - keep_len
|
913
|
+
return src_cache_loc, tgt_cache_loc, to_free_num_slots
|
914
|
+
|
915
|
+
|
916
|
+
@triton.jit
|
917
|
+
def filter_finished_cache_loc_kernel(
|
918
|
+
out_cache_loc,
|
919
|
+
tgt_cache_loc,
|
920
|
+
accept_length,
|
921
|
+
accept_length_filter,
|
922
|
+
bs_upper: tl.constexpr,
|
923
|
+
num_verify_tokens_upper: tl.constexpr,
|
924
|
+
):
|
925
|
+
bid = tl.program_id(0)
|
926
|
+
bs_offset = tl.arange(0, bs_upper)
|
927
|
+
|
928
|
+
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
929
|
+
old_start = tl.sum(accept_length_all) + bid
|
930
|
+
|
931
|
+
accept_length_filter_all = tl.load(
|
932
|
+
accept_length_filter + bs_offset, mask=bs_offset < bid
|
933
|
+
)
|
934
|
+
new_start = tl.sum(accept_length_filter_all)
|
935
|
+
|
936
|
+
copy_len = tl.load(accept_length_filter + bid)
|
937
|
+
copy_offset = tl.arange(0, num_verify_tokens_upper)
|
938
|
+
value = tl.load(
|
939
|
+
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
|
940
|
+
)
|
941
|
+
tl.store(
|
942
|
+
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
|
943
|
+
)
|
944
|
+
|
945
|
+
|
946
|
+
@torch.compile(dynamic=True)
|
947
|
+
def create_accept_length_filter(
|
948
|
+
accept_length: torch.Tensor,
|
949
|
+
unfinished_index_device: torch.Tensor,
|
950
|
+
seq_lens: torch.Tensor,
|
951
|
+
):
|
952
|
+
accept_length_filter = torch.zeros_like(accept_length)
|
953
|
+
accept_length_filter[unfinished_index_device] = (
|
954
|
+
accept_length[unfinished_index_device] + 1
|
955
|
+
)
|
956
|
+
seq_lens.add_(accept_length + 1)
|
957
|
+
return accept_length_filter
|
958
|
+
|
959
|
+
|
747
960
|
@torch.compile(dynamic=True)
|
748
961
|
def select_top_k_tokens(
|
749
962
|
i: int,
|
@@ -802,15 +1015,35 @@ def _generate_simulated_accept_index(
|
|
802
1015
|
spec_steps,
|
803
1016
|
):
|
804
1017
|
simulate_acc_len_float = float(simulate_acc_len)
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
1018
|
+
if SIMULATE_ACC_METHOD == "multinomial":
|
1019
|
+
simulated_values = torch.normal(
|
1020
|
+
mean=simulate_acc_len_float,
|
1021
|
+
std=1.0,
|
1022
|
+
size=(1,),
|
1023
|
+
device="cpu",
|
1024
|
+
)
|
1025
|
+
# clamp simulated values to be between 1 and self.spec_steps
|
1026
|
+
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
1027
|
+
simulate_acc_len = int(simulated_values.round().item())
|
1028
|
+
elif SIMULATE_ACC_METHOD == "match-expected":
|
1029
|
+
# multinomial sampling does not match the expected length
|
1030
|
+
# we keep it for the sake of compatibility of existing tests
|
1031
|
+
# but it's better to use "match-expected" for the cases that need to
|
1032
|
+
# match the expected length, One caveat is that this will only sample
|
1033
|
+
# either round down or round up of the expected length
|
1034
|
+
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
|
1035
|
+
lower = int(simulate_acc_len_float // 1)
|
1036
|
+
upper = lower + 1 if lower < spec_steps + 1 else lower
|
1037
|
+
if lower == upper:
|
1038
|
+
simulate_acc_len = lower
|
1039
|
+
else:
|
1040
|
+
weight_upper = simulate_acc_len_float - lower
|
1041
|
+
weight_lower = 1.0 - weight_upper
|
1042
|
+
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
1043
|
+
sampled_index = torch.multinomial(probs, num_samples=1)
|
1044
|
+
simulate_acc_len = lower if sampled_index == 0 else upper
|
1045
|
+
else:
|
1046
|
+
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
|
814
1047
|
|
815
1048
|
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
816
1049
|
sim_accept_index = torch.full(
|
@@ -901,9 +1134,9 @@ def generate_token_bitmask(
|
|
901
1134
|
"""
|
902
1135
|
Generate the logit mask for structured output.
|
903
1136
|
Draft model's token can be either valid or invalid with respect to the grammar.
|
904
|
-
We need to perform DFS to
|
905
|
-
1. which tokens are accepted by the grammar
|
906
|
-
2. what is the corresponding logit mask.
|
1137
|
+
We need to perform DFS to
|
1138
|
+
1. figure out which tokens are accepted by the grammar.
|
1139
|
+
2. if so, what is the corresponding logit mask.
|
907
1140
|
"""
|
908
1141
|
|
909
1142
|
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
@@ -920,6 +1153,7 @@ def generate_token_bitmask(
|
|
920
1153
|
device="cpu",
|
921
1154
|
)
|
922
1155
|
grammar = req.grammar
|
1156
|
+
s = time.perf_counter()
|
923
1157
|
traverse_tree(
|
924
1158
|
retrieve_next_token_cpu[i],
|
925
1159
|
retrieve_next_sibling_cpu[i],
|
@@ -929,6 +1163,12 @@ def generate_token_bitmask(
|
|
929
1163
|
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
930
1164
|
],
|
931
1165
|
)
|
1166
|
+
tree_traverse_time = time.perf_counter() - s
|
1167
|
+
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
|
1168
|
+
logger.warning(
|
1169
|
+
f"Bit mask generation took {tree_traverse_time} seconds with "
|
1170
|
+
f"grammar: {req.grammar}"
|
1171
|
+
)
|
932
1172
|
|
933
1173
|
verify_input.grammar = grammar
|
934
1174
|
return allocate_token_bitmask
|