sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,11 @@ from typing import Optional, Union
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
6
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
6
7
|
from sglang.srt.layers.radix_attention import RadixAttention
|
7
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
8
9
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
9
|
-
from sglang.srt.speculative.
|
10
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
10
11
|
|
11
12
|
|
12
13
|
class HybridAttnBackend(AttentionBackend):
|
@@ -21,6 +22,7 @@ class HybridAttnBackend(AttentionBackend):
|
|
21
22
|
self.model_runner = model_runner
|
22
23
|
self.prefill_backend = prefill_backend
|
23
24
|
self.decode_backend = decode_backend
|
25
|
+
self.data_type = model_runner.kv_cache_dtype
|
24
26
|
|
25
27
|
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
26
28
|
"""
|
@@ -70,7 +72,7 @@ class HybridAttnBackend(AttentionBackend):
|
|
70
72
|
seq_lens: torch.Tensor,
|
71
73
|
encoder_lens: Optional[torch.Tensor],
|
72
74
|
forward_mode: ForwardMode,
|
73
|
-
spec_info: Optional[
|
75
|
+
spec_info: Optional[SpecInput],
|
74
76
|
):
|
75
77
|
backend = self._select_backend(forward_mode)
|
76
78
|
backend.init_forward_metadata_capture_cuda_graph(
|
@@ -91,7 +93,7 @@ class HybridAttnBackend(AttentionBackend):
|
|
91
93
|
seq_lens_sum: int,
|
92
94
|
encoder_lens: Optional[torch.Tensor],
|
93
95
|
forward_mode: ForwardMode,
|
94
|
-
spec_info: Optional[
|
96
|
+
spec_info: Optional[SpecInput],
|
95
97
|
seq_lens_cpu: Optional[torch.Tensor],
|
96
98
|
):
|
97
99
|
backend = self._select_backend(forward_mode)
|
@@ -137,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
|
|
137
139
|
return backend.forward_extend(
|
138
140
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
139
141
|
)
|
142
|
+
|
143
|
+
def get_indexer_metadata(
|
144
|
+
self, layer_id: int, forward_batch: ForwardBatch
|
145
|
+
) -> Optional[BaseIndexerMetadata]:
|
146
|
+
backend = self._select_backend(forward_batch.forward_mode)
|
147
|
+
return backend.get_indexer_metadata(layer_id, forward_batch)
|
@@ -21,11 +21,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
21
21
|
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
22
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
23
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
24
|
-
from sglang.srt.models.qwen3_next import
|
25
|
-
from sglang.srt.speculative.
|
26
|
-
from sglang.srt.utils import is_npu
|
24
|
+
from sglang.srt.models.qwen3_next import fused_gdn_gating
|
25
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
26
|
+
from sglang.srt.utils import is_cuda, is_npu
|
27
27
|
|
28
|
-
if
|
28
|
+
if is_cuda():
|
29
|
+
from sglang.srt.layers.attention.mamba.causal_conv1d import (
|
30
|
+
causal_conv1d_fn as causal_conv1d_fn_cuda,
|
31
|
+
)
|
32
|
+
|
33
|
+
causal_conv1d_fn = causal_conv1d_fn_cuda
|
34
|
+
elif is_npu():
|
29
35
|
from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
|
30
36
|
from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
|
31
37
|
fused_sigmoid_gating_delta_rule_update_npu,
|
@@ -58,18 +64,16 @@ class MambaAttnBackend(AttentionBackend):
|
|
58
64
|
self.forward_metadata: ForwardMetadata = None
|
59
65
|
self.state_indices_list = []
|
60
66
|
self.query_start_loc_list = []
|
61
|
-
|
62
|
-
|
63
|
-
@lru_cache(maxsize=128)
|
64
|
-
def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
|
65
|
-
"""Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
|
66
|
-
device = torch.device(device_str)
|
67
|
-
return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
|
67
|
+
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
|
68
|
+
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
|
68
69
|
|
69
70
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
70
71
|
bs = forward_batch.batch_size
|
72
|
+
|
71
73
|
if forward_batch.forward_mode.is_decode_or_idle():
|
72
|
-
query_start_loc =
|
74
|
+
query_start_loc = torch.arange(
|
75
|
+
0, bs + 1, dtype=torch.int32, device=self.device
|
76
|
+
)
|
73
77
|
elif forward_batch.forward_mode.is_extend():
|
74
78
|
if forward_batch.forward_mode.is_target_verify():
|
75
79
|
query_start_loc = torch.arange(
|
@@ -99,6 +103,10 @@ class MambaAttnBackend(AttentionBackend):
|
|
99
103
|
)
|
100
104
|
|
101
105
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
106
|
+
assert (
|
107
|
+
max_num_tokens % max_bs == 0
|
108
|
+
), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
|
109
|
+
verify_step = max_num_tokens / max_bs
|
102
110
|
for i in range(max_bs):
|
103
111
|
self.state_indices_list.append(
|
104
112
|
torch.full(
|
@@ -108,6 +116,16 @@ class MambaAttnBackend(AttentionBackend):
|
|
108
116
|
self.query_start_loc_list.append(
|
109
117
|
torch.empty((i + 2,), dtype=torch.int32, device=self.device)
|
110
118
|
)
|
119
|
+
self.cached_cuda_graph_decode_query_start_loc = torch.arange(
|
120
|
+
0, max_bs + 1, dtype=torch.int32, device=self.device
|
121
|
+
)
|
122
|
+
self.cached_cuda_graph_verify_query_start_loc = torch.arange(
|
123
|
+
0,
|
124
|
+
max_bs * verify_step + 1,
|
125
|
+
step=verify_step,
|
126
|
+
dtype=torch.int32,
|
127
|
+
device=self.device,
|
128
|
+
)
|
111
129
|
|
112
130
|
def init_forward_metadata_capture_cuda_graph(
|
113
131
|
self,
|
@@ -117,19 +135,15 @@ class MambaAttnBackend(AttentionBackend):
|
|
117
135
|
seq_lens: torch.Tensor,
|
118
136
|
encoder_lens: Optional[torch.Tensor],
|
119
137
|
forward_mode: ForwardMode,
|
120
|
-
spec_info: Optional[
|
138
|
+
spec_info: Optional[SpecInput],
|
121
139
|
):
|
122
140
|
if forward_mode.is_decode_or_idle():
|
123
|
-
self.query_start_loc_list[bs - 1].copy_(
|
141
|
+
self.query_start_loc_list[bs - 1].copy_(
|
142
|
+
self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
|
143
|
+
)
|
124
144
|
elif forward_mode.is_target_verify():
|
125
145
|
self.query_start_loc_list[bs - 1].copy_(
|
126
|
-
|
127
|
-
0,
|
128
|
-
bs * spec_info.draft_token_num + 1,
|
129
|
-
step=spec_info.draft_token_num,
|
130
|
-
dtype=torch.int32,
|
131
|
-
device=self.device,
|
132
|
-
)
|
146
|
+
self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
|
133
147
|
)
|
134
148
|
else:
|
135
149
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
@@ -148,7 +162,7 @@ class MambaAttnBackend(AttentionBackend):
|
|
148
162
|
seq_lens_sum: int,
|
149
163
|
encoder_lens: Optional[torch.Tensor],
|
150
164
|
forward_mode: ForwardMode,
|
151
|
-
spec_info: Optional[
|
165
|
+
spec_info: Optional[SpecInput],
|
152
166
|
seq_lens_cpu: Optional[torch.Tensor],
|
153
167
|
):
|
154
168
|
num_padding = torch.count_nonzero(
|
@@ -160,23 +174,29 @@ class MambaAttnBackend(AttentionBackend):
|
|
160
174
|
mamba_indices[bs - num_padding :] = -1
|
161
175
|
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
162
176
|
if forward_mode.is_decode_or_idle():
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
elif forward_mode.is_target_verify():
|
167
|
-
self.query_start_loc_list[bs - 1].copy_(
|
168
|
-
torch.arange(
|
169
|
-
0,
|
170
|
-
bs * spec_info.draft_token_num + 1,
|
171
|
-
step=spec_info.draft_token_num,
|
172
|
-
dtype=torch.int32,
|
173
|
-
device=self.device,
|
177
|
+
if num_padding == 0:
|
178
|
+
self.query_start_loc_list[bs - 1].copy_(
|
179
|
+
self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
|
174
180
|
)
|
175
|
-
|
176
|
-
|
177
|
-
|
181
|
+
else:
|
182
|
+
self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
|
183
|
+
self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]
|
184
|
+
)
|
185
|
+
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
|
178
186
|
bs - num_padding
|
179
|
-
)
|
187
|
+
)
|
188
|
+
elif forward_mode.is_target_verify():
|
189
|
+
if num_padding == 0:
|
190
|
+
self.query_start_loc_list[bs - 1].copy_(
|
191
|
+
self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
|
195
|
+
self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]
|
196
|
+
)
|
197
|
+
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
|
198
|
+
(bs - num_padding) * spec_info.draft_token_num
|
199
|
+
)
|
180
200
|
else:
|
181
201
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
182
202
|
|
@@ -343,6 +363,7 @@ class MambaAttnBackend(AttentionBackend):
|
|
343
363
|
has_initial_state=has_initial_states,
|
344
364
|
cache_indices=cache_indices,
|
345
365
|
query_start_loc=query_start_loc,
|
366
|
+
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
346
367
|
).transpose(0, 1)[:seq_len]
|
347
368
|
|
348
369
|
key_split_dim = key_dim // attn_tp_size
|
@@ -431,7 +452,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
431
452
|
seq_lens: torch.Tensor,
|
432
453
|
encoder_lens: Optional[torch.Tensor],
|
433
454
|
forward_mode: ForwardMode,
|
434
|
-
spec_info: Optional[
|
455
|
+
spec_info: Optional[SpecInput],
|
435
456
|
):
|
436
457
|
for attn_backend in self.attn_backend_list:
|
437
458
|
attn_backend.init_forward_metadata_capture_cuda_graph(
|
@@ -452,7 +473,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
452
473
|
seq_lens_sum: int,
|
453
474
|
encoder_lens: Optional[torch.Tensor],
|
454
475
|
forward_mode: ForwardMode,
|
455
|
-
spec_info: Optional[
|
476
|
+
spec_info: Optional[SpecInput],
|
456
477
|
seq_lens_cpu: Optional[torch.Tensor],
|
457
478
|
):
|
458
479
|
for attn_backend in self.attn_backend_list:
|
@@ -567,36 +588,15 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
567
588
|
|
568
589
|
# Compute common indices once to avoid duplication
|
569
590
|
last_steps_all = (accepted_length - 1).to(torch.int64)
|
570
|
-
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
571
|
-
last_steps = last_steps_all[valid_mask].to(torch.int64)
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
for j in range(idx.numel()):
|
583
|
-
ci = idx[j].item()
|
584
|
-
st = steps[j].item()
|
585
|
-
ssm_states[:, ci, :].copy_(
|
586
|
-
intermediate_state_cache[:, ci, st].to(
|
587
|
-
ssm_states.dtype, copy=False
|
588
|
-
)
|
589
|
-
)
|
590
|
-
|
591
|
-
# Conv window updates
|
592
|
-
for i in range(0, num_valid, chunk):
|
593
|
-
idx = valid_state_indices[i : i + chunk]
|
594
|
-
steps = last_steps[i : i + chunk]
|
595
|
-
for j in range(idx.numel()):
|
596
|
-
ci = idx[j].item()
|
597
|
-
st = steps[j].item()
|
598
|
-
conv_states[:, ci, :, :].copy_(
|
599
|
-
intermediate_conv_window_cache[:, ci, st].to(
|
600
|
-
conv_states.dtype, copy=False
|
601
|
-
)
|
602
|
-
)
|
591
|
+
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
|
592
|
+
last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
|
593
|
+
|
594
|
+
# scatter into ssm_states at the chosen cache lines
|
595
|
+
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
|
596
|
+
:, valid_state_indices, last_steps
|
597
|
+
].to(ssm_states.dtype, copy=False)
|
598
|
+
|
599
|
+
# Scatter into conv_states at the chosen cache lines
|
600
|
+
conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
|
601
|
+
:, valid_state_indices, last_steps
|
602
|
+
].to(conv_states.dtype, copy=False)
|
@@ -2,7 +2,7 @@
|
|
2
2
|
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
3
3
|
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
4
4
|
|
5
|
-
from typing import Optional, Union
|
5
|
+
from typing import List, Optional, Union
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import torch
|
@@ -22,11 +22,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|
22
22
|
cache_indices_ptr, # conv_state_indices_ptr
|
23
23
|
has_initial_states_ptr,
|
24
24
|
query_start_loc_ptr,
|
25
|
-
batch_ptr,
|
26
|
-
token_chunk_offset_ptr,
|
27
25
|
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
28
26
|
# Matrix dimensions
|
29
|
-
batch: tl.int32, # actually padded_batch
|
30
27
|
dim: tl.constexpr,
|
31
28
|
seqlen: tl.int32, # cu_seqlen
|
32
29
|
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
@@ -69,11 +66,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|
69
66
|
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
70
67
|
|
71
68
|
# single-sequence id
|
72
|
-
idx_seq = tl.
|
73
|
-
chunk_offset = tl.
|
69
|
+
idx_seq = tl.program_id(0)
|
70
|
+
chunk_offset = tl.program_id(1)
|
74
71
|
|
75
72
|
# BLOCK_N elements along the feature-dimension (channel)
|
76
|
-
idx_feats = tl.program_id(
|
73
|
+
idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N)
|
77
74
|
|
78
75
|
if idx_seq == pad_slot_id:
|
79
76
|
return
|
@@ -86,6 +83,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|
86
83
|
token_offset = BLOCK_M * chunk_offset
|
87
84
|
segment_len = min(BLOCK_M, seqlen - token_offset)
|
88
85
|
|
86
|
+
if segment_len <= 0:
|
87
|
+
return
|
88
|
+
|
89
89
|
# base of the sequence
|
90
90
|
x_base = (
|
91
91
|
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
|
@@ -382,12 +382,13 @@ def causal_conv1d_fn(
|
|
382
382
|
bias: Union[torch.Tensor, None],
|
383
383
|
conv_states: torch.Tensor,
|
384
384
|
query_start_loc: torch.Tensor,
|
385
|
+
seq_lens_cpu: List[int],
|
385
386
|
cache_indices: Optional[torch.Tensor] = None,
|
386
387
|
has_initial_state: Optional[torch.Tensor] = None,
|
387
388
|
activation: Optional[str] = "silu",
|
388
389
|
pad_slot_id: int = PAD_SLOT_ID,
|
389
|
-
metadata=None,
|
390
390
|
validate_data=False,
|
391
|
+
**kwargs,
|
391
392
|
):
|
392
393
|
"""support varlen + continuous batching when x is 2D tensor
|
393
394
|
|
@@ -413,6 +414,8 @@ def causal_conv1d_fn(
|
|
413
414
|
[length(query_start_loc)-1 == batch]
|
414
415
|
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
415
416
|
x.shape=(dim,17)
|
417
|
+
seq_lens_cpu: (batch) int32
|
418
|
+
The sequence lengths of the sequences in the batch
|
416
419
|
cache_indices: (batch) int32
|
417
420
|
indicates the corresponding state index,
|
418
421
|
like so: conv_state = conv_states[cache_indices[batch_id]]
|
@@ -434,26 +437,7 @@ def causal_conv1d_fn(
|
|
434
437
|
if isinstance(activation, bool) and activation:
|
435
438
|
activation = "silu"
|
436
439
|
|
437
|
-
args = None
|
438
440
|
out = torch.empty_like(x)
|
439
|
-
if metadata is not None:
|
440
|
-
cu_seqlen = metadata.cu_seqlen
|
441
|
-
nums_dict = metadata.nums_dict
|
442
|
-
# x = metadata.x
|
443
|
-
args = nums_dict
|
444
|
-
batch_ptr = metadata.batch_ptr
|
445
|
-
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
446
|
-
else:
|
447
|
-
seqlens = np.diff(query_start_loc.to("cpu"))
|
448
|
-
args = seqlens
|
449
|
-
MAX_NUM_PROGRAMS = 1024
|
450
|
-
|
451
|
-
batch_ptr = torch.full(
|
452
|
-
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
453
|
-
) # tracking which seq-idx the Triton program is handling
|
454
|
-
token_chunk_offset_ptr = torch.full(
|
455
|
-
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
456
|
-
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
|
457
441
|
|
458
442
|
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
459
443
|
dim, cu_seqlen = x.shape
|
@@ -461,7 +445,6 @@ def causal_conv1d_fn(
|
|
461
445
|
state_len = width - 1
|
462
446
|
np2_statelen = triton.next_power_of_2(state_len)
|
463
447
|
|
464
|
-
padded_batch = query_start_loc.size(0) - 1
|
465
448
|
stride_x_seq = 0
|
466
449
|
stride_x_dim = x.stride(0)
|
467
450
|
stride_x_token = x.stride(1)
|
@@ -501,6 +484,7 @@ def causal_conv1d_fn(
|
|
501
484
|
assert query_start_loc is not None
|
502
485
|
assert query_start_loc.dim() == 1
|
503
486
|
assert x.stride(0) == 1 or x.stride(1) == 1
|
487
|
+
padded_batch = query_start_loc.size(0) - 1
|
504
488
|
if bias is not None:
|
505
489
|
assert bias.dim() == 1
|
506
490
|
assert dim == bias.size(0)
|
@@ -516,78 +500,14 @@ def causal_conv1d_fn(
|
|
516
500
|
assert (dim, width) == weight.shape
|
517
501
|
assert is_channel_last, "Need to run in channel-last layout"
|
518
502
|
|
519
|
-
if metadata is None:
|
520
|
-
|
521
|
-
def num_program(META, seqlens):
|
522
|
-
tot = 0
|
523
|
-
|
524
|
-
mlist = []
|
525
|
-
offsetlist = [] # type: ignore
|
526
|
-
|
527
|
-
nums = -(-seqlens // META["BLOCK_M"])
|
528
|
-
|
529
|
-
tot = nums.sum().item()
|
530
|
-
mlist = np.repeat(np.arange(len(nums)), nums)
|
531
|
-
for idx, num in enumerate(nums):
|
532
|
-
offsetlist.extend(
|
533
|
-
range(num)
|
534
|
-
) # chunk-idx if a sequence is split into multiple chunks
|
535
|
-
|
536
|
-
if META["batch_ptr"].nelement() < len(mlist):
|
537
|
-
newlen = len(mlist) + 1
|
538
|
-
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
539
|
-
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
540
|
-
|
541
|
-
if META["batch_ptr"].nelement() >= len(mlist):
|
542
|
-
META["batch_ptr"][0 : len(mlist)].copy_(
|
543
|
-
torch.from_numpy(np.array(mlist))
|
544
|
-
)
|
545
|
-
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
|
546
|
-
torch.from_numpy(np.array(offsetlist))
|
547
|
-
)
|
548
|
-
|
549
|
-
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
|
550
|
-
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
|
551
|
-
META["x_ptr"].device
|
552
|
-
)
|
553
|
-
return tot
|
554
|
-
|
555
|
-
else:
|
556
|
-
|
557
|
-
def num_program(META, nums_dict):
|
558
|
-
tot = nums_dict[META["BLOCK_M"]]["tot"]
|
559
|
-
|
560
|
-
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
|
561
|
-
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
|
562
|
-
|
563
|
-
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
|
564
|
-
|
565
|
-
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
|
566
|
-
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
|
567
|
-
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
|
568
|
-
"token_chunk_offset_ptr"
|
569
|
-
]
|
570
|
-
else:
|
571
|
-
if META["batch_ptr"].nelement() < mlist_len:
|
572
|
-
newlen = mlist_len + 1
|
573
|
-
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
574
|
-
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
575
|
-
|
576
|
-
if META["batch_ptr"].nelement() >= mlist_len:
|
577
|
-
META["batch_ptr"][0:mlist_len].copy_(mlist)
|
578
|
-
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
|
579
|
-
return tot
|
580
|
-
|
581
503
|
def grid(META):
|
504
|
+
max_seq_len = max(seq_lens_cpu)
|
582
505
|
return (
|
583
|
-
|
506
|
+
len(seq_lens_cpu), # batch_size
|
507
|
+
(max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"],
|
584
508
|
triton.cdiv(dim, META["BLOCK_N"]),
|
585
509
|
)
|
586
510
|
|
587
|
-
if batch_ptr.device != x.device:
|
588
|
-
batch_ptr = batch_ptr.to(x.device)
|
589
|
-
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
|
590
|
-
|
591
511
|
_causal_conv1d_fwd_kernel[grid](
|
592
512
|
# Pointers to matrices
|
593
513
|
x,
|
@@ -597,11 +517,8 @@ def causal_conv1d_fn(
|
|
597
517
|
cache_indices,
|
598
518
|
has_initial_state,
|
599
519
|
query_start_loc,
|
600
|
-
batch_ptr,
|
601
|
-
token_chunk_offset_ptr,
|
602
520
|
out,
|
603
521
|
# Matrix dimensions
|
604
|
-
padded_batch,
|
605
522
|
dim,
|
606
523
|
cu_seqlen,
|
607
524
|
num_cache_lines,
|