sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
from einops import rearrange, repeat
|
8
|
+
|
9
|
+
from sglang.srt.distributed import parallel_state
|
10
|
+
from sglang.srt.distributed import utils as dist_utils
|
11
|
+
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
12
|
+
context_attention_fwd,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.linear import (
|
15
|
+
ColumnParallelLinear,
|
16
|
+
QKVParallelLinear,
|
17
|
+
RowParallelLinear,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
20
|
+
|
21
|
+
|
22
|
+
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
23
|
+
if not interleaved:
|
24
|
+
x1, x2 = x.chunk(2, dim=-1)
|
25
|
+
return torch.cat((-x2, x1), dim=-1)
|
26
|
+
else:
|
27
|
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
28
|
+
return rearrange(
|
29
|
+
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def apply_rotary_emb_torch(
|
34
|
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
35
|
+
) -> torch.Tensor:
|
36
|
+
"""
|
37
|
+
x: (batch_size, seqlen, nheads, headdim)
|
38
|
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
39
|
+
"""
|
40
|
+
ro_dim = cos.shape[-1] * 2
|
41
|
+
assert ro_dim <= x.shape[-1]
|
42
|
+
cos = repeat(
|
43
|
+
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
44
|
+
)
|
45
|
+
sin = repeat(
|
46
|
+
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
47
|
+
)
|
48
|
+
return torch.cat(
|
49
|
+
[
|
50
|
+
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
51
|
+
x[..., ro_dim:],
|
52
|
+
],
|
53
|
+
dim=-1,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
58
|
+
t_ = t.float()
|
59
|
+
cos = freqs.cos()
|
60
|
+
sin = freqs.sin()
|
61
|
+
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
62
|
+
return output
|
63
|
+
|
64
|
+
|
65
|
+
class VisionAttention(nn.Module):
|
66
|
+
"""Multi-headed attention without any cache, mostly used for ViT."""
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
embed_dim: int,
|
71
|
+
num_heads: int,
|
72
|
+
projection_size: int,
|
73
|
+
use_qkv_parallel: bool,
|
74
|
+
quant_config: Optional[QuantizationConfig] = None,
|
75
|
+
prefix: str = "",
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
79
|
+
|
80
|
+
self.hidden_size_per_attention_head = dist_utils.divide(
|
81
|
+
projection_size, num_heads
|
82
|
+
)
|
83
|
+
self.num_attention_heads_per_partition = dist_utils.divide(
|
84
|
+
num_heads, world_size
|
85
|
+
)
|
86
|
+
# self.tp_size = get_tensor_model_parallel_world_size()
|
87
|
+
# num_heads = self.num_heads_per_partition
|
88
|
+
self.use_qkv_parallel = use_qkv_parallel
|
89
|
+
if use_qkv_parallel:
|
90
|
+
self.head_dim = embed_dim // num_heads
|
91
|
+
self.qkv_proj = QKVParallelLinear(
|
92
|
+
hidden_size=embed_dim,
|
93
|
+
head_size=self.head_dim,
|
94
|
+
total_num_heads=num_heads,
|
95
|
+
quant_config=quant_config,
|
96
|
+
prefix=f"{prefix}.qkv_proj",
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
self.qkv_proj = ColumnParallelLinear(
|
100
|
+
input_size=embed_dim,
|
101
|
+
output_size=3 * projection_size,
|
102
|
+
quant_config=quant_config,
|
103
|
+
prefix=f"{prefix}.qkv_proj",
|
104
|
+
)
|
105
|
+
self.proj = RowParallelLinear(
|
106
|
+
input_size=embed_dim,
|
107
|
+
output_size=embed_dim,
|
108
|
+
quant_config=quant_config,
|
109
|
+
prefix=f"{prefix}.out_proj",
|
110
|
+
)
|
111
|
+
|
112
|
+
def forward(
|
113
|
+
self,
|
114
|
+
x: torch.Tensor,
|
115
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
116
|
+
rotary_pos_emb: torch.Tensor = None,
|
117
|
+
) -> torch.Tensor:
|
118
|
+
"""
|
119
|
+
Input shape: [b, s, embed_dim]
|
120
|
+
Output shape: [s, b, num_heads * head_size]
|
121
|
+
"""
|
122
|
+
|
123
|
+
bsz, s, _ = x.shape
|
124
|
+
if self.use_qkv_parallel:
|
125
|
+
# [b, s, embed_dim] --> [b, s, embed_dim]
|
126
|
+
qkv, _ = self.qkv_proj(x)
|
127
|
+
q, k, v = qkv.chunk(3, dim=-1)
|
128
|
+
|
129
|
+
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
|
130
|
+
q, k, v = [
|
131
|
+
x.reshape(
|
132
|
+
bsz * s, self.num_attention_heads_per_partition, -1
|
133
|
+
).contiguous()
|
134
|
+
for x in (q, k, v)
|
135
|
+
]
|
136
|
+
else:
|
137
|
+
# [b, s, embed_dim] --> [s, b, embed_dim]
|
138
|
+
x = rearrange(x, "b s ... -> s b ...")
|
139
|
+
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
|
140
|
+
qkv, _ = self.qkv_proj(x)
|
141
|
+
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
142
|
+
new_x_shape = qkv.size()[:-1] + (
|
143
|
+
self.num_attention_heads_per_partition,
|
144
|
+
3 * self.hidden_size_per_attention_head,
|
145
|
+
)
|
146
|
+
qkv = qkv.view(*new_x_shape)
|
147
|
+
|
148
|
+
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
149
|
+
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
150
|
+
|
151
|
+
# [s, b, head, head_dim] --> [b, s, head, head_dim]
|
152
|
+
q, k, v = [
|
153
|
+
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
154
|
+
]
|
155
|
+
|
156
|
+
if rotary_pos_emb is not None:
|
157
|
+
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
158
|
+
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
159
|
+
|
160
|
+
if self.use_qkv_parallel:
|
161
|
+
pass
|
162
|
+
else:
|
163
|
+
# [b, s, head, head_dim] --> [b * s, head, head_dim]
|
164
|
+
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
165
|
+
|
166
|
+
# [b * s, num_heads, head_size]
|
167
|
+
output = torch.empty_like(q)
|
168
|
+
|
169
|
+
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
170
|
+
max_seqlen = seq_lens.max().item()
|
171
|
+
|
172
|
+
context_attention_fwd(
|
173
|
+
q,
|
174
|
+
k,
|
175
|
+
v,
|
176
|
+
output,
|
177
|
+
cu_seqlens.cuda(),
|
178
|
+
seq_lens,
|
179
|
+
max_seqlen,
|
180
|
+
is_causal=False,
|
181
|
+
)
|
182
|
+
|
183
|
+
if self.use_qkv_parallel:
|
184
|
+
|
185
|
+
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
186
|
+
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
187
|
+
|
188
|
+
# [b, s, head, head_dim] --> [b, s, head, head_dim]
|
189
|
+
output, _ = self.proj(output)
|
190
|
+
else:
|
191
|
+
# [b * s, head, head_dim] --> [b, s, head, head_dim]
|
192
|
+
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
193
|
+
|
194
|
+
# [s, b, num_heads * head_size]
|
195
|
+
context_layer = rearrange(
|
196
|
+
context_layer, "b s h d -> s b (h d)"
|
197
|
+
).contiguous()
|
198
|
+
|
199
|
+
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
|
200
|
+
output, _ = self.proj(context_layer)
|
201
|
+
|
202
|
+
output = output.view(bsz, s, -1)
|
203
|
+
|
204
|
+
return output
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from sglang.srt.distributed import GroupCoordinator, get_tp_group
|
4
|
+
|
5
|
+
_ATTN_TP_GROUP = None
|
6
|
+
_ATTN_TP_RANK = None
|
7
|
+
_ATTN_TP_SIZE = None
|
8
|
+
_DP_RANK = None
|
9
|
+
_DP_SIZE = None
|
10
|
+
|
11
|
+
|
12
|
+
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
13
|
+
if not enable_dp_attention:
|
14
|
+
return tp_rank, tp_size, 0
|
15
|
+
|
16
|
+
attn_tp_size = tp_size // dp_size
|
17
|
+
dp_rank = tp_rank // attn_tp_size
|
18
|
+
attn_tp_rank = tp_rank % attn_tp_size
|
19
|
+
return attn_tp_rank, attn_tp_size, dp_rank
|
20
|
+
|
21
|
+
|
22
|
+
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
23
|
+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
24
|
+
|
25
|
+
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
26
|
+
|
27
|
+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
28
|
+
enable_dp_attention, tp_rank, tp_size, dp_size
|
29
|
+
)
|
30
|
+
_DP_SIZE = dp_size
|
31
|
+
|
32
|
+
tp_group = get_tp_group()
|
33
|
+
_ATTN_TP_GROUP = GroupCoordinator(
|
34
|
+
[
|
35
|
+
list(range(head, head + _ATTN_TP_SIZE))
|
36
|
+
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
37
|
+
],
|
38
|
+
tp_rank,
|
39
|
+
torch.distributed.get_backend(tp_group.device_group),
|
40
|
+
SYNC_TOKEN_IDS_ACROSS_TP,
|
41
|
+
False,
|
42
|
+
False,
|
43
|
+
False,
|
44
|
+
False,
|
45
|
+
group_name="attention_tp",
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
def get_attention_tp_group():
|
50
|
+
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
51
|
+
return _ATTN_TP_GROUP
|
52
|
+
|
53
|
+
|
54
|
+
def get_attention_tp_rank():
|
55
|
+
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
56
|
+
return _ATTN_TP_RANK
|
57
|
+
|
58
|
+
|
59
|
+
def get_attention_tp_size():
|
60
|
+
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
61
|
+
return _ATTN_TP_SIZE
|
62
|
+
|
63
|
+
|
64
|
+
def get_attention_dp_rank():
|
65
|
+
assert _DP_RANK is not None, "dp attention not initialized!"
|
66
|
+
return _DP_RANK
|
67
|
+
|
68
|
+
|
69
|
+
def get_attention_dp_size():
|
70
|
+
assert _DP_SIZE is not None, "dp attention not initialized!"
|
71
|
+
return _DP_SIZE
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
|
-
from sglang.srt.utils import
|
22
|
+
from sglang.srt.utils import is_cuda_available
|
23
23
|
|
24
|
-
if
|
25
|
-
from
|
24
|
+
if is_cuda_available():
|
25
|
+
from sgl_kernel import (
|
26
26
|
fused_add_rmsnorm,
|
27
27
|
gemma_fused_add_rmsnorm,
|
28
28
|
gemma_rmsnorm,
|
@@ -121,8 +121,8 @@ class GemmaRMSNorm(CustomOp):
|
|
121
121
|
return out
|
122
122
|
|
123
123
|
|
124
|
-
if not
|
124
|
+
if not is_cuda_available():
|
125
125
|
logger.info(
|
126
|
-
"
|
126
|
+
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
127
127
|
)
|
128
128
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
sglang/srt/layers/linear.py
CHANGED
@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
|
|
7
7
|
import torch
|
8
8
|
import torch.nn.functional as F
|
9
9
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
10
|
-
|
10
|
+
|
11
|
+
from sglang.srt.distributed import (
|
11
12
|
divide,
|
12
13
|
get_tensor_model_parallel_rank,
|
13
14
|
get_tensor_model_parallel_world_size,
|
@@ -15,10 +16,6 @@ from vllm.distributed import (
|
|
15
16
|
tensor_model_parallel_all_gather,
|
16
17
|
tensor_model_parallel_all_reduce,
|
17
18
|
)
|
18
|
-
|
19
|
-
# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
|
20
|
-
from vllm.model_executor.layers.linear import LinearBase
|
21
|
-
|
22
19
|
from sglang.srt.layers.parameter import (
|
23
20
|
BasevLLMParameter,
|
24
21
|
PackedColumnParameter,
|
@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
174
171
|
return F.linear(x, layer.weight, bias)
|
175
172
|
|
176
173
|
|
174
|
+
class LinearBase(torch.nn.Module):
|
175
|
+
"""Base linear layer.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
input_size: input dimension of the linear layer.
|
179
|
+
output_size: output dimension of the linear layer.
|
180
|
+
bias: If true, add bias.
|
181
|
+
skip_bias_add: If true, skip adding bias but instead return it.
|
182
|
+
params_dtype: Data type for the parameters.
|
183
|
+
quant_config: Quantization configure.
|
184
|
+
"""
|
185
|
+
|
186
|
+
def __init__(
|
187
|
+
self,
|
188
|
+
input_size: int,
|
189
|
+
output_size: int,
|
190
|
+
skip_bias_add: bool = False,
|
191
|
+
params_dtype: Optional[torch.dtype] = None,
|
192
|
+
quant_config: Optional[QuantizationConfig] = None,
|
193
|
+
prefix: str = "",
|
194
|
+
):
|
195
|
+
super().__init__()
|
196
|
+
|
197
|
+
# Keep input parameters
|
198
|
+
self.input_size = input_size
|
199
|
+
self.output_size = output_size
|
200
|
+
self.skip_bias_add = skip_bias_add
|
201
|
+
if params_dtype is None:
|
202
|
+
params_dtype = torch.get_default_dtype()
|
203
|
+
self.params_dtype = params_dtype
|
204
|
+
if quant_config is None:
|
205
|
+
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
|
206
|
+
else:
|
207
|
+
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
208
|
+
|
209
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
210
|
+
raise NotImplementedError
|
211
|
+
|
212
|
+
|
177
213
|
class ReplicatedLinear(LinearBase):
|
178
214
|
"""Replicated linear layer.
|
179
215
|
|
@@ -293,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
|
|
293
329
|
prefix: str = "",
|
294
330
|
tp_rank: Optional[int] = None,
|
295
331
|
tp_size: Optional[int] = None,
|
332
|
+
use_presharded_weights: bool = False,
|
296
333
|
):
|
297
334
|
super().__init__(
|
298
335
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
299
336
|
)
|
300
337
|
|
301
338
|
self.gather_output = gather_output
|
339
|
+
self.use_presharded_weights = use_presharded_weights
|
302
340
|
|
303
341
|
# Divide the weight matrix along the last dimension.
|
304
342
|
if tp_rank is None:
|
@@ -366,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
|
|
366
404
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
367
405
|
shard_size = param_data.shape[output_dim]
|
368
406
|
start_idx = self.tp_rank * shard_size
|
369
|
-
|
407
|
+
if not self.use_presharded_weights:
|
408
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
370
409
|
|
371
410
|
# Special case for loading scales off disk, which often do not
|
372
411
|
# have a shape (such as in the case of AutoFP8).
|
@@ -382,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
|
|
382
421
|
if len(loaded_weight.shape) == 0:
|
383
422
|
assert loaded_weight.numel() == 1
|
384
423
|
loaded_weight = loaded_weight.reshape(1)
|
385
|
-
param.load_column_parallel_weight(
|
424
|
+
param.load_column_parallel_weight(
|
425
|
+
loaded_weight,
|
426
|
+
tp_rank=self.tp_rank,
|
427
|
+
use_presharded_weights=self.use_presharded_weights,
|
428
|
+
)
|
386
429
|
|
387
430
|
def forward(self, input_):
|
388
431
|
bias = self.bias if not self.skip_bias_add else None
|
@@ -463,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
463
506
|
prefix=prefix,
|
464
507
|
tp_rank=tp_rank,
|
465
508
|
tp_size=tp_size,
|
509
|
+
use_presharded_weights=use_presharded_weights,
|
466
510
|
)
|
511
|
+
self.prefix = prefix
|
467
512
|
|
468
513
|
def weight_loader(
|
469
514
|
self,
|
@@ -707,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
707
752
|
prefix: str = "",
|
708
753
|
tp_rank: Optional[int] = None,
|
709
754
|
tp_size: Optional[int] = None,
|
755
|
+
load_presharded_attn: bool = False,
|
710
756
|
):
|
711
757
|
self.hidden_size = hidden_size
|
712
758
|
self.head_size = head_size
|
@@ -736,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
736
782
|
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
737
783
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
738
784
|
]
|
785
|
+
self.use_presharded_weights = load_presharded_attn
|
739
786
|
|
740
787
|
super().__init__(
|
741
788
|
input_size=input_size,
|
@@ -748,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
748
795
|
prefix=prefix,
|
749
796
|
tp_rank=tp_rank,
|
750
797
|
tp_size=tp_size,
|
798
|
+
use_presharded_weights=self.use_presharded_weights,
|
751
799
|
)
|
752
800
|
|
753
801
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
@@ -806,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
806
854
|
shard_size=shard_size, shard_offset=shard_offset
|
807
855
|
)
|
808
856
|
|
809
|
-
|
810
|
-
|
811
|
-
|
857
|
+
if not self.use_presharded_weights:
|
858
|
+
loaded_weight_shard = loaded_weight.narrow(
|
859
|
+
param.output_dim, shard_offset, shard_size
|
860
|
+
)
|
812
861
|
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
813
862
|
|
814
863
|
def weight_loader_v2(
|
@@ -846,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
846
895
|
shard_offset=shard_offset,
|
847
896
|
shard_size=shard_size,
|
848
897
|
tp_rank=self.tp_rank,
|
898
|
+
use_presharded_weights=self.use_presharded_weights,
|
849
899
|
)
|
850
900
|
|
851
901
|
def weight_loader(
|
@@ -951,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
951
1001
|
param, orig_qkv_offsets, shard_id
|
952
1002
|
)
|
953
1003
|
|
954
|
-
|
955
|
-
|
956
|
-
|
1004
|
+
if not self.use_presharded_weights:
|
1005
|
+
loaded_weight_shard = loaded_weight.narrow(
|
1006
|
+
output_dim, shard_offset, shard_size
|
1007
|
+
)
|
957
1008
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
958
1009
|
return
|
959
1010
|
|
@@ -1013,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
1013
1064
|
|
1014
1065
|
# bitsandbytes loads the weights of the specific portion
|
1015
1066
|
# no need to narrow here
|
1016
|
-
if not use_bitsandbytes_4bit:
|
1067
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
1017
1068
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
1018
1069
|
|
1019
1070
|
# Special case for for AQLM codebooks.
|
@@ -14,17 +14,18 @@
|
|
14
14
|
"""Logits processing."""
|
15
15
|
|
16
16
|
import dataclasses
|
17
|
+
import logging
|
17
18
|
from typing import List, Optional, Union
|
18
19
|
|
19
20
|
import torch
|
20
21
|
import triton
|
21
22
|
import triton.language as tl
|
22
23
|
from torch import nn
|
23
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
24
26
|
get_tensor_model_parallel_world_size,
|
25
27
|
tensor_model_parallel_all_gather,
|
26
28
|
)
|
27
|
-
|
28
29
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
29
30
|
from sglang.srt.model_executor.forward_batch_info import (
|
30
31
|
CaptureHiddenMode,
|
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
32
33
|
ForwardMode,
|
33
34
|
)
|
34
35
|
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
35
38
|
|
36
39
|
@dataclasses.dataclass
|
37
40
|
class LogitsProcessorOutput:
|
@@ -50,8 +53,6 @@ class LogitsProcessorOutput:
|
|
50
53
|
next_token_top_logprobs_idx: Optional[List] = None
|
51
54
|
|
52
55
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
53
|
-
# The normlaized logprobs of prompts. shape: [#seq]
|
54
|
-
normalized_prompt_logprobs: torch.Tensor = None
|
55
56
|
# The logprobs of input tokens. shape: [#token]
|
56
57
|
input_token_logprobs: torch.Tensor = None
|
57
58
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
@@ -129,59 +130,70 @@ class LogitsProcessor(nn.Module):
|
|
129
130
|
hidden_states,
|
130
131
|
lm_head: VocabParallelEmbedding,
|
131
132
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
132
|
-
):
|
133
|
+
) -> LogitsProcessorOutput:
|
133
134
|
if isinstance(logits_metadata, ForwardBatch):
|
134
135
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
135
136
|
|
136
137
|
# Get the last hidden states and last logits for the next token prediction
|
137
138
|
if (
|
138
|
-
logits_metadata.forward_mode.
|
139
|
+
logits_metadata.forward_mode.is_decode_or_idle()
|
139
140
|
or logits_metadata.forward_mode.is_target_verify()
|
140
141
|
):
|
141
|
-
|
142
|
-
|
143
|
-
|
142
|
+
pruned_states = hidden_states
|
143
|
+
sample_indices = None
|
144
|
+
elif (
|
145
|
+
logits_metadata.forward_mode.is_extend()
|
146
|
+
and not logits_metadata.extend_return_logprob
|
147
|
+
):
|
148
|
+
# Prefill without input logprobs.
|
144
149
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
145
|
-
|
150
|
+
pruned_states = hidden_states[last_index]
|
151
|
+
sample_indices = None
|
152
|
+
else:
|
153
|
+
# Slice the requested tokens to compute logprob
|
154
|
+
sample_index_pt = -1
|
155
|
+
sample_indices = []
|
156
|
+
pt, pruned_states, pruned_input_ids = 0, [], []
|
157
|
+
for start_len, extend_len in zip(
|
158
|
+
logits_metadata.extend_logprob_start_lens_cpu,
|
159
|
+
logits_metadata.extend_seq_lens_cpu,
|
160
|
+
):
|
161
|
+
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
162
|
+
sample_index_pt += extend_len - start_len
|
163
|
+
sample_indices.append(sample_index_pt)
|
164
|
+
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
165
|
+
pt += extend_len
|
166
|
+
|
167
|
+
pruned_states = torch.cat(pruned_states)
|
168
|
+
|
169
|
+
# Compute logits for both input and sampled tokens.
|
170
|
+
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
171
|
+
sampled_logits = (
|
172
|
+
logits[sample_indices] if sample_indices is not None else logits
|
173
|
+
)
|
146
174
|
|
147
|
-
# Compute logits
|
148
|
-
last_logits = self._get_logits(last_hidden, lm_head)
|
149
175
|
if (
|
150
176
|
not logits_metadata.extend_return_logprob
|
151
177
|
or logits_metadata.capture_hidden_mode.need_capture()
|
152
178
|
):
|
153
179
|
# Decode mode or extend mode without return_logprob.
|
154
180
|
return LogitsProcessorOutput(
|
155
|
-
next_token_logits=
|
181
|
+
next_token_logits=sampled_logits,
|
156
182
|
hidden_states=(
|
157
183
|
hidden_states
|
158
184
|
if logits_metadata.capture_hidden_mode.is_full()
|
159
185
|
else (
|
160
|
-
|
186
|
+
pruned_states
|
161
187
|
if logits_metadata.capture_hidden_mode.is_last()
|
162
188
|
else None
|
163
189
|
)
|
164
190
|
),
|
165
191
|
)
|
166
192
|
else:
|
167
|
-
|
168
|
-
|
169
|
-
for start_len, extend_len in zip(
|
170
|
-
logits_metadata.extend_logprob_start_lens_cpu,
|
171
|
-
logits_metadata.extend_seq_lens_cpu,
|
172
|
-
):
|
173
|
-
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
174
|
-
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
175
|
-
pt += extend_len
|
176
|
-
|
177
|
-
# Compute the logits of all required tokens
|
178
|
-
pruned_states = torch.cat(pruned_states)
|
179
|
-
del hidden_states
|
180
|
-
input_token_logits = self._get_logits(pruned_states, lm_head)
|
181
|
-
del pruned_states
|
193
|
+
input_logprobs = logits
|
194
|
+
del hidden_states, logits
|
182
195
|
|
183
196
|
# Normalize the logprob w/o temperature, top-p
|
184
|
-
input_logprobs = input_token_logits
|
185
197
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
186
198
|
input_logprobs, logits_metadata
|
187
199
|
)
|
@@ -195,25 +207,18 @@ class LogitsProcessor(nn.Module):
|
|
195
207
|
else:
|
196
208
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
197
209
|
|
198
|
-
# Compute the normalized logprobs for the requested tokens.
|
199
|
-
# Note that we pad a zero at the end for easy batching.
|
200
210
|
input_token_logprobs = input_logprobs[
|
201
|
-
torch.arange(input_logprobs.shape[0], device=
|
211
|
+
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
202
212
|
torch.cat(
|
203
213
|
[
|
204
214
|
torch.cat(pruned_input_ids)[1:],
|
205
|
-
torch.tensor([0], device=
|
215
|
+
torch.tensor([0], device=input_logprobs.device),
|
206
216
|
]
|
207
217
|
),
|
208
218
|
]
|
209
|
-
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
210
|
-
input_token_logprobs,
|
211
|
-
logits_metadata,
|
212
|
-
)
|
213
219
|
|
214
220
|
return LogitsProcessorOutput(
|
215
|
-
next_token_logits=
|
216
|
-
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
221
|
+
next_token_logits=sampled_logits,
|
217
222
|
input_token_logprobs=input_token_logprobs,
|
218
223
|
input_top_logprobs_val=input_top_logprobs_val,
|
219
224
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
@@ -223,8 +228,11 @@ class LogitsProcessor(nn.Module):
|
|
223
228
|
self,
|
224
229
|
hidden_states: torch.Tensor,
|
225
230
|
lm_head: VocabParallelEmbedding,
|
231
|
+
logits_metadata: LogitsMetadata,
|
226
232
|
embedding_bias: Optional[torch.Tensor] = None,
|
227
233
|
) -> torch.Tensor:
|
234
|
+
"""Get logits from hidden_states."""
|
235
|
+
|
228
236
|
if hasattr(lm_head, "weight"):
|
229
237
|
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
230
238
|
else:
|
@@ -237,8 +245,6 @@ class LogitsProcessor(nn.Module):
|
|
237
245
|
if self.do_tensor_parallel_all_gather:
|
238
246
|
logits = tensor_model_parallel_all_gather(logits)
|
239
247
|
|
240
|
-
# Compute the normalized logprobs for the requested tokens.
|
241
|
-
# Note that we pad a zero at the end for easy batching.
|
242
248
|
logits = logits[:, : self.config.vocab_size].float()
|
243
249
|
|
244
250
|
if self.final_logit_softcapping:
|
@@ -246,27 +252,6 @@ class LogitsProcessor(nn.Module):
|
|
246
252
|
|
247
253
|
return logits
|
248
254
|
|
249
|
-
@staticmethod
|
250
|
-
def _get_normalized_prompt_logprobs(
|
251
|
-
input_token_logprobs: torch.Tensor,
|
252
|
-
logits_metadata: LogitsMetadata,
|
253
|
-
):
|
254
|
-
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
255
|
-
pruned_lens = torch.tensor(
|
256
|
-
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
257
|
-
)
|
258
|
-
|
259
|
-
start = torch.zeros_like(pruned_lens)
|
260
|
-
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
261
|
-
end = torch.clamp(
|
262
|
-
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
263
|
-
)
|
264
|
-
sum_logp = (
|
265
|
-
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
266
|
-
)
|
267
|
-
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
268
|
-
return normalized_prompt_logprobs
|
269
|
-
|
270
255
|
@staticmethod
|
271
256
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
272
257
|
max_k = max(logits_metadata.top_logprobs_nums)
|
@@ -311,7 +296,7 @@ def fused_softcap_kernel(
|
|
311
296
|
n_elements,
|
312
297
|
BLOCK_SIZE: tl.constexpr,
|
313
298
|
):
|
314
|
-
pid = tl.program_id(0)
|
299
|
+
pid = tl.program_id(0).to(tl.int64)
|
315
300
|
block_start = pid * BLOCK_SIZE
|
316
301
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
317
302
|
mask = offsets < n_elements
|