sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from functools import lru_cache
|
3
4
|
from typing import Optional
|
4
5
|
|
5
6
|
import torch
|
@@ -18,6 +19,7 @@ from sglang.srt.layers.linear import (
|
|
18
19
|
RowParallelLinear,
|
19
20
|
)
|
20
21
|
from sglang.srt.layers.quantization import QuantizationConfig
|
22
|
+
from sglang.srt.utils import add_prefix
|
21
23
|
|
22
24
|
|
23
25
|
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
@@ -121,20 +123,20 @@ class VisionAttention(nn.Module):
|
|
121
123
|
head_size=self.head_size,
|
122
124
|
total_num_heads=num_heads,
|
123
125
|
quant_config=quant_config,
|
124
|
-
prefix=
|
126
|
+
prefix=add_prefix("qkv_proj", prefix),
|
125
127
|
)
|
126
128
|
else:
|
127
129
|
self.qkv_proj = ColumnParallelLinear(
|
128
130
|
input_size=embed_dim,
|
129
131
|
output_size=3 * projection_size,
|
130
132
|
quant_config=quant_config,
|
131
|
-
prefix=
|
133
|
+
prefix=add_prefix("qkv_proj", prefix),
|
132
134
|
)
|
133
135
|
self.proj = RowParallelLinear(
|
134
136
|
input_size=embed_dim,
|
135
137
|
output_size=embed_dim,
|
136
138
|
quant_config=quant_config,
|
137
|
-
prefix=
|
139
|
+
prefix=add_prefix("out_proj", prefix),
|
138
140
|
)
|
139
141
|
|
140
142
|
def forward(
|
@@ -223,9 +225,6 @@ class VisionSdpaAttention(nn.Module):
|
|
223
225
|
|
224
226
|
"""
|
225
227
|
|
226
|
-
# TODO: Should it be released after used?
|
227
|
-
_mask_cache = {}
|
228
|
-
|
229
228
|
def __init__(
|
230
229
|
self,
|
231
230
|
head_size: int,
|
@@ -239,75 +238,61 @@ class VisionSdpaAttention(nn.Module):
|
|
239
238
|
self.use_full_precision_softmax = use_full_precision_softmax
|
240
239
|
self.dropout = dropout
|
241
240
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
dtype=torch.bfloat16,
|
250
|
-
) -> torch.Tensor:
|
251
|
-
r"""
|
252
|
-
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
253
|
-
|
254
|
-
When `flatten_batch` is True:
|
255
|
-
- All sequences in the batch are flattened into a single dimension
|
256
|
-
- `s` represents the total number of tokens across all sequences in the batch
|
257
|
-
- Returns a unified mask of shape `(1, 1, s, s)`
|
258
|
-
|
259
|
-
When `flatten_batch` is False:
|
260
|
-
- Each sequence has its own attention mask
|
261
|
-
- `s` represents the maximum sequence length in the batch
|
262
|
-
- Returns separate masks of shape `(b, 1, s, s)`
|
263
|
-
|
241
|
+
@staticmethod
|
242
|
+
@lru_cache(maxsize=128)
|
243
|
+
def _generate_mask_cache(
|
244
|
+
s: int, flatten_batch: bool, cu_seqlens: tuple
|
245
|
+
) -> torch.BoolTensor:
|
246
|
+
"""
|
247
|
+
Generate a boolean attention mask with caching mechanism.
|
264
248
|
Args:
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
249
|
+
s: sequence length
|
250
|
+
flatten_batch: whether to flatten batch dimension
|
251
|
+
cu_seqlens: tuple of cumulative sequence lengths
|
269
252
|
Returns:
|
270
|
-
|
253
|
+
attention mask tensor
|
271
254
|
"""
|
272
|
-
|
273
|
-
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
274
|
-
|
275
|
-
if cache_key in VisionSdpaAttention._mask_cache:
|
276
|
-
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
277
|
-
# print(f"cache hit for key: {cache_key}")
|
278
|
-
return cached_mask.to(device=device, dtype=dtype)
|
279
|
-
|
280
|
-
if cu_seqlens is None:
|
281
|
-
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
282
|
-
|
283
255
|
if flatten_batch:
|
284
|
-
mask = torch.zeros([1, s, s],
|
256
|
+
mask = torch.zeros([1, s, s], dtype=torch.bool)
|
285
257
|
for i in range(1, len(cu_seqlens)):
|
286
258
|
start = cu_seqlens[i - 1]
|
287
259
|
end = cu_seqlens[i]
|
288
|
-
mask[
|
289
|
-
...,
|
290
|
-
start:end,
|
291
|
-
start:end,
|
292
|
-
] = True
|
260
|
+
mask[..., start:end, start:end] = True
|
293
261
|
else:
|
294
262
|
# [1, 1, 1, s]
|
295
|
-
row_indices = torch.arange(s
|
263
|
+
row_indices = torch.arange(s).view(1, 1, 1, s)
|
296
264
|
# [1, 1, s, 1]
|
297
|
-
col_indices = torch.arange(s
|
265
|
+
col_indices = torch.arange(s).view(1, 1, s, 1)
|
298
266
|
# [b, 1, 1, 1]
|
299
|
-
seq_lens = (
|
300
|
-
|
301
|
-
)
|
267
|
+
seq_lens = torch.tensor(
|
268
|
+
[end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],
|
269
|
+
).view(-1, 1, 1, 1)
|
302
270
|
|
303
271
|
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
304
272
|
|
305
|
-
|
306
|
-
|
273
|
+
return mask
|
274
|
+
|
275
|
+
def generate_patch_attention_mask(
|
276
|
+
self,
|
277
|
+
s: int,
|
278
|
+
cu_seqlens: Optional[torch.Tensor],
|
279
|
+
flatten_batch: bool = False,
|
280
|
+
) -> Optional[torch.Tensor]:
|
281
|
+
r"""
|
282
|
+
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
283
|
+
Args:
|
284
|
+
s: sequence length
|
285
|
+
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
286
|
+
flatten_batch: whether to flatten batch dimension
|
287
|
+
Returns:
|
288
|
+
attention mask tensor or None
|
289
|
+
"""
|
290
|
+
if cu_seqlens is None:
|
291
|
+
return None
|
307
292
|
|
308
|
-
|
293
|
+
cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())
|
309
294
|
|
310
|
-
return
|
295
|
+
return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)
|
311
296
|
|
312
297
|
def forward(
|
313
298
|
self,
|
@@ -330,15 +315,23 @@ class VisionSdpaAttention(nn.Module):
|
|
330
315
|
# [b, 1, s, s]
|
331
316
|
if attention_mask is None:
|
332
317
|
attention_mask = self.generate_patch_attention_mask(
|
333
|
-
s,
|
318
|
+
s, cu_seqlens, flatten_batch=self.flatten_batch
|
334
319
|
)
|
320
|
+
|
321
|
+
if attention_mask is None:
|
322
|
+
if self.use_full_precision_softmax:
|
323
|
+
raise RuntimeError("Empty attention mask")
|
324
|
+
else:
|
325
|
+
attention_mask = attention_mask.to(device=q.device)
|
326
|
+
|
335
327
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
336
|
-
|
328
|
+
|
337
329
|
if self.use_full_precision_softmax:
|
338
330
|
scale = self.head_size**-0.5
|
339
331
|
k_transposed = rearrange(k, "b h s d -> b h d s")
|
340
332
|
attn_weights = torch.matmul(q, k_transposed) * scale
|
341
333
|
del k, k_transposed
|
334
|
+
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
|
342
335
|
attn_weights = attn_weights + attention_mask
|
343
336
|
del attention_mask
|
344
337
|
# full-precision
|
@@ -354,7 +347,12 @@ class VisionSdpaAttention(nn.Module):
|
|
354
347
|
# SDPA
|
355
348
|
# [b, h, s, head_size]
|
356
349
|
output = F.scaled_dot_product_attention(
|
357
|
-
q,
|
350
|
+
q,
|
351
|
+
k,
|
352
|
+
v,
|
353
|
+
attn_mask=attention_mask,
|
354
|
+
dropout_p=self.dropout,
|
355
|
+
is_causal=False,
|
358
356
|
)
|
359
357
|
|
360
358
|
# [b, h, s, head_size] --> [b * s, h, head_size]
|
@@ -380,7 +378,6 @@ class VisionTritonAttention(nn.Module):
|
|
380
378
|
v: torch.Tensor,
|
381
379
|
_bsz: int,
|
382
380
|
cu_seqlens: Optional[torch.Tensor],
|
383
|
-
**kwargs,
|
384
381
|
) -> torch.Tensor:
|
385
382
|
r"""
|
386
383
|
Args:
|
@@ -1,6 +1,21 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
from typing import TYPE_CHECKING, Union
|
5
|
+
|
1
6
|
import torch
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
11
|
+
GroupCoordinator,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
get_tp_group,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
2
16
|
|
3
|
-
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
4
19
|
|
5
20
|
_ATTN_TP_GROUP = None
|
6
21
|
_ATTN_TP_RANK = None
|
@@ -69,3 +84,129 @@ def get_attention_dp_rank():
|
|
69
84
|
def get_attention_dp_size():
|
70
85
|
assert _DP_SIZE is not None, "dp attention not initialized!"
|
71
86
|
return _DP_SIZE
|
87
|
+
|
88
|
+
|
89
|
+
def get_dp_local_info(forward_batch: ForwardBatch):
|
90
|
+
dp_rank = get_attention_dp_rank()
|
91
|
+
|
92
|
+
if forward_batch.dp_local_start_pos is None:
|
93
|
+
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
94
|
+
if dp_rank == 0:
|
95
|
+
local_start_pos = torch.zeros_like(cumtokens[0])
|
96
|
+
else:
|
97
|
+
local_start_pos = cumtokens[dp_rank - 1]
|
98
|
+
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
|
99
|
+
|
100
|
+
forward_batch.dp_local_start_pos = local_start_pos
|
101
|
+
forward_batch.dp_local_num_tokens = local_num_tokens
|
102
|
+
|
103
|
+
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
|
104
|
+
|
105
|
+
|
106
|
+
@triton.jit
|
107
|
+
def memcpy_triton_kernel(
|
108
|
+
dst_ptr,
|
109
|
+
src_ptr,
|
110
|
+
offset_ptr,
|
111
|
+
sz_ptr,
|
112
|
+
offset_src,
|
113
|
+
chunk_size, # multiplied for offset and sz
|
114
|
+
BLOCK_SIZE: tl.constexpr,
|
115
|
+
):
|
116
|
+
pid = tl.program_id(axis=0).to(tl.int64)
|
117
|
+
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
|
118
|
+
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
|
119
|
+
|
120
|
+
start_index = pid * BLOCK_SIZE
|
121
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
122
|
+
mask = start_index + offs < sz
|
123
|
+
|
124
|
+
if offset_src:
|
125
|
+
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
|
126
|
+
tl.store(dst_ptr + start_index + offs, data, mask=mask)
|
127
|
+
else:
|
128
|
+
data = tl.load(src_ptr + start_index + offs, mask=mask)
|
129
|
+
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
|
130
|
+
|
131
|
+
|
132
|
+
def prod(x):
|
133
|
+
return functools.reduce(lambda a, b: a * b, x, 1)
|
134
|
+
|
135
|
+
|
136
|
+
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
137
|
+
max_size = min(src.numel(), dst.numel())
|
138
|
+
assert dim == 0, "dim != 0 unsupported"
|
139
|
+
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
|
140
|
+
chunk_size = prod(src.shape[1:])
|
141
|
+
BLOCK_SIZE = 8192
|
142
|
+
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
|
143
|
+
|
144
|
+
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
145
|
+
|
146
|
+
|
147
|
+
def dp_gather(
|
148
|
+
global_tokens: torch.Tensor,
|
149
|
+
local_tokens: torch.Tensor,
|
150
|
+
forward_batch: ForwardBatch,
|
151
|
+
layer_id: Union[str, int],
|
152
|
+
):
|
153
|
+
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
154
|
+
|
155
|
+
global_tokens.fill_(0)
|
156
|
+
assert local_tokens.is_contiguous()
|
157
|
+
assert global_tokens.is_contiguous()
|
158
|
+
if local_tokens.shape[0] > 0 and (
|
159
|
+
layer_id != "embedding" or get_attention_tp_rank() == 0
|
160
|
+
):
|
161
|
+
assert (
|
162
|
+
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
|
163
|
+
), "aliasing between global_tokens and local_tokens not allowed"
|
164
|
+
memcpy_triton(
|
165
|
+
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
166
|
+
)
|
167
|
+
|
168
|
+
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
|
169
|
+
NUM_GPUS_PER_NODE = 8
|
170
|
+
if (
|
171
|
+
not local_tokens.dtype.is_floating_point
|
172
|
+
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
|
173
|
+
):
|
174
|
+
torch.ops.sglang.inplace_all_reduce(
|
175
|
+
global_tokens, group_name=get_tp_group().unique_name
|
176
|
+
)
|
177
|
+
else:
|
178
|
+
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
179
|
+
|
180
|
+
|
181
|
+
def dp_scatter(
|
182
|
+
local_tokens: torch.Tensor, # output
|
183
|
+
global_tokens: torch.Tensor, # input
|
184
|
+
forward_batch: ForwardBatch,
|
185
|
+
):
|
186
|
+
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
187
|
+
# since local_tokens may be padded for cuda graph
|
188
|
+
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
189
|
+
local_tokens.fill_(0)
|
190
|
+
assert local_tokens.is_contiguous()
|
191
|
+
assert global_tokens.is_contiguous()
|
192
|
+
if local_tokens.shape[0] > 0:
|
193
|
+
assert (
|
194
|
+
local_tokens.untyped_storage().data_ptr()
|
195
|
+
!= global_tokens.untyped_storage().data_ptr()
|
196
|
+
), "aliasing between local_tokens and global_tokens not allowed"
|
197
|
+
memcpy_triton(
|
198
|
+
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
199
|
+
)
|
200
|
+
|
201
|
+
|
202
|
+
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
203
|
+
def do_logits_dp_scatter(logits: torch.Tensor):
|
204
|
+
local_logits = torch.empty(
|
205
|
+
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
206
|
+
dtype=logits.dtype,
|
207
|
+
device=logits.device,
|
208
|
+
)
|
209
|
+
dp_scatter(local_logits, logits, forward_batch)
|
210
|
+
return local_logits
|
211
|
+
|
212
|
+
return do_logits_dp_scatter
|
sglang/srt/layers/layernorm.py
CHANGED
sglang/srt/layers/linear.py
CHANGED
@@ -38,6 +38,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
38
38
|
"AWQLinearMethod",
|
39
39
|
"GPTQMarlinLinearMethod",
|
40
40
|
"Fp8LinearMethod",
|
41
|
+
"BlockInt8LinearMethod",
|
41
42
|
"MarlinLinearMethod",
|
42
43
|
"QQQLinearMethod",
|
43
44
|
"GPTQMarlin24LinearMethod",
|
@@ -425,13 +426,14 @@ class ColumnParallelLinear(LinearBase):
|
|
425
426
|
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
426
427
|
|
427
428
|
if isinstance(param, _ColumnvLLMParameter):
|
428
|
-
# FIXME: why would we need this special case?
|
429
429
|
param.load_column_parallel_weight(
|
430
430
|
loaded_weight,
|
431
431
|
tp_rank=self.tp_rank,
|
432
432
|
use_presharded_weights=self.use_presharded_weights,
|
433
433
|
)
|
434
434
|
else:
|
435
|
+
# FIXME: This branch is needed to load deepseek v3 awq.
|
436
|
+
# However, we should fix this and avoid the branching here.
|
435
437
|
param.load_column_parallel_weight(loaded_weight)
|
436
438
|
|
437
439
|
def forward(self, input_):
|