sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +41 -5
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
- sglang/srt/layers/parameter.py +2 -1
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/fp8.py +6 -3
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +25 -2
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +277 -178
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +206 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +37 -15
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/sampling_batch_info.py +139 -4
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +57 -14
- sglang/srt/utils.py +103 -65
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
from einops import rearrange, repeat
|
8
|
+
|
9
|
+
from sglang.srt.distributed import parallel_state
|
10
|
+
from sglang.srt.distributed import utils as dist_utils
|
11
|
+
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
12
|
+
context_attention_fwd,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.linear import (
|
15
|
+
ColumnParallelLinear,
|
16
|
+
QKVParallelLinear,
|
17
|
+
RowParallelLinear,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
20
|
+
|
21
|
+
|
22
|
+
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
23
|
+
if not interleaved:
|
24
|
+
x1, x2 = x.chunk(2, dim=-1)
|
25
|
+
return torch.cat((-x2, x1), dim=-1)
|
26
|
+
else:
|
27
|
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
28
|
+
return rearrange(
|
29
|
+
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def apply_rotary_emb_torch(
|
34
|
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
35
|
+
) -> torch.Tensor:
|
36
|
+
"""
|
37
|
+
x: (batch_size, seqlen, nheads, headdim)
|
38
|
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
39
|
+
"""
|
40
|
+
ro_dim = cos.shape[-1] * 2
|
41
|
+
assert ro_dim <= x.shape[-1]
|
42
|
+
cos = repeat(
|
43
|
+
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
44
|
+
)
|
45
|
+
sin = repeat(
|
46
|
+
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
47
|
+
)
|
48
|
+
return torch.cat(
|
49
|
+
[
|
50
|
+
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
51
|
+
x[..., ro_dim:],
|
52
|
+
],
|
53
|
+
dim=-1,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
58
|
+
t_ = t.float()
|
59
|
+
cos = freqs.cos()
|
60
|
+
sin = freqs.sin()
|
61
|
+
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
62
|
+
return output
|
63
|
+
|
64
|
+
|
65
|
+
class VisionAttention(nn.Module):
|
66
|
+
"""Multi-headed attention without any cache, mostly used for ViT."""
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
embed_dim: int,
|
71
|
+
num_heads: int,
|
72
|
+
projection_size: int,
|
73
|
+
use_qkv_parallel: bool,
|
74
|
+
quant_config: Optional[QuantizationConfig] = None,
|
75
|
+
prefix: str = "",
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
79
|
+
|
80
|
+
self.hidden_size_per_attention_head = dist_utils.divide(
|
81
|
+
projection_size, num_heads
|
82
|
+
)
|
83
|
+
self.num_attention_heads_per_partition = dist_utils.divide(
|
84
|
+
num_heads, world_size
|
85
|
+
)
|
86
|
+
# self.tp_size = get_tensor_model_parallel_world_size()
|
87
|
+
# num_heads = self.num_heads_per_partition
|
88
|
+
self.use_qkv_parallel = use_qkv_parallel
|
89
|
+
if use_qkv_parallel:
|
90
|
+
self.head_dim = embed_dim // num_heads
|
91
|
+
self.qkv_proj = QKVParallelLinear(
|
92
|
+
hidden_size=embed_dim,
|
93
|
+
head_size=self.head_dim,
|
94
|
+
total_num_heads=num_heads,
|
95
|
+
quant_config=quant_config,
|
96
|
+
prefix=f"{prefix}.qkv_proj",
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
self.qkv_proj = ColumnParallelLinear(
|
100
|
+
input_size=embed_dim,
|
101
|
+
output_size=3 * projection_size,
|
102
|
+
quant_config=quant_config,
|
103
|
+
prefix=f"{prefix}.qkv_proj",
|
104
|
+
)
|
105
|
+
self.proj = RowParallelLinear(
|
106
|
+
input_size=embed_dim,
|
107
|
+
output_size=embed_dim,
|
108
|
+
quant_config=quant_config,
|
109
|
+
prefix=f"{prefix}.out_proj",
|
110
|
+
)
|
111
|
+
|
112
|
+
def forward(
|
113
|
+
self,
|
114
|
+
x: torch.Tensor,
|
115
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
116
|
+
rotary_pos_emb: torch.Tensor = None,
|
117
|
+
) -> torch.Tensor:
|
118
|
+
"""
|
119
|
+
Input shape: [b, s, embed_dim]
|
120
|
+
Output shape: [s, b, num_heads * head_size]
|
121
|
+
"""
|
122
|
+
|
123
|
+
bsz, s, _ = x.shape
|
124
|
+
if self.use_qkv_parallel:
|
125
|
+
# [b, s, embed_dim] --> [b, s, embed_dim]
|
126
|
+
qkv, _ = self.qkv_proj(x)
|
127
|
+
q, k, v = qkv.chunk(3, dim=-1)
|
128
|
+
|
129
|
+
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
|
130
|
+
q, k, v = [
|
131
|
+
x.reshape(
|
132
|
+
bsz * s, self.num_attention_heads_per_partition, -1
|
133
|
+
).contiguous()
|
134
|
+
for x in (q, k, v)
|
135
|
+
]
|
136
|
+
else:
|
137
|
+
# [b, s, embed_dim] --> [s, b, embed_dim]
|
138
|
+
x = rearrange(x, "b s ... -> s b ...")
|
139
|
+
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
|
140
|
+
qkv, _ = self.qkv_proj(x)
|
141
|
+
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
142
|
+
new_x_shape = qkv.size()[:-1] + (
|
143
|
+
self.num_attention_heads_per_partition,
|
144
|
+
3 * self.hidden_size_per_attention_head,
|
145
|
+
)
|
146
|
+
qkv = qkv.view(*new_x_shape)
|
147
|
+
|
148
|
+
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
149
|
+
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
150
|
+
|
151
|
+
# [s, b, head, head_dim] --> [b, s, head, head_dim]
|
152
|
+
q, k, v = [
|
153
|
+
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
154
|
+
]
|
155
|
+
|
156
|
+
if rotary_pos_emb is not None:
|
157
|
+
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
158
|
+
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
159
|
+
|
160
|
+
if self.use_qkv_parallel:
|
161
|
+
pass
|
162
|
+
else:
|
163
|
+
# [b, s, head, head_dim] --> [b * s, head, head_dim]
|
164
|
+
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
165
|
+
|
166
|
+
# [b * s, num_heads, head_size]
|
167
|
+
output = torch.empty_like(q)
|
168
|
+
|
169
|
+
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
170
|
+
max_seqlen = seq_lens.max().item()
|
171
|
+
|
172
|
+
context_attention_fwd(
|
173
|
+
q,
|
174
|
+
k,
|
175
|
+
v,
|
176
|
+
output,
|
177
|
+
cu_seqlens.cuda(),
|
178
|
+
seq_lens,
|
179
|
+
max_seqlen,
|
180
|
+
is_causal=False,
|
181
|
+
)
|
182
|
+
|
183
|
+
if self.use_qkv_parallel:
|
184
|
+
|
185
|
+
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
186
|
+
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
187
|
+
|
188
|
+
# [b, s, head, head_dim] --> [b, s, head, head_dim]
|
189
|
+
output, _ = self.proj(output)
|
190
|
+
else:
|
191
|
+
# [b * s, head, head_dim] --> [b, s, head, head_dim]
|
192
|
+
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
193
|
+
|
194
|
+
# [s, b, num_heads * head_size]
|
195
|
+
context_layer = rearrange(
|
196
|
+
context_layer, "b s h d -> s b (h d)"
|
197
|
+
).contiguous()
|
198
|
+
|
199
|
+
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
|
200
|
+
output, _ = self.proj(context_layer)
|
201
|
+
|
202
|
+
output = output.view(bsz, s, -1)
|
203
|
+
|
204
|
+
return output
|
@@ -0,0 +1,69 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from sglang.srt.distributed import GroupCoordinator, get_tp_group
|
4
|
+
|
5
|
+
_ATTN_TP_GROUP = None
|
6
|
+
_ATTN_TP_RANK = None
|
7
|
+
_ATTN_TP_SIZE = None
|
8
|
+
_DP_RANK = None
|
9
|
+
_DP_SIZE = None
|
10
|
+
|
11
|
+
|
12
|
+
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
13
|
+
if not enable_dp_attention:
|
14
|
+
return tp_rank, tp_size, 0
|
15
|
+
|
16
|
+
attn_tp_size = tp_size // dp_size
|
17
|
+
dp_rank = tp_rank // attn_tp_size
|
18
|
+
attn_tp_rank = tp_rank % attn_tp_size
|
19
|
+
return attn_tp_rank, attn_tp_size, dp_rank
|
20
|
+
|
21
|
+
|
22
|
+
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
23
|
+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
24
|
+
|
25
|
+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
26
|
+
enable_dp_attention, tp_rank, tp_size, dp_size
|
27
|
+
)
|
28
|
+
_DP_SIZE = dp_size
|
29
|
+
|
30
|
+
tp_group = get_tp_group()
|
31
|
+
_ATTN_TP_GROUP = GroupCoordinator(
|
32
|
+
[
|
33
|
+
list(range(head, head + _ATTN_TP_SIZE))
|
34
|
+
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
35
|
+
],
|
36
|
+
tp_rank,
|
37
|
+
torch.distributed.get_backend(tp_group.device_group),
|
38
|
+
False,
|
39
|
+
False,
|
40
|
+
False,
|
41
|
+
False,
|
42
|
+
False,
|
43
|
+
group_name="attention_tp",
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
def get_attention_tp_group():
|
48
|
+
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
49
|
+
return _ATTN_TP_GROUP
|
50
|
+
|
51
|
+
|
52
|
+
def get_attention_tp_rank():
|
53
|
+
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
54
|
+
return _ATTN_TP_RANK
|
55
|
+
|
56
|
+
|
57
|
+
def get_attention_tp_size():
|
58
|
+
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
59
|
+
return _ATTN_TP_SIZE
|
60
|
+
|
61
|
+
|
62
|
+
def get_attention_dp_rank():
|
63
|
+
assert _DP_RANK is not None, "dp attention not initialized!"
|
64
|
+
return _DP_RANK
|
65
|
+
|
66
|
+
|
67
|
+
def get_attention_dp_size():
|
68
|
+
assert _DP_SIZE is not None, "dp attention not initialized!"
|
69
|
+
return _DP_SIZE
|
sglang/srt/layers/linear.py
CHANGED
@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
|
|
7
7
|
import torch
|
8
8
|
import torch.nn.functional as F
|
9
9
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
10
|
-
|
10
|
+
|
11
|
+
from sglang.srt.distributed import (
|
11
12
|
divide,
|
12
13
|
get_tensor_model_parallel_rank,
|
13
14
|
get_tensor_model_parallel_world_size,
|
@@ -15,10 +16,6 @@ from vllm.distributed import (
|
|
15
16
|
tensor_model_parallel_all_gather,
|
16
17
|
tensor_model_parallel_all_reduce,
|
17
18
|
)
|
18
|
-
|
19
|
-
# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
|
20
|
-
from vllm.model_executor.layers.linear import LinearBase
|
21
|
-
|
22
19
|
from sglang.srt.layers.parameter import (
|
23
20
|
BasevLLMParameter,
|
24
21
|
PackedColumnParameter,
|
@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
174
171
|
return F.linear(x, layer.weight, bias)
|
175
172
|
|
176
173
|
|
174
|
+
class LinearBase(torch.nn.Module):
|
175
|
+
"""Base linear layer.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
input_size: input dimension of the linear layer.
|
179
|
+
output_size: output dimension of the linear layer.
|
180
|
+
bias: If true, add bias.
|
181
|
+
skip_bias_add: If true, skip adding bias but instead return it.
|
182
|
+
params_dtype: Data type for the parameters.
|
183
|
+
quant_config: Quantization configure.
|
184
|
+
"""
|
185
|
+
|
186
|
+
def __init__(
|
187
|
+
self,
|
188
|
+
input_size: int,
|
189
|
+
output_size: int,
|
190
|
+
skip_bias_add: bool = False,
|
191
|
+
params_dtype: Optional[torch.dtype] = None,
|
192
|
+
quant_config: Optional[QuantizationConfig] = None,
|
193
|
+
prefix: str = "",
|
194
|
+
):
|
195
|
+
super().__init__()
|
196
|
+
|
197
|
+
# Keep input parameters
|
198
|
+
self.input_size = input_size
|
199
|
+
self.output_size = output_size
|
200
|
+
self.skip_bias_add = skip_bias_add
|
201
|
+
if params_dtype is None:
|
202
|
+
params_dtype = torch.get_default_dtype()
|
203
|
+
self.params_dtype = params_dtype
|
204
|
+
if quant_config is None:
|
205
|
+
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
|
206
|
+
else:
|
207
|
+
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
208
|
+
|
209
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
210
|
+
raise NotImplementedError
|
211
|
+
|
212
|
+
|
177
213
|
class ReplicatedLinear(LinearBase):
|
178
214
|
"""Replicated linear layer.
|
179
215
|
|
@@ -14,17 +14,18 @@
|
|
14
14
|
"""Logits processing."""
|
15
15
|
|
16
16
|
import dataclasses
|
17
|
+
import logging
|
17
18
|
from typing import List, Optional, Union
|
18
19
|
|
19
20
|
import torch
|
20
21
|
import triton
|
21
22
|
import triton.language as tl
|
22
23
|
from torch import nn
|
23
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
24
26
|
get_tensor_model_parallel_world_size,
|
25
27
|
tensor_model_parallel_all_gather,
|
26
28
|
)
|
27
|
-
|
28
29
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
29
30
|
from sglang.srt.model_executor.forward_batch_info import (
|
30
31
|
CaptureHiddenMode,
|
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
32
33
|
ForwardMode,
|
33
34
|
)
|
34
35
|
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
35
38
|
|
36
39
|
@dataclasses.dataclass
|
37
40
|
class LogitsProcessorOutput:
|
@@ -50,8 +53,6 @@ class LogitsProcessorOutput:
|
|
50
53
|
next_token_top_logprobs_idx: Optional[List] = None
|
51
54
|
|
52
55
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
53
|
-
# The normlaized logprobs of prompts. shape: [#seq]
|
54
|
-
normalized_prompt_logprobs: torch.Tensor = None
|
55
56
|
# The logprobs of input tokens. shape: [#token]
|
56
57
|
input_token_logprobs: torch.Tensor = None
|
57
58
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
@@ -129,59 +130,70 @@ class LogitsProcessor(nn.Module):
|
|
129
130
|
hidden_states,
|
130
131
|
lm_head: VocabParallelEmbedding,
|
131
132
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
132
|
-
):
|
133
|
+
) -> LogitsProcessorOutput:
|
133
134
|
if isinstance(logits_metadata, ForwardBatch):
|
134
135
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
135
136
|
|
136
137
|
# Get the last hidden states and last logits for the next token prediction
|
137
138
|
if (
|
138
|
-
logits_metadata.forward_mode.
|
139
|
+
logits_metadata.forward_mode.is_decode_or_idle()
|
139
140
|
or logits_metadata.forward_mode.is_target_verify()
|
140
141
|
):
|
141
|
-
|
142
|
-
|
143
|
-
|
142
|
+
pruned_states = hidden_states
|
143
|
+
sample_indices = None
|
144
|
+
elif (
|
145
|
+
logits_metadata.forward_mode.is_extend()
|
146
|
+
and not logits_metadata.extend_return_logprob
|
147
|
+
):
|
148
|
+
# Prefill without input logprobs.
|
144
149
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
145
|
-
|
150
|
+
pruned_states = hidden_states[last_index]
|
151
|
+
sample_indices = None
|
152
|
+
else:
|
153
|
+
# Slice the requested tokens to compute logprob
|
154
|
+
sample_index_pt = -1
|
155
|
+
sample_indices = []
|
156
|
+
pt, pruned_states, pruned_input_ids = 0, [], []
|
157
|
+
for start_len, extend_len in zip(
|
158
|
+
logits_metadata.extend_logprob_start_lens_cpu,
|
159
|
+
logits_metadata.extend_seq_lens_cpu,
|
160
|
+
):
|
161
|
+
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
162
|
+
sample_index_pt += extend_len - start_len
|
163
|
+
sample_indices.append(sample_index_pt)
|
164
|
+
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
165
|
+
pt += extend_len
|
166
|
+
|
167
|
+
pruned_states = torch.cat(pruned_states)
|
168
|
+
|
169
|
+
# Compute logits for both input and sampled tokens.
|
170
|
+
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
171
|
+
sampled_logits = (
|
172
|
+
logits[sample_indices] if sample_indices is not None else logits
|
173
|
+
)
|
146
174
|
|
147
|
-
# Compute logits
|
148
|
-
last_logits = self._get_logits(last_hidden, lm_head)
|
149
175
|
if (
|
150
176
|
not logits_metadata.extend_return_logprob
|
151
177
|
or logits_metadata.capture_hidden_mode.need_capture()
|
152
178
|
):
|
153
179
|
# Decode mode or extend mode without return_logprob.
|
154
180
|
return LogitsProcessorOutput(
|
155
|
-
next_token_logits=
|
181
|
+
next_token_logits=sampled_logits,
|
156
182
|
hidden_states=(
|
157
183
|
hidden_states
|
158
184
|
if logits_metadata.capture_hidden_mode.is_full()
|
159
185
|
else (
|
160
|
-
|
186
|
+
pruned_states
|
161
187
|
if logits_metadata.capture_hidden_mode.is_last()
|
162
188
|
else None
|
163
189
|
)
|
164
190
|
),
|
165
191
|
)
|
166
192
|
else:
|
167
|
-
|
168
|
-
|
169
|
-
for start_len, extend_len in zip(
|
170
|
-
logits_metadata.extend_logprob_start_lens_cpu,
|
171
|
-
logits_metadata.extend_seq_lens_cpu,
|
172
|
-
):
|
173
|
-
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
174
|
-
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
175
|
-
pt += extend_len
|
176
|
-
|
177
|
-
# Compute the logits of all required tokens
|
178
|
-
pruned_states = torch.cat(pruned_states)
|
179
|
-
del hidden_states
|
180
|
-
input_token_logits = self._get_logits(pruned_states, lm_head)
|
181
|
-
del pruned_states
|
193
|
+
input_logprobs = logits
|
194
|
+
del hidden_states, logits
|
182
195
|
|
183
196
|
# Normalize the logprob w/o temperature, top-p
|
184
|
-
input_logprobs = input_token_logits
|
185
197
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
186
198
|
input_logprobs, logits_metadata
|
187
199
|
)
|
@@ -195,25 +207,18 @@ class LogitsProcessor(nn.Module):
|
|
195
207
|
else:
|
196
208
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
197
209
|
|
198
|
-
# Compute the normalized logprobs for the requested tokens.
|
199
|
-
# Note that we pad a zero at the end for easy batching.
|
200
210
|
input_token_logprobs = input_logprobs[
|
201
|
-
torch.arange(input_logprobs.shape[0], device=
|
211
|
+
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
202
212
|
torch.cat(
|
203
213
|
[
|
204
214
|
torch.cat(pruned_input_ids)[1:],
|
205
|
-
torch.tensor([0], device=
|
215
|
+
torch.tensor([0], device=input_logprobs.device),
|
206
216
|
]
|
207
217
|
),
|
208
218
|
]
|
209
|
-
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
210
|
-
input_token_logprobs,
|
211
|
-
logits_metadata,
|
212
|
-
)
|
213
219
|
|
214
220
|
return LogitsProcessorOutput(
|
215
|
-
next_token_logits=
|
216
|
-
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
221
|
+
next_token_logits=sampled_logits,
|
217
222
|
input_token_logprobs=input_token_logprobs,
|
218
223
|
input_top_logprobs_val=input_top_logprobs_val,
|
219
224
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
@@ -223,8 +228,11 @@ class LogitsProcessor(nn.Module):
|
|
223
228
|
self,
|
224
229
|
hidden_states: torch.Tensor,
|
225
230
|
lm_head: VocabParallelEmbedding,
|
231
|
+
logits_metadata: LogitsMetadata,
|
226
232
|
embedding_bias: Optional[torch.Tensor] = None,
|
227
233
|
) -> torch.Tensor:
|
234
|
+
"""Get logits from hidden_states."""
|
235
|
+
|
228
236
|
if hasattr(lm_head, "weight"):
|
229
237
|
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
230
238
|
else:
|
@@ -237,8 +245,6 @@ class LogitsProcessor(nn.Module):
|
|
237
245
|
if self.do_tensor_parallel_all_gather:
|
238
246
|
logits = tensor_model_parallel_all_gather(logits)
|
239
247
|
|
240
|
-
# Compute the normalized logprobs for the requested tokens.
|
241
|
-
# Note that we pad a zero at the end for easy batching.
|
242
248
|
logits = logits[:, : self.config.vocab_size].float()
|
243
249
|
|
244
250
|
if self.final_logit_softcapping:
|
@@ -246,27 +252,6 @@ class LogitsProcessor(nn.Module):
|
|
246
252
|
|
247
253
|
return logits
|
248
254
|
|
249
|
-
@staticmethod
|
250
|
-
def _get_normalized_prompt_logprobs(
|
251
|
-
input_token_logprobs: torch.Tensor,
|
252
|
-
logits_metadata: LogitsMetadata,
|
253
|
-
):
|
254
|
-
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
255
|
-
pruned_lens = torch.tensor(
|
256
|
-
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
257
|
-
)
|
258
|
-
|
259
|
-
start = torch.zeros_like(pruned_lens)
|
260
|
-
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
261
|
-
end = torch.clamp(
|
262
|
-
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
263
|
-
)
|
264
|
-
sum_logp = (
|
265
|
-
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
266
|
-
)
|
267
|
-
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
268
|
-
return normalized_prompt_logprobs
|
269
|
-
|
270
255
|
@staticmethod
|
271
256
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
272
257
|
max_k = max(logits_metadata.top_logprobs_nums)
|
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
|
|
4
4
|
import torch
|
5
5
|
from torch.nn import Module
|
6
6
|
from vllm import _custom_ops as ops
|
7
|
-
from vllm.
|
7
|
+
from vllm.model_executor.custom_op import CustomOp
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
8
10
|
get_tensor_model_parallel_rank,
|
9
11
|
get_tensor_model_parallel_world_size,
|
10
12
|
)
|
11
|
-
from vllm.model_executor.custom_op import CustomOp
|
12
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
13
|
-
|
14
13
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
15
14
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
15
|
grouped_gemm_triton,
|
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
25
24
|
QuantizationConfig,
|
26
25
|
QuantizeMethodBase,
|
27
26
|
)
|
27
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
28
28
|
from sglang.srt.utils import is_hip, set_weight_attrs
|
29
29
|
|
30
30
|
logger = logging.getLogger(__name__)
|
@@ -8,6 +8,7 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
+
from sglang.srt.layers.activation import SiluAndMul
|
11
12
|
from sglang.srt.layers.moe.topk import select_experts
|
12
13
|
|
13
14
|
|
@@ -44,3 +45,71 @@ def fused_moe_forward_native(
|
|
44
45
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
45
46
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
46
47
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
48
|
+
|
49
|
+
|
50
|
+
def moe_forward_native(
|
51
|
+
layer: torch.nn.Module,
|
52
|
+
x: torch.Tensor,
|
53
|
+
use_grouped_topk: bool,
|
54
|
+
top_k: int,
|
55
|
+
router_logits: torch.Tensor,
|
56
|
+
renormalize: bool,
|
57
|
+
topk_group: Optional[int] = None,
|
58
|
+
num_expert_group: Optional[int] = None,
|
59
|
+
custom_routing_function: Optional[Callable] = None,
|
60
|
+
correction_bias: Optional[torch.Tensor] = None,
|
61
|
+
) -> torch.Tensor:
|
62
|
+
|
63
|
+
topk_weights, topk_ids = select_experts(
|
64
|
+
hidden_states=x,
|
65
|
+
router_logits=router_logits,
|
66
|
+
use_grouped_topk=use_grouped_topk,
|
67
|
+
top_k=top_k,
|
68
|
+
renormalize=renormalize,
|
69
|
+
topk_group=topk_group,
|
70
|
+
num_expert_group=num_expert_group,
|
71
|
+
custom_routing_function=custom_routing_function,
|
72
|
+
correction_bias=correction_bias,
|
73
|
+
torch_native=True,
|
74
|
+
)
|
75
|
+
|
76
|
+
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
77
|
+
len_experts = layer.num_experts
|
78
|
+
|
79
|
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
|
80
|
+
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
|
81
|
+
tokens_per_expert = cnts.sum(dim=0)
|
82
|
+
idxs = topk_ids.view(-1).argsort()
|
83
|
+
|
84
|
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
85
|
+
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
86
|
+
|
87
|
+
outputs = []
|
88
|
+
start_idx = 0
|
89
|
+
for i, num_tokens in enumerate(tokens_per_expert):
|
90
|
+
end_idx = start_idx + num_tokens
|
91
|
+
if num_tokens == 0:
|
92
|
+
continue
|
93
|
+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
94
|
+
|
95
|
+
layer_w13_weight = layer.w13_weight[i]
|
96
|
+
layer_w2_weight = layer.w2_weight[i]
|
97
|
+
|
98
|
+
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
99
|
+
gate_up = SiluAndMul()(gate_up)
|
100
|
+
expert_out = F.linear(gate_up, layer_w2_weight)
|
101
|
+
outputs.append(expert_out)
|
102
|
+
start_idx = end_idx
|
103
|
+
|
104
|
+
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
105
|
+
new_x = torch.empty_like(outs)
|
106
|
+
|
107
|
+
new_x[idxs] = outs
|
108
|
+
final_out = (
|
109
|
+
new_x.view(*topk_ids.shape, -1)
|
110
|
+
.type(topk_weights.dtype)
|
111
|
+
.mul_(topk_weights.unsqueeze(dim=-1))
|
112
|
+
.sum(dim=1)
|
113
|
+
.type(new_x.dtype)
|
114
|
+
)
|
115
|
+
return final_out
|
@@ -15,15 +15,18 @@ from vllm import _custom_ops as ops
|
|
15
15
|
|
16
16
|
from sglang.srt.layers.moe.topk import select_experts
|
17
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
|
-
from sglang.srt.utils import
|
18
|
+
from sglang.srt.utils import (
|
19
|
+
direct_register_custom_op,
|
20
|
+
get_device_name,
|
21
|
+
is_cuda_available,
|
22
|
+
is_hip,
|
23
|
+
)
|
19
24
|
|
20
|
-
|
21
|
-
|
25
|
+
is_cuda = is_cuda_available()
|
26
|
+
is_hip_flag = is_hip()
|
27
|
+
if is_cuda:
|
22
28
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
23
29
|
|
24
|
-
is_hip_flag = False
|
25
|
-
else:
|
26
|
-
is_hip_flag = True
|
27
30
|
|
28
31
|
logger = logging.getLogger(__name__)
|
29
32
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|