sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/layers/activation.py
CHANGED
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
|
|
35
35
|
is_cuda,
|
36
36
|
is_hip,
|
37
37
|
is_npu,
|
38
|
+
is_xpu,
|
38
39
|
set_weight_attrs,
|
39
40
|
)
|
40
41
|
from sglang.utils import resolve_obj_by_qualname
|
@@ -44,8 +45,9 @@ _is_npu = is_npu()
|
|
44
45
|
_is_cpu_amx_available = cpu_has_amx_support()
|
45
46
|
_is_cpu = is_cpu()
|
46
47
|
_is_hip = is_hip()
|
48
|
+
_is_xpu = is_xpu()
|
47
49
|
|
48
|
-
if _is_cuda:
|
50
|
+
if _is_cuda or _is_xpu:
|
49
51
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
50
52
|
elif _is_hip:
|
51
53
|
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
|
|
70
72
|
|
71
73
|
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
72
74
|
if _is_cpu_amx_available:
|
73
|
-
d = x.shape[-1] // 2
|
74
|
-
output_shape = x.shape[:-1] + (d,)
|
75
75
|
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
76
76
|
return out
|
77
77
|
else:
|
@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
|
|
81
81
|
out = torch_npu.npu_swiglu(x)
|
82
82
|
return out
|
83
83
|
|
84
|
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
85
|
+
d = x.shape[-1] // 2
|
86
|
+
output_shape = x.shape[:-1] + (d,)
|
87
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
88
|
+
silu_and_mul(x, out)
|
89
|
+
return out
|
90
|
+
|
84
91
|
|
85
92
|
class GeluAndMul(CustomOp):
|
86
93
|
def __init__(self, approximate="tanh"):
|
87
94
|
super().__init__()
|
88
95
|
self.approximate = approximate
|
89
96
|
|
90
|
-
def
|
91
|
-
d = x.shape[-1] // 2
|
92
|
-
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
93
|
-
|
94
|
-
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
97
|
+
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
|
95
98
|
d = x.shape[-1] // 2
|
96
99
|
output_shape = x.shape[:-1] + (d,)
|
97
100
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
@@ -103,6 +106,24 @@ class GeluAndMul(CustomOp):
|
|
103
106
|
raise RuntimeError("GeluAndMul only support tanh or none")
|
104
107
|
return out
|
105
108
|
|
109
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
110
|
+
d = x.shape[-1] // 2
|
111
|
+
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
112
|
+
|
113
|
+
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
114
|
+
if _is_cpu_amx_available and self.approximate == "tanh":
|
115
|
+
return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
116
|
+
elif _is_cpu_amx_available and self.approximate == "none":
|
117
|
+
return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
118
|
+
else:
|
119
|
+
return self.forward_native(x)
|
120
|
+
|
121
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
122
|
+
return self._forward_impl(x)
|
123
|
+
|
124
|
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
125
|
+
return self._forward_impl(x)
|
126
|
+
|
106
127
|
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
107
128
|
y_npu, gelu_npu = torch_npu.npu_geglu(
|
108
129
|
x,
|
@@ -150,6 +171,115 @@ class QuickGELU(CustomOp):
|
|
150
171
|
return torch_npu.npu_fast_gelu(x)
|
151
172
|
|
152
173
|
|
174
|
+
class XIELU(CustomOp):
|
175
|
+
"""
|
176
|
+
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
|
177
|
+
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
|
178
|
+
Otherwise, we emit a single warning and use xIELU Python
|
179
|
+
"""
|
180
|
+
|
181
|
+
def __init__(
|
182
|
+
self,
|
183
|
+
alpha_p_init: float = 0.8,
|
184
|
+
alpha_n_init: float = 0.8,
|
185
|
+
beta: float = 0.5,
|
186
|
+
eps: float = -1e-6,
|
187
|
+
dtype: torch.dtype = torch.bfloat16,
|
188
|
+
with_vector_loads: bool = False,
|
189
|
+
):
|
190
|
+
super().__init__()
|
191
|
+
self.alpha_p = nn.Parameter(
|
192
|
+
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
|
193
|
+
0
|
194
|
+
)
|
195
|
+
)
|
196
|
+
self.alpha_n = nn.Parameter(
|
197
|
+
torch.log(
|
198
|
+
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
|
199
|
+
).unsqueeze(0)
|
200
|
+
)
|
201
|
+
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
|
202
|
+
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
|
203
|
+
self.with_vector_loads = with_vector_loads
|
204
|
+
# Temporary until xIELU CUDA fully implemented
|
205
|
+
self._beta_scalar = float(self.beta.detach().cpu().float().item())
|
206
|
+
self._eps_scalar = float(self.eps.detach().cpu().float().item())
|
207
|
+
|
208
|
+
self._xielu_cuda_obj = None
|
209
|
+
try:
|
210
|
+
import xielu.ops # noqa: F401
|
211
|
+
|
212
|
+
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
|
213
|
+
msg = "Using experimental xIELU CUDA."
|
214
|
+
try:
|
215
|
+
from torch._dynamo import allow_in_graph
|
216
|
+
|
217
|
+
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
|
218
|
+
msg += " Enabled torch._dynamo for xIELU CUDA."
|
219
|
+
except Exception as err:
|
220
|
+
msg += (
|
221
|
+
f" Could not enable torch._dynamo for xIELU ({err}) - "
|
222
|
+
"this may result in slower performance."
|
223
|
+
)
|
224
|
+
self._xielu_cuda_fn = self._xielu_cuda
|
225
|
+
logger.warning_once(msg)
|
226
|
+
except Exception as err:
|
227
|
+
logger.warning_once(
|
228
|
+
"CUDA-fused xIELU not available (%s) –"
|
229
|
+
" falling back to a Python version.\n"
|
230
|
+
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
|
231
|
+
str(err),
|
232
|
+
)
|
233
|
+
|
234
|
+
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
|
235
|
+
alpha_p = nn.functional.softplus(self.alpha_p)
|
236
|
+
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
|
237
|
+
return torch.where(
|
238
|
+
x > 0,
|
239
|
+
alpha_p * x * x + self.beta * x,
|
240
|
+
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
|
241
|
+
)
|
242
|
+
|
243
|
+
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
244
|
+
"""Firewall function to prevent torch.compile from seeing .item()"""
|
245
|
+
assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
|
246
|
+
original_shape = x.shape
|
247
|
+
# CUDA kernel expects 3D tensors, reshape if needed
|
248
|
+
while x.dim() < 3:
|
249
|
+
x = x.unsqueeze(0)
|
250
|
+
if x.dim() > 3:
|
251
|
+
x = x.view(-1, 1, x.size(-1))
|
252
|
+
if original_shape != x.shape:
|
253
|
+
logger.warning_once(
|
254
|
+
"Warning: xIELU input tensor expects 3 dimensions"
|
255
|
+
" but got (shape: %s). Reshaping to (shape: %s).\n"
|
256
|
+
"Note: For SGLang this may be expected if sending"
|
257
|
+
"[B*S,D] instead of [B,S,D].",
|
258
|
+
original_shape,
|
259
|
+
x.shape,
|
260
|
+
)
|
261
|
+
result = self._xielu_cuda_obj.forward(
|
262
|
+
x,
|
263
|
+
self.alpha_p,
|
264
|
+
self.alpha_n,
|
265
|
+
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
266
|
+
self._beta_scalar,
|
267
|
+
self._eps_scalar,
|
268
|
+
self.with_vector_loads,
|
269
|
+
)
|
270
|
+
return result.view(original_shape)
|
271
|
+
|
272
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
273
|
+
if self._xielu_cuda_obj is not None and input.is_cuda:
|
274
|
+
if not torch._dynamo.is_compiling():
|
275
|
+
return self._xielu_cuda_fn(input)
|
276
|
+
else:
|
277
|
+
logger.warning_once(
|
278
|
+
"torch._dynamo is compiling, using Python version of xIELU."
|
279
|
+
)
|
280
|
+
return self._xielu_python(input)
|
281
|
+
|
282
|
+
|
153
283
|
class ScaledActivation(nn.Module):
|
154
284
|
"""An activation function with post-scale parameters.
|
155
285
|
|
@@ -197,6 +327,7 @@ _ACTIVATION_REGISTRY = {
|
|
197
327
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
198
328
|
"gelu_new": NewGELU(),
|
199
329
|
"relu2": ReLU2(),
|
330
|
+
"xielu": XIELU(),
|
200
331
|
}
|
201
332
|
|
202
333
|
|
@@ -242,7 +373,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|
242
373
|
return nn.Identity()
|
243
374
|
|
244
375
|
|
245
|
-
if not (
|
376
|
+
if not (
|
377
|
+
_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
|
378
|
+
):
|
246
379
|
logger.info(
|
247
380
|
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
248
381
|
)
|
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
|
|
10
10
|
from sglang.srt.configs.model_config import AttentionArch
|
11
11
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
12
12
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
13
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
14
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
15
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
15
16
|
from sglang.srt.utils import get_bool_env_var
|
@@ -33,6 +34,7 @@ class ForwardMetadata:
|
|
33
34
|
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
34
35
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
35
36
|
seq_lens_cpu_list: Optional[List[int]] = None
|
37
|
+
seq_lens_list_cumsum: Optional[List[int]] = None
|
36
38
|
|
37
39
|
|
38
40
|
class AscendAttnBackend(AttentionBackend):
|
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
83
85
|
|
84
86
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
85
87
|
"""Init the metadata for a forward pass."""
|
88
|
+
tp_size = get_attention_tp_size()
|
86
89
|
self.forward_metadata = ForwardMetadata()
|
87
90
|
|
88
91
|
self.forward_metadata.block_tables = (
|
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
|
|
96
99
|
forward_batch.extend_seq_lens.cpu().int()
|
97
100
|
)
|
98
101
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
+
|
103
|
+
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
104
|
+
if forward_batch.is_extend_in_batch:
|
105
|
+
seq_lens_list_cumsum[-1] = (
|
106
|
+
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
107
|
+
) * tp_size
|
108
|
+
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
102
109
|
|
103
110
|
self.graph_mode = False
|
104
111
|
|
@@ -368,7 +375,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
368
375
|
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
|
369
376
|
)
|
370
377
|
|
371
|
-
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
|
378
|
+
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
|
372
379
|
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
|
373
380
|
if self.forward_metadata.seq_lens_cpu_int is None:
|
374
381
|
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
@@ -0,0 +1,242 @@
|
|
1
|
+
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
4
|
+
|
5
|
+
import warnings
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from einops import rearrange
|
10
|
+
|
11
|
+
from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
12
|
+
from sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o
|
13
|
+
from sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import (
|
14
|
+
chunk_scaled_dot_kkt_fwd,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
|
17
|
+
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
|
18
|
+
from sglang.srt.layers.attention.fla.solve_tril import solve_tril
|
19
|
+
from sglang.srt.layers.attention.fla.utils import (
|
20
|
+
SUPPRESS_LEVEL,
|
21
|
+
autocast_custom_fwd,
|
22
|
+
input_guard,
|
23
|
+
)
|
24
|
+
from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd
|
25
|
+
|
26
|
+
|
27
|
+
def chunk_gated_delta_rule_fwd(
|
28
|
+
q: torch.Tensor,
|
29
|
+
k: torch.Tensor,
|
30
|
+
v: torch.Tensor,
|
31
|
+
g: torch.Tensor,
|
32
|
+
beta: torch.Tensor,
|
33
|
+
scale: float,
|
34
|
+
initial_state: torch.Tensor,
|
35
|
+
output_final_state: bool,
|
36
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
37
|
+
):
|
38
|
+
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
39
|
+
# obtain WY representation. u is actually the new v.
|
40
|
+
A = chunk_scaled_dot_kkt_fwd(
|
41
|
+
k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
42
|
+
)
|
43
|
+
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
44
|
+
w, u = recompute_w_u_fwd(
|
45
|
+
k=k,
|
46
|
+
v=v,
|
47
|
+
beta=beta,
|
48
|
+
A=A,
|
49
|
+
g_cumsum=g,
|
50
|
+
cu_seqlens=cu_seqlens,
|
51
|
+
)
|
52
|
+
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
53
|
+
k=k,
|
54
|
+
w=w,
|
55
|
+
u=u,
|
56
|
+
g=g,
|
57
|
+
initial_state=initial_state,
|
58
|
+
output_final_state=output_final_state,
|
59
|
+
cu_seqlens=cu_seqlens,
|
60
|
+
)
|
61
|
+
o = chunk_fwd_o(
|
62
|
+
q=q,
|
63
|
+
k=k,
|
64
|
+
v=v_new,
|
65
|
+
h=h,
|
66
|
+
g=g,
|
67
|
+
scale=scale,
|
68
|
+
cu_seqlens=cu_seqlens,
|
69
|
+
)
|
70
|
+
if SUPPRESS_LEVEL < 3:
|
71
|
+
return g, o, A, final_state, None, None, None
|
72
|
+
elif SUPPRESS_LEVEL >= 3:
|
73
|
+
return g, o, A, final_state, w, h, v_new
|
74
|
+
|
75
|
+
|
76
|
+
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
77
|
+
|
78
|
+
@staticmethod
|
79
|
+
@input_guard
|
80
|
+
@autocast_custom_fwd
|
81
|
+
def forward(
|
82
|
+
ctx,
|
83
|
+
q: torch.Tensor,
|
84
|
+
k: torch.Tensor,
|
85
|
+
v: torch.Tensor,
|
86
|
+
g: torch.Tensor,
|
87
|
+
beta: torch.Tensor,
|
88
|
+
scale: float,
|
89
|
+
initial_state: torch.Tensor,
|
90
|
+
output_final_state: bool,
|
91
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
92
|
+
use_qk_l2norm_in_kernel: bool = False,
|
93
|
+
):
|
94
|
+
q_orig = q
|
95
|
+
k_orig = k
|
96
|
+
|
97
|
+
if use_qk_l2norm_in_kernel:
|
98
|
+
q = l2norm_fwd(q)
|
99
|
+
k = l2norm_fwd(k)
|
100
|
+
|
101
|
+
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
102
|
+
q=q,
|
103
|
+
k=k,
|
104
|
+
v=v,
|
105
|
+
g=g,
|
106
|
+
beta=beta,
|
107
|
+
scale=scale,
|
108
|
+
initial_state=initial_state,
|
109
|
+
output_final_state=output_final_state,
|
110
|
+
cu_seqlens=cu_seqlens,
|
111
|
+
)
|
112
|
+
return o.to(q.dtype), final_state
|
113
|
+
|
114
|
+
|
115
|
+
@torch.compiler.disable
|
116
|
+
def chunk_gated_delta_rule(
|
117
|
+
q: torch.Tensor,
|
118
|
+
k: torch.Tensor,
|
119
|
+
v: torch.Tensor,
|
120
|
+
g: torch.Tensor,
|
121
|
+
beta: torch.Tensor,
|
122
|
+
scale: float = None,
|
123
|
+
initial_state: torch.Tensor = None,
|
124
|
+
output_final_state: bool = False,
|
125
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
126
|
+
head_first: bool = False,
|
127
|
+
use_qk_l2norm_in_kernel: bool = False,
|
128
|
+
):
|
129
|
+
r"""
|
130
|
+
Args:
|
131
|
+
q (torch.Tensor):
|
132
|
+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
133
|
+
k (torch.Tensor):
|
134
|
+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
135
|
+
v (torch.Tensor):
|
136
|
+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
137
|
+
g (torch.Tensor):
|
138
|
+
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
139
|
+
beta (torch.Tensor):
|
140
|
+
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
141
|
+
scale (Optional[int]):
|
142
|
+
Scale factor for the RetNet attention scores.
|
143
|
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
144
|
+
initial_state (Optional[torch.Tensor]):
|
145
|
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
146
|
+
For equal-length input sequences, `N` equals the batch size `B`.
|
147
|
+
Default: `None`.
|
148
|
+
output_final_state (Optional[bool]):
|
149
|
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
150
|
+
cu_seqlens (torch.LongTensor):
|
151
|
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
152
|
+
consistent with the FlashAttention API.
|
153
|
+
head_first (Optional[bool]):
|
154
|
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
155
|
+
Default: `False`.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
o (torch.Tensor):
|
159
|
+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
160
|
+
final_state (torch.Tensor):
|
161
|
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
162
|
+
|
163
|
+
Examples::
|
164
|
+
>>> import torch
|
165
|
+
>>> import torch.nn.functional as F
|
166
|
+
>>> from einops import rearrange
|
167
|
+
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
168
|
+
# inputs with equal lengths
|
169
|
+
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
170
|
+
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
171
|
+
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
172
|
+
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
173
|
+
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
174
|
+
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
175
|
+
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
176
|
+
>>> o, ht = chunk_gated_delta_rule(
|
177
|
+
q, k, v, g, beta,
|
178
|
+
initial_state=h0,
|
179
|
+
output_final_state=True
|
180
|
+
)
|
181
|
+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
182
|
+
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
183
|
+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
184
|
+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
185
|
+
>>> o_var, ht_var = chunk_gated_delta_rule(
|
186
|
+
q, k, v, g, beta,
|
187
|
+
initial_state=h0,
|
188
|
+
output_final_state=True,
|
189
|
+
cu_seqlens=cu_seqlens
|
190
|
+
)
|
191
|
+
"""
|
192
|
+
assert q.dtype == k.dtype == v.dtype
|
193
|
+
assert (
|
194
|
+
q.dtype != torch.float32
|
195
|
+
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
196
|
+
assert (
|
197
|
+
len(beta.shape) == 3
|
198
|
+
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
199
|
+
|
200
|
+
if head_first:
|
201
|
+
raise DeprecationWarning(
|
202
|
+
"head_first is deprecated and will be removed in a future version. "
|
203
|
+
"Please use head_first=False for now instead."
|
204
|
+
)
|
205
|
+
q, k, v, beta, g = map(
|
206
|
+
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
|
207
|
+
)
|
208
|
+
# if not head_first and q.shape[1] < q.shape[2]:
|
209
|
+
# warnings.warn(
|
210
|
+
# f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
211
|
+
# "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
212
|
+
# "when head_first=False was specified. "
|
213
|
+
# "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
214
|
+
# )
|
215
|
+
if cu_seqlens is not None:
|
216
|
+
if q.shape[0] != 1:
|
217
|
+
raise ValueError(
|
218
|
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
219
|
+
f"Please flatten variable-length inputs before processing."
|
220
|
+
)
|
221
|
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
222
|
+
raise ValueError(
|
223
|
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
224
|
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
225
|
+
)
|
226
|
+
if scale is None:
|
227
|
+
scale = k.shape[-1] ** -0.5
|
228
|
+
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
229
|
+
q,
|
230
|
+
k,
|
231
|
+
v,
|
232
|
+
g,
|
233
|
+
beta,
|
234
|
+
scale,
|
235
|
+
initial_state,
|
236
|
+
output_final_state,
|
237
|
+
cu_seqlens,
|
238
|
+
use_qk_l2norm_in_kernel,
|
239
|
+
)
|
240
|
+
if head_first:
|
241
|
+
o = rearrange(o, "b t h ... -> b h t ...")
|
242
|
+
return o, final_state
|