sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
1
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt import two_batch_overlap
|
6
6
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
7
|
-
from sglang.srt.speculative.
|
7
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
|
|
46
46
|
seq_lens: torch.Tensor,
|
47
47
|
encoder_lens: Optional[torch.Tensor],
|
48
48
|
forward_mode: "ForwardMode",
|
49
|
-
spec_info: Optional[
|
49
|
+
spec_info: Optional[SpecInput],
|
50
50
|
):
|
51
51
|
self.primary.init_forward_metadata_capture_cuda_graph(
|
52
52
|
bs=bs,
|
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
|
|
77
77
|
seq_lens_sum: int,
|
78
78
|
encoder_lens: Optional[torch.Tensor],
|
79
79
|
forward_mode: "ForwardMode",
|
80
|
-
spec_info: Optional[
|
80
|
+
spec_info: Optional[SpecInput],
|
81
81
|
seq_lens_cpu: Optional[torch.Tensor],
|
82
82
|
):
|
83
83
|
self.primary.init_forward_metadata_replay_cuda_graph(
|
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
|
|
112
112
|
seq_lens: torch.Tensor,
|
113
113
|
encoder_lens: Optional[torch.Tensor],
|
114
114
|
forward_mode: "ForwardMode",
|
115
|
-
spec_info: Optional[
|
115
|
+
spec_info: Optional[SpecInput],
|
116
116
|
# capture args
|
117
117
|
capture_num_tokens: int = None,
|
118
118
|
# replay args
|
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
|
|
196
196
|
seq_lens: torch.Tensor,
|
197
197
|
encoder_lens: Optional[torch.Tensor],
|
198
198
|
forward_mode: "ForwardMode",
|
199
|
-
spec_info: Optional[
|
199
|
+
spec_info: Optional[SpecInput],
|
200
200
|
# capture args
|
201
201
|
capture_num_tokens: int = None,
|
202
202
|
# replay args
|
@@ -0,0 +1,325 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
7
|
+
|
8
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
14
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
15
|
+
|
16
|
+
|
17
|
+
class TorchFlexAttnBackend(AttentionBackend):
|
18
|
+
def __init__(self, model_runner: ModelRunner):
|
19
|
+
super().__init__()
|
20
|
+
self.forward_metadata = None
|
21
|
+
self.device = model_runner.device
|
22
|
+
self.flex_attention = torch.compile(flex_attention, dynamic=True)
|
23
|
+
torch._dynamo.config.cache_size_limit = 1024
|
24
|
+
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
25
|
+
|
26
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
27
|
+
"""Init the metadata for a forward pass."""
|
28
|
+
# TODO: find a more elegant way to save memory
|
29
|
+
# Currently maintain the same memory as torch_native_backend
|
30
|
+
torch.cuda.empty_cache()
|
31
|
+
|
32
|
+
# Provide two block_mask Lists per seq_idx for lower latency, later will support per layer level mask generation
|
33
|
+
self.extend_block_masks = []
|
34
|
+
self.decode_block_masks = []
|
35
|
+
|
36
|
+
if forward_batch.forward_mode.is_extend():
|
37
|
+
for seq_idx in range(forward_batch.seq_lens.shape[0]):
|
38
|
+
seq_len_kv = forward_batch.seq_lens[seq_idx]
|
39
|
+
seq_len_q = seq_len_kv
|
40
|
+
self.extend_block_masks.append(
|
41
|
+
create_block_mask(
|
42
|
+
self._causal_mask,
|
43
|
+
None,
|
44
|
+
None,
|
45
|
+
seq_len_q,
|
46
|
+
seq_len_kv,
|
47
|
+
device=self.device,
|
48
|
+
_compile=False,
|
49
|
+
)
|
50
|
+
)
|
51
|
+
|
52
|
+
elif forward_batch.forward_mode.is_decode():
|
53
|
+
for seq_idx in range(forward_batch.seq_lens.shape[0]):
|
54
|
+
seq_len_q = 1
|
55
|
+
seq_len_kv = forward_batch.seq_lens[seq_idx]
|
56
|
+
|
57
|
+
self.decode_block_masks.append(
|
58
|
+
create_block_mask(
|
59
|
+
self._decode_mask,
|
60
|
+
None,
|
61
|
+
None,
|
62
|
+
seq_len_q,
|
63
|
+
seq_len_kv,
|
64
|
+
device=self.device,
|
65
|
+
_compile=False,
|
66
|
+
)
|
67
|
+
)
|
68
|
+
|
69
|
+
def _causal_mask(self, b, h, q_idx, kv_idx):
|
70
|
+
return q_idx >= kv_idx
|
71
|
+
|
72
|
+
def _decode_mask(self, b, h, q_idx, kv_idx):
|
73
|
+
return q_idx <= kv_idx
|
74
|
+
|
75
|
+
def _run_flex_forward_extend(
|
76
|
+
self,
|
77
|
+
query: torch.Tensor,
|
78
|
+
output: torch.Tensor,
|
79
|
+
k_cache: torch.Tensor,
|
80
|
+
v_cache: torch.Tensor,
|
81
|
+
req_to_token: torch.Tensor,
|
82
|
+
req_pool_indices: torch.Tensor,
|
83
|
+
seq_lens: torch.Tensor,
|
84
|
+
extend_prefix_lens: torch.Tensor,
|
85
|
+
extend_seq_lens: torch.Tensor,
|
86
|
+
scaling=None,
|
87
|
+
enable_gqa=False,
|
88
|
+
causal=False,
|
89
|
+
):
|
90
|
+
"""Run the extend forward by using torch flex attention op.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
query: [num_tokens, num_heads, head_size]
|
94
|
+
output: [num_tokens, num_heads, head_size]
|
95
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
96
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
97
|
+
req_to_token: [max_num_reqs, max_context_len]
|
98
|
+
req_pool_indices: [num_seqs]
|
99
|
+
seq_lens: [num_seqs]
|
100
|
+
extend_prefix_lens: [num_seqs]
|
101
|
+
extend_seq_lens: [num_seqs]
|
102
|
+
scaling: float or None
|
103
|
+
enable_gqa: bool
|
104
|
+
causal: bool
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
output: [num_tokens, num_heads, head_size]
|
108
|
+
"""
|
109
|
+
|
110
|
+
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
111
|
+
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
112
|
+
|
113
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
114
|
+
query = query.movedim(0, query.dim() - 2)
|
115
|
+
|
116
|
+
start_q, start_kv = 0, 0
|
117
|
+
|
118
|
+
for seq_idx in range(seq_lens.shape[0]):
|
119
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
120
|
+
# Need optimize the performance later.
|
121
|
+
extend_seq_len_q = extend_seq_lens[seq_idx]
|
122
|
+
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
123
|
+
|
124
|
+
seq_len_kv = seq_lens[seq_idx]
|
125
|
+
end_q = start_q + extend_seq_len_q
|
126
|
+
end_kv = start_kv + seq_len_kv
|
127
|
+
|
128
|
+
per_req_query = query[:, start_q:end_q, :]
|
129
|
+
per_req_query_redundant = torch.empty(
|
130
|
+
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
131
|
+
dtype=per_req_query.dtype,
|
132
|
+
device=per_req_query.device,
|
133
|
+
)
|
134
|
+
|
135
|
+
per_req_query_redundant[:, prefill_seq_len_q:, :] = per_req_query
|
136
|
+
|
137
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
138
|
+
# index for each token in the sequence.
|
139
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
140
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
141
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
142
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
143
|
+
|
144
|
+
if not causal:
|
145
|
+
raise NotImplementedError("Non-causal mode is not yet implemented.")
|
146
|
+
|
147
|
+
per_req_out_redundant = (
|
148
|
+
self.flex_attention(
|
149
|
+
per_req_query_redundant.unsqueeze(0),
|
150
|
+
per_req_key.unsqueeze(0),
|
151
|
+
per_req_value.unsqueeze(0),
|
152
|
+
block_mask=self.extend_block_masks[seq_idx],
|
153
|
+
scale=scaling,
|
154
|
+
enable_gqa=enable_gqa,
|
155
|
+
)
|
156
|
+
.squeeze(0)
|
157
|
+
.movedim(query.dim() - 2, 0)
|
158
|
+
)
|
159
|
+
output[start_q:end_q, :, :] = per_req_out_redundant[
|
160
|
+
prefill_seq_len_q:, :, :
|
161
|
+
]
|
162
|
+
start_q, start_kv = end_q, end_kv
|
163
|
+
return output
|
164
|
+
|
165
|
+
def _run_flex_forward_decode(
|
166
|
+
self,
|
167
|
+
query: torch.Tensor,
|
168
|
+
output: torch.Tensor,
|
169
|
+
k_cache: torch.Tensor,
|
170
|
+
v_cache: torch.Tensor,
|
171
|
+
req_to_token: torch.Tensor,
|
172
|
+
req_pool_indices: torch.Tensor,
|
173
|
+
seq_lens: torch.Tensor,
|
174
|
+
scaling=None,
|
175
|
+
enable_gqa=False,
|
176
|
+
causal=False,
|
177
|
+
):
|
178
|
+
"""Run the decode forward by using torch flex attention op.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
query: [num_tokens, num_heads, head_size]
|
182
|
+
output: [num_tokens, num_heads, head_size]
|
183
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
184
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
185
|
+
req_to_token: [max_num_reqs, max_context_len]
|
186
|
+
req_pool_indices: [num_seqs]
|
187
|
+
seq_lens: [num_seqs]
|
188
|
+
scaling: float or None
|
189
|
+
enable_gqa: bool
|
190
|
+
causal: bool
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
output: [num_tokens, num_heads, head_size]
|
194
|
+
"""
|
195
|
+
|
196
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
197
|
+
query = query.movedim(0, query.dim() - 2)
|
198
|
+
|
199
|
+
start_q, start_kv = 0, 0
|
200
|
+
for seq_idx in range(seq_lens.shape[0]):
|
201
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
202
|
+
# Need optimize the performance later.
|
203
|
+
|
204
|
+
seq_len_q = 1
|
205
|
+
seq_len_kv = seq_lens[seq_idx]
|
206
|
+
end_q = start_q + seq_len_q
|
207
|
+
end_kv = start_kv + seq_len_kv
|
208
|
+
|
209
|
+
per_req_query = query[:, start_q:end_q, :]
|
210
|
+
|
211
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
212
|
+
# index for each token in the sequence.
|
213
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
214
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
215
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
216
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
217
|
+
|
218
|
+
per_req_out = (
|
219
|
+
self.flex_attention(
|
220
|
+
per_req_query.unsqueeze(0),
|
221
|
+
per_req_key.unsqueeze(0),
|
222
|
+
per_req_value.unsqueeze(0),
|
223
|
+
block_mask=self.decode_block_masks[seq_idx],
|
224
|
+
scale=scaling,
|
225
|
+
enable_gqa=enable_gqa,
|
226
|
+
)
|
227
|
+
.squeeze(0)
|
228
|
+
.movedim(query.dim() - 2, 0)
|
229
|
+
)
|
230
|
+
|
231
|
+
output[start_q:end_q, :, :] = per_req_out
|
232
|
+
start_q, start_kv = end_q, end_kv
|
233
|
+
|
234
|
+
return output
|
235
|
+
|
236
|
+
def forward_extend(
|
237
|
+
self,
|
238
|
+
q,
|
239
|
+
k,
|
240
|
+
v,
|
241
|
+
layer: RadixAttention,
|
242
|
+
forward_batch: ForwardBatch,
|
243
|
+
save_kv_cache=True,
|
244
|
+
):
|
245
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
246
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
247
|
+
else:
|
248
|
+
o = torch.empty_like(q)
|
249
|
+
|
250
|
+
if save_kv_cache:
|
251
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
252
|
+
layer, forward_batch.out_cache_loc, k, v
|
253
|
+
)
|
254
|
+
|
255
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
256
|
+
|
257
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
258
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
259
|
+
|
260
|
+
causal = True
|
261
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
262
|
+
raise NotImplementedError(
|
263
|
+
"TorchFlexAttnBackend does not support non-causal attention for now."
|
264
|
+
)
|
265
|
+
|
266
|
+
self._run_flex_forward_extend(
|
267
|
+
q_,
|
268
|
+
o_,
|
269
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
270
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
271
|
+
forward_batch.req_to_token_pool.req_to_token,
|
272
|
+
forward_batch.req_pool_indices,
|
273
|
+
forward_batch.seq_lens,
|
274
|
+
forward_batch.extend_prefix_lens,
|
275
|
+
forward_batch.extend_seq_lens,
|
276
|
+
scaling=layer.scaling,
|
277
|
+
enable_gqa=use_gqa,
|
278
|
+
causal=causal,
|
279
|
+
)
|
280
|
+
return o
|
281
|
+
|
282
|
+
def forward_decode(
|
283
|
+
self,
|
284
|
+
q,
|
285
|
+
k,
|
286
|
+
v,
|
287
|
+
layer: RadixAttention,
|
288
|
+
forward_batch: ForwardBatch,
|
289
|
+
save_kv_cache=True,
|
290
|
+
):
|
291
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
292
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
293
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
294
|
+
|
295
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
296
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
297
|
+
else:
|
298
|
+
o = torch.empty_like(q)
|
299
|
+
|
300
|
+
if save_kv_cache:
|
301
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
302
|
+
layer, forward_batch.out_cache_loc, k, v
|
303
|
+
)
|
304
|
+
|
305
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
306
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
307
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
308
|
+
|
309
|
+
self._run_flex_forward_decode(
|
310
|
+
q_,
|
311
|
+
o_,
|
312
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
313
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
314
|
+
forward_batch.req_to_token_pool.req_to_token,
|
315
|
+
forward_batch.req_pool_indices,
|
316
|
+
forward_batch.seq_lens,
|
317
|
+
scaling=layer.scaling,
|
318
|
+
enable_gqa=use_gqa,
|
319
|
+
causal=False,
|
320
|
+
)
|
321
|
+
|
322
|
+
return o
|
323
|
+
|
324
|
+
def support_triton(self):
|
325
|
+
return False
|
@@ -12,12 +12,17 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|
12
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
13
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
|
-
from sglang.srt.utils import
|
15
|
+
from sglang.srt.utils import (
|
16
|
+
get_bool_env_var,
|
17
|
+
get_device_core_count,
|
18
|
+
get_int_env_var,
|
19
|
+
next_power_of_2,
|
20
|
+
)
|
16
21
|
|
17
22
|
if TYPE_CHECKING:
|
18
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
24
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
20
|
-
from sglang.srt.speculative.
|
25
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
21
26
|
|
22
27
|
|
23
28
|
def logit_capping_mod(logit_capping_method, logit_cap):
|
@@ -94,7 +99,25 @@ class TritonAttnBackend(AttentionBackend):
|
|
94
99
|
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
95
100
|
)
|
96
101
|
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
97
|
-
|
102
|
+
|
103
|
+
# Decide whether enable deterministic inference with batch-invariant operations
|
104
|
+
self.enable_deterministic = (
|
105
|
+
model_runner.server_args.enable_deterministic_inference
|
106
|
+
)
|
107
|
+
|
108
|
+
# Configure deterministic inference settings
|
109
|
+
if self.enable_deterministic:
|
110
|
+
# Use fixed split tile size for batch invariance
|
111
|
+
self.split_tile_size = get_int_env_var(
|
112
|
+
"SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
|
113
|
+
)
|
114
|
+
# Set static_kv_splits to False to use deterministic logic instead
|
115
|
+
self.static_kv_splits = False
|
116
|
+
else:
|
117
|
+
self.split_tile_size = (
|
118
|
+
model_runner.server_args.triton_attention_split_tile_size
|
119
|
+
)
|
120
|
+
|
98
121
|
if self.split_tile_size is not None:
|
99
122
|
self.max_kv_splits = (
|
100
123
|
self.max_context_len + self.split_tile_size - 1
|
@@ -154,13 +177,23 @@ class TritonAttnBackend(AttentionBackend):
|
|
154
177
|
num_group * num_seq == num_token
|
155
178
|
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
|
156
179
|
|
157
|
-
|
180
|
+
# Legacy dynamic splitting logic (non-deterministic)
|
181
|
+
if (
|
182
|
+
self.static_kv_splits or self.device_core_count <= 0
|
183
|
+
) and not self.enable_deterministic:
|
158
184
|
num_kv_splits.fill_(self.max_kv_splits)
|
159
185
|
return
|
160
186
|
|
161
|
-
|
187
|
+
# deterministic
|
188
|
+
if self.split_tile_size is not None and self.enable_deterministic:
|
189
|
+
# expand seq_lens to match num_token
|
190
|
+
if num_group > 1:
|
191
|
+
expanded_seq_lens = seq_lens.repeat_interleave(num_group)
|
192
|
+
else:
|
193
|
+
expanded_seq_lens = seq_lens
|
194
|
+
|
162
195
|
num_kv_splits[:] = (
|
163
|
-
|
196
|
+
expanded_seq_lens + self.split_tile_size - 1
|
164
197
|
) // self.split_tile_size
|
165
198
|
return
|
166
199
|
|
@@ -449,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
449
482
|
seq_lens: torch.Tensor,
|
450
483
|
encoder_lens: Optional[torch.Tensor],
|
451
484
|
forward_mode: ForwardMode,
|
452
|
-
spec_info: Optional[
|
485
|
+
spec_info: Optional[SpecInput],
|
453
486
|
):
|
454
487
|
assert encoder_lens is None, "Not supported"
|
455
488
|
window_kv_indptr = self.window_kv_indptr
|
@@ -605,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
605
638
|
seq_lens_sum: int,
|
606
639
|
encoder_lens: Optional[torch.Tensor],
|
607
640
|
forward_mode: ForwardMode,
|
608
|
-
spec_info: Optional[
|
641
|
+
spec_info: Optional[SpecInput],
|
609
642
|
seq_lens_cpu: Optional[torch.Tensor],
|
610
643
|
):
|
611
644
|
# NOTE: encoder_lens expected to be zeros or None
|
@@ -850,7 +883,7 @@ class TritonMultiStepDraftBackend:
|
|
850
883
|
topk: int,
|
851
884
|
speculative_num_steps: int,
|
852
885
|
):
|
853
|
-
from sglang.srt.speculative.
|
886
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
854
887
|
|
855
888
|
self.topk = topk
|
856
889
|
self.speculative_num_steps = speculative_num_steps
|
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
|
|
20
20
|
if is_flashinfer_available():
|
21
21
|
import flashinfer
|
22
22
|
|
23
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
24
|
-
|
25
23
|
if TYPE_CHECKING:
|
26
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
27
25
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
28
|
-
from sglang.srt.speculative.spec_info import
|
26
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
29
27
|
|
30
28
|
# Constants
|
31
29
|
DEFAULT_WORKSPACE_SIZE_MB = (
|
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
201
199
|
seq_lens: torch.Tensor,
|
202
200
|
encoder_lens: Optional[torch.Tensor],
|
203
201
|
forward_mode: ForwardMode,
|
204
|
-
spec_info: Optional[
|
202
|
+
spec_info: Optional[SpecInput],
|
205
203
|
):
|
206
204
|
"""Initialize metadata for CUDA graph capture."""
|
207
205
|
metadata = TRTLLMMHAMetadata()
|
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
314
312
|
seq_lens_sum: int,
|
315
313
|
encoder_lens: Optional[torch.Tensor],
|
316
314
|
forward_mode: ForwardMode,
|
317
|
-
spec_info: Optional[
|
315
|
+
spec_info: Optional[SpecInput],
|
318
316
|
seq_lens_cpu: Optional[torch.Tensor],
|
319
317
|
):
|
320
318
|
"""Replay CUDA graph with new inputs."""
|
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
|
661
659
|
forward_batch: ForwardBatch,
|
662
660
|
):
|
663
661
|
assert forward_batch.spec_info is not None
|
664
|
-
assert
|
662
|
+
assert forward_batch.spec_info.is_draft_input()
|
665
663
|
|
666
664
|
for i in range(self.speculative_num_steps - 1):
|
667
665
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
|
678
676
|
self, forward_batch: ForwardBatch, bs: int
|
679
677
|
):
|
680
678
|
assert forward_batch.spec_info is not None
|
681
|
-
assert
|
679
|
+
assert forward_batch.spec_info.is_draft_input()
|
682
680
|
|
683
681
|
for i in range(self.speculative_num_steps - 1):
|
684
682
|
|