sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
from sglang.srt.layers.attention import AttentionBackend
|
9
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
|
+
|
15
|
+
|
16
|
+
class TritonAttnBackend(AttentionBackend):
|
17
|
+
def __init__(self, model_runner: ModelRunner):
|
18
|
+
# Lazy import to avoid the initialization of cuda context
|
19
|
+
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
20
|
+
decode_attention_fwd,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
23
|
+
extend_attention_fwd,
|
24
|
+
)
|
25
|
+
|
26
|
+
super().__init__()
|
27
|
+
|
28
|
+
self.decode_attention_fwd = decode_attention_fwd
|
29
|
+
self.extend_attention_fwd = extend_attention_fwd
|
30
|
+
self.num_head = (
|
31
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
32
|
+
)
|
33
|
+
|
34
|
+
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
35
|
+
self.reduce_dtype = torch.float32
|
36
|
+
else:
|
37
|
+
self.reduce_dtype = torch.float16
|
38
|
+
|
39
|
+
self.forward_metadata = None
|
40
|
+
|
41
|
+
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
42
|
+
|
43
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
44
|
+
"""Init auxiliary variables for triton attention backend."""
|
45
|
+
|
46
|
+
if forward_batch.forward_mode.is_decode():
|
47
|
+
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
48
|
+
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
49
|
+
|
50
|
+
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
51
|
+
attn_logits = torch.empty(
|
52
|
+
(self.num_head, total_num_tokens),
|
53
|
+
dtype=self.reduce_dtype,
|
54
|
+
device="cuda",
|
55
|
+
)
|
56
|
+
|
57
|
+
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
58
|
+
max_extend_len = None
|
59
|
+
else:
|
60
|
+
start_loc = attn_logits = max_seq_len = None
|
61
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
62
|
+
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
63
|
+
|
64
|
+
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
65
|
+
|
66
|
+
def init_cuda_graph_state(self, max_bs: int):
|
67
|
+
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
68
|
+
|
69
|
+
self.cuda_graph_start_loc = torch.zeros(
|
70
|
+
(max_bs,), dtype=torch.int32, device="cuda"
|
71
|
+
)
|
72
|
+
self.cuda_graph_attn_logits = torch.empty(
|
73
|
+
(
|
74
|
+
self.num_head,
|
75
|
+
self.cuda_graph_max_total_num_tokens,
|
76
|
+
),
|
77
|
+
dtype=self.reduce_dtype,
|
78
|
+
device="cuda",
|
79
|
+
)
|
80
|
+
|
81
|
+
def init_forward_metadata_capture_cuda_graph(
|
82
|
+
self, bs: int, req_pool_indices, seq_lens
|
83
|
+
):
|
84
|
+
self.forward_metadata = (
|
85
|
+
self.cuda_graph_start_loc,
|
86
|
+
self.cuda_graph_attn_logits,
|
87
|
+
self.cuda_graph_max_seq_len,
|
88
|
+
None,
|
89
|
+
)
|
90
|
+
|
91
|
+
def init_forward_metadata_replay_cuda_graph(
|
92
|
+
self, bs: int, req_pool_indices, seq_lens
|
93
|
+
):
|
94
|
+
self.cuda_graph_start_loc.zero_()
|
95
|
+
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
96
|
+
|
97
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
98
|
+
return 1
|
99
|
+
|
100
|
+
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
101
|
+
# TODO: reuse the buffer across layers
|
102
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
103
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
104
|
+
else:
|
105
|
+
o = torch.empty_like(q)
|
106
|
+
|
107
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
108
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v
|
109
|
+
)
|
110
|
+
|
111
|
+
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
112
|
+
self.extend_attention_fwd(
|
113
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
114
|
+
k.contiguous(),
|
115
|
+
v.contiguous(),
|
116
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
117
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
118
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
119
|
+
forward_batch.req_to_token_pool.req_to_token,
|
120
|
+
forward_batch.req_pool_indices,
|
121
|
+
forward_batch.seq_lens,
|
122
|
+
forward_batch.extend_seq_lens,
|
123
|
+
forward_batch.extend_start_loc,
|
124
|
+
max_extend_len,
|
125
|
+
layer.scaling,
|
126
|
+
layer.logit_cap,
|
127
|
+
)
|
128
|
+
return o
|
129
|
+
|
130
|
+
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
131
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
132
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
133
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
134
|
+
|
135
|
+
# TODO: reuse the buffer across layers
|
136
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
137
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
138
|
+
else:
|
139
|
+
o = torch.empty_like(q)
|
140
|
+
|
141
|
+
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
142
|
+
|
143
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
144
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v
|
145
|
+
)
|
146
|
+
|
147
|
+
self.decode_attention_fwd(
|
148
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
149
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
150
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
151
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
152
|
+
forward_batch.req_to_token_pool.req_to_token,
|
153
|
+
forward_batch.req_pool_indices,
|
154
|
+
start_loc,
|
155
|
+
forward_batch.seq_lens,
|
156
|
+
attn_logits,
|
157
|
+
max_seq_len,
|
158
|
+
layer.scaling,
|
159
|
+
layer.logit_cap,
|
160
|
+
)
|
161
|
+
return o
|
@@ -22,7 +22,9 @@ import torch
|
|
22
22
|
import triton
|
23
23
|
import triton.language as tl
|
24
24
|
|
25
|
-
from sglang.srt.layers.
|
25
|
+
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
26
|
+
context_attention_fwd,
|
27
|
+
)
|
26
28
|
|
27
29
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
28
30
|
|
@@ -0,0 +1,117 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn import functional as F
|
5
|
+
|
6
|
+
|
7
|
+
def fused_topk_native(
|
8
|
+
hidden_states: torch.Tensor,
|
9
|
+
gating_output: torch.Tensor,
|
10
|
+
topk: int,
|
11
|
+
renormalize: bool,
|
12
|
+
):
|
13
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
14
|
+
M, _ = hidden_states.shape
|
15
|
+
topk_weights = torch.empty(
|
16
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
17
|
+
)
|
18
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
19
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
20
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
21
|
+
if renormalize:
|
22
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
23
|
+
return topk_weights, topk_ids
|
24
|
+
|
25
|
+
|
26
|
+
# This is used by the Deepseek-V2 model
|
27
|
+
def grouped_topk(
|
28
|
+
hidden_states: torch.Tensor,
|
29
|
+
gating_output: torch.Tensor,
|
30
|
+
topk: int,
|
31
|
+
renormalize: bool,
|
32
|
+
num_expert_group: int = 0,
|
33
|
+
topk_group: int = 0,
|
34
|
+
):
|
35
|
+
|
36
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
37
|
+
|
38
|
+
scores = torch.softmax(gating_output, dim=-1)
|
39
|
+
num_token = scores.shape[0]
|
40
|
+
group_scores = (
|
41
|
+
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
42
|
+
) # [n, n_group]
|
43
|
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
44
|
+
1
|
45
|
+
] # [n, top_k_group]
|
46
|
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
47
|
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
48
|
+
score_mask = (
|
49
|
+
group_mask.unsqueeze(-1)
|
50
|
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
51
|
+
.reshape(num_token, -1)
|
52
|
+
) # [n, e]
|
53
|
+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
54
|
+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
55
|
+
|
56
|
+
if renormalize:
|
57
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
58
|
+
return topk_weights, topk_ids
|
59
|
+
|
60
|
+
|
61
|
+
def select_experts_native(
|
62
|
+
hidden_states: torch.Tensor,
|
63
|
+
router_logits: torch.Tensor,
|
64
|
+
top_k: int,
|
65
|
+
use_grouped_topk: bool,
|
66
|
+
renormalize: bool,
|
67
|
+
topk_group: Optional[int] = None,
|
68
|
+
num_expert_group: Optional[int] = None,
|
69
|
+
):
|
70
|
+
# DeekSeekv2 uses grouped_top_k
|
71
|
+
if use_grouped_topk:
|
72
|
+
assert topk_group is not None
|
73
|
+
assert num_expert_group is not None
|
74
|
+
topk_weights, topk_ids = grouped_topk(
|
75
|
+
hidden_states=hidden_states,
|
76
|
+
gating_output=router_logits,
|
77
|
+
topk=top_k,
|
78
|
+
renormalize=renormalize,
|
79
|
+
num_expert_group=num_expert_group,
|
80
|
+
topk_group=topk_group,
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
topk_weights, topk_ids = fused_topk_native(
|
84
|
+
hidden_states=hidden_states,
|
85
|
+
gating_output=router_logits,
|
86
|
+
topk=top_k,
|
87
|
+
renormalize=renormalize,
|
88
|
+
)
|
89
|
+
return topk_weights, topk_ids
|
90
|
+
|
91
|
+
|
92
|
+
def fused_moe_forward_native(
|
93
|
+
layer: torch.nn.Module,
|
94
|
+
x: torch.Tensor,
|
95
|
+
use_grouped_topk: bool,
|
96
|
+
top_k: int,
|
97
|
+
router_logits: torch.Tensor,
|
98
|
+
renormalize: bool,
|
99
|
+
topk_group: Optional[int] = None,
|
100
|
+
num_expert_group: Optional[int] = None,
|
101
|
+
) -> torch.Tensor:
|
102
|
+
topk_weights, topk_ids = select_experts_native(
|
103
|
+
hidden_states=x,
|
104
|
+
router_logits=router_logits,
|
105
|
+
use_grouped_topk=use_grouped_topk,
|
106
|
+
top_k=top_k,
|
107
|
+
renormalize=renormalize,
|
108
|
+
topk_group=topk_group,
|
109
|
+
num_expert_group=num_expert_group,
|
110
|
+
)
|
111
|
+
w13_weights = layer.w13_weight[topk_ids]
|
112
|
+
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
113
|
+
w2_weights = layer.w2_weight[topk_ids]
|
114
|
+
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
115
|
+
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
116
|
+
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
117
|
+
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
|
|
21
21
|
import torch
|
22
22
|
import torch.nn as nn
|
23
23
|
|
24
|
-
from sglang.srt.utils import
|
24
|
+
from sglang.srt.utils import is_flashinfer_available
|
25
25
|
|
26
|
-
if
|
26
|
+
if is_flashinfer_available():
|
27
27
|
from flashinfer.norm import (
|
28
28
|
fused_add_rmsnorm,
|
29
29
|
gemma_fused_add_rmsnorm,
|
@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
|
|
119
119
|
return out
|
120
120
|
|
121
121
|
|
122
|
-
if
|
122
|
+
if not is_flashinfer_available():
|
123
123
|
logger.info(
|
124
|
-
"FlashInfer is not available on
|
124
|
+
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
|
125
125
|
)
|
126
126
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
@@ -25,7 +25,7 @@ from vllm.distributed import (
|
|
25
25
|
tensor_model_parallel_all_gather,
|
26
26
|
)
|
27
27
|
|
28
|
-
from sglang.srt.model_executor.forward_batch_info import
|
28
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
@@ -61,26 +61,30 @@ class LogitsMetadata:
|
|
61
61
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
62
62
|
|
63
63
|
@classmethod
|
64
|
-
def
|
65
|
-
|
66
|
-
|
64
|
+
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
65
|
+
if forward_batch.return_logprob:
|
66
|
+
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
67
|
+
else:
|
68
|
+
return_top_logprob = False
|
69
|
+
|
70
|
+
if forward_batch.forward_mode.is_extend():
|
67
71
|
extend_logprob_pruned_lens_cpu = [
|
68
72
|
extend_len - start_len
|
69
73
|
for extend_len, start_len in zip(
|
70
|
-
|
71
|
-
|
74
|
+
forward_batch.extend_seq_lens,
|
75
|
+
forward_batch.extend_logprob_start_lens_cpu,
|
72
76
|
)
|
73
77
|
]
|
74
78
|
else:
|
75
79
|
extend_logprob_pruned_lens_cpu = None
|
76
80
|
return cls(
|
77
|
-
forward_mode=
|
78
|
-
top_logprobs_nums=
|
79
|
-
return_logprob=
|
81
|
+
forward_mode=forward_batch.forward_mode,
|
82
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
83
|
+
return_logprob=forward_batch.return_logprob,
|
80
84
|
return_top_logprob=return_top_logprob,
|
81
|
-
extend_seq_lens=
|
82
|
-
extend_seq_lens_cpu=
|
83
|
-
extend_logprob_start_lens_cpu=
|
85
|
+
extend_seq_lens=forward_batch.extend_seq_lens,
|
86
|
+
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
87
|
+
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
84
88
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
85
89
|
)
|
86
90
|
|
@@ -162,10 +166,10 @@ class LogitsProcessor(nn.Module):
|
|
162
166
|
input_ids,
|
163
167
|
hidden_states,
|
164
168
|
weight,
|
165
|
-
logits_metadata: Union[LogitsMetadata,
|
169
|
+
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
166
170
|
):
|
167
|
-
if isinstance(logits_metadata,
|
168
|
-
logits_metadata = LogitsMetadata.
|
171
|
+
if isinstance(logits_metadata, ForwardBatch):
|
172
|
+
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
169
173
|
assert isinstance(logits_metadata, LogitsMetadata)
|
170
174
|
|
171
175
|
# Get the last hidden states and last logits for the next token prediction
|
sglang/srt/layers/pooler.py
CHANGED
@@ -7,7 +7,7 @@ from enum import IntEnum
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
|
-
from sglang.srt.model_executor.model_runner import
|
10
|
+
from sglang.srt.model_executor.model_runner import ForwardBatch
|
11
11
|
|
12
12
|
|
13
13
|
class PoolingType(IntEnum):
|
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
|
|
36
36
|
self.normalize = normalize
|
37
37
|
|
38
38
|
def forward(
|
39
|
-
self, hidden_states: torch.Tensor,
|
39
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
40
40
|
) -> EmbeddingPoolerOutput:
|
41
41
|
if self.pooling_type == PoolingType.LAST:
|
42
|
-
last_token_indices = torch.cumsum(
|
42
|
+
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
43
43
|
pooled_data = hidden_states[last_token_indices]
|
44
44
|
else:
|
45
45
|
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
|
19
19
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
20
20
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
21
21
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
22
|
-
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
23
22
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
24
23
|
|
25
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -39,7 +38,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
39
38
|
"gptq_marlin": GPTQMarlinConfig,
|
40
39
|
"awq_marlin": AWQMarlinConfig,
|
41
40
|
"gptq": GPTQConfig,
|
42
|
-
"squeezellm": SqueezeLLMConfig,
|
43
41
|
"compressed-tensors": CompressedTensorsConfig,
|
44
42
|
"bitsandbytes": BitsAndBytesConfig,
|
45
43
|
"qqq": QQQConfig,
|
@@ -17,7 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
|
-
from sglang.srt.model_executor.forward_batch_info import
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
21
|
|
22
22
|
|
23
23
|
class RadixAttention(nn.Module):
|
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
|
|
32
32
|
scaling: float,
|
33
33
|
num_kv_heads: int,
|
34
34
|
layer_id: int,
|
35
|
-
sliding_window_size: int = -1,
|
36
35
|
logit_cap: float = 0.0,
|
37
36
|
v_head_dim: int = -1,
|
37
|
+
sliding_window_size: int = -1,
|
38
|
+
is_cross_attention: bool = False,
|
38
39
|
):
|
39
40
|
super().__init__()
|
40
41
|
self.tp_q_head_num = num_heads
|
@@ -47,12 +48,13 @@ class RadixAttention(nn.Module):
|
|
47
48
|
self.layer_id = layer_id
|
48
49
|
self.logit_cap = logit_cap
|
49
50
|
self.sliding_window_size = sliding_window_size or -1
|
51
|
+
self.is_cross_attention = is_cross_attention
|
50
52
|
|
51
|
-
def forward(self, q, k, v,
|
53
|
+
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
52
54
|
if k is not None:
|
53
55
|
# For cross-layer sharing, kv can be None
|
54
56
|
assert v is not None
|
55
57
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
56
58
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
57
59
|
|
58
|
-
return
|
60
|
+
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -7,10 +7,9 @@ from torch import nn
|
|
7
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
9
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
|
-
from sglang.srt.utils import
|
10
|
+
from sglang.srt.utils import is_flashinfer_available
|
11
11
|
|
12
|
-
|
13
|
-
if not is_hip():
|
12
|
+
if is_flashinfer_available():
|
14
13
|
from flashinfer.sampling import (
|
15
14
|
min_p_sampling_from_probs,
|
16
15
|
top_k_renorm_prob,
|
@@ -43,7 +42,10 @@ class Sampler(nn.Module):
|
|
43
42
|
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
44
43
|
)
|
45
44
|
|
46
|
-
if
|
45
|
+
if sampling_info.top_ks.max().item() <= 1:
|
46
|
+
# Use torch.argmax if all requests use greedy sampling
|
47
|
+
batch_next_token_ids = torch.argmax(probs, -1)
|
48
|
+
elif global_server_args_dict["sampling_backend"] == "flashinfer":
|
47
49
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
48
50
|
uniform_samples = torch.rand(
|
49
51
|
(max_top_k_round, batch_size), device=probs.device
|
@@ -18,11 +18,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
18
18
|
"""
|
19
19
|
# Lazy import to suppress some warnings
|
20
20
|
from torchao.quantization import (
|
21
|
+
float8_dynamic_activation_float8_weight,
|
21
22
|
int4_weight_only,
|
22
23
|
int8_dynamic_activation_int8_weight,
|
23
24
|
int8_weight_only,
|
24
25
|
quantize_,
|
25
26
|
)
|
27
|
+
from torchao.quantization.observer import PerRow, PerTensor
|
26
28
|
|
27
29
|
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
28
30
|
dummy_linear.weight = param
|
@@ -45,6 +47,22 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
45
47
|
# this requires newer hardware
|
46
48
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
47
49
|
quantize_(dummy_linear, float8_weight_only())
|
50
|
+
elif "fp8dq" in torchao_config:
|
51
|
+
granularity = torchao_config.split("-")[-1]
|
52
|
+
GRANULARITY_MAP = {
|
53
|
+
"per_row": PerRow(),
|
54
|
+
"per_tensor": PerTensor(),
|
55
|
+
}
|
56
|
+
assert (
|
57
|
+
granularity in GRANULARITY_MAP
|
58
|
+
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
|
59
|
+
quantize_(
|
60
|
+
dummy_linear,
|
61
|
+
float8_dynamic_activation_float8_weight(
|
62
|
+
granularity=GRANULARITY_MAP[granularity]
|
63
|
+
),
|
64
|
+
)
|
65
|
+
|
48
66
|
return dummy_linear.weight
|
49
67
|
|
50
68
|
|
sglang/srt/lora/lora.py
CHANGED
@@ -28,19 +28,19 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
28
28
|
import safetensors.torch
|
29
29
|
import torch
|
30
30
|
from torch import nn
|
31
|
-
from vllm.model_executor.layers.linear import (
|
32
|
-
ColumnParallelLinear,
|
33
|
-
MergedColumnParallelLinear,
|
34
|
-
QKVParallelLinear,
|
35
|
-
RowParallelLinear,
|
36
|
-
)
|
37
31
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
38
32
|
ParallelLMHead,
|
39
33
|
VocabParallelEmbedding,
|
40
34
|
)
|
41
35
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
42
36
|
|
43
|
-
from sglang.srt.
|
37
|
+
from sglang.srt.layers.linear import (
|
38
|
+
ColumnParallelLinear,
|
39
|
+
MergedColumnParallelLinear,
|
40
|
+
QKVParallelLinear,
|
41
|
+
RowParallelLinear,
|
42
|
+
)
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
44
44
|
|
45
45
|
|
46
46
|
class BaseLayerWithLoRA(nn.Module):
|
@@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
101
101
|
) -> None:
|
102
102
|
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
103
103
|
|
104
|
-
def set_lora_info(self, A_buffer, B_buffer, bs,
|
104
|
+
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
105
105
|
self.set_lora = True
|
106
106
|
self.A_buffer = A_buffer
|
107
107
|
self.B_buffer = B_buffer
|
108
108
|
self.bs = bs
|
109
|
-
self.
|
109
|
+
self.seg_indptr = seg_indptr
|
110
110
|
self.weight_indices = weight_indices
|
111
111
|
|
112
112
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
@@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
115
115
|
weights=self.A_buffer,
|
116
116
|
batch_size=self.bs,
|
117
117
|
weight_column_major=True,
|
118
|
-
|
118
|
+
seg_indptr=self.seg_indptr,
|
119
119
|
weight_indices=self.weight_indices,
|
120
120
|
)
|
121
121
|
# FIXME
|
122
|
-
assert lora_a_output.shape[-1] == self.lora_rank * 2
|
123
122
|
lora_output = torch.empty_like(base_output)
|
124
123
|
output_dim = lora_output.shape[-1] // 2
|
125
124
|
for i in range(2):
|
@@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
132
131
|
weights=self.B_buffer[:, left:right, :].contiguous(),
|
133
132
|
batch_size=self.bs,
|
134
133
|
weight_column_major=True,
|
135
|
-
|
134
|
+
seg_indptr=self.seg_indptr,
|
136
135
|
weight_indices=self.weight_indices,
|
137
136
|
)
|
138
137
|
return base_output + lora_output * self.scaling
|
@@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
145
144
|
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
146
145
|
|
147
146
|
def set_lora_info(
|
148
|
-
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs,
|
147
|
+
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
|
149
148
|
):
|
150
149
|
self.set_lora = True
|
151
150
|
self.A_buffer_qkv = A_buffer_qkv
|
152
151
|
self.B_buffer_q = B_buffer_q
|
153
152
|
self.B_buffer_kv = B_buffer_kv
|
154
153
|
self.bs = bs
|
155
|
-
self.
|
154
|
+
self.seg_indptr = seg_indptr
|
156
155
|
self.weight_indices = weight_indices
|
157
156
|
|
158
157
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
@@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
161
160
|
weights=self.A_buffer_qkv,
|
162
161
|
batch_size=self.bs,
|
163
162
|
weight_column_major=True,
|
164
|
-
|
163
|
+
seg_indptr=self.seg_indptr,
|
165
164
|
weight_indices=self.weight_indices,
|
166
165
|
)
|
167
166
|
# FIXME parallelize qkv
|
@@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
173
172
|
weights=self.B_buffer_q,
|
174
173
|
batch_size=self.bs,
|
175
174
|
weight_column_major=True,
|
176
|
-
|
175
|
+
seg_indptr=self.seg_indptr,
|
177
176
|
weight_indices=self.weight_indices,
|
178
177
|
)
|
179
178
|
# kv
|
@@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
189
188
|
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
190
189
|
batch_size=self.bs,
|
191
190
|
weight_column_major=True,
|
192
|
-
|
191
|
+
seg_indptr=self.seg_indptr,
|
193
192
|
weight_indices=self.weight_indices,
|
194
193
|
)
|
195
194
|
)
|
@@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
202
201
|
) -> None:
|
203
202
|
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
204
203
|
|
205
|
-
def set_lora_info(self, A_buffer, B_buffer, bs,
|
204
|
+
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
206
205
|
self.set_lora = True
|
207
206
|
self.A_buffer = A_buffer
|
208
207
|
self.B_buffer = B_buffer
|
209
208
|
self.bs = bs
|
210
|
-
self.
|
209
|
+
self.seg_indptr = seg_indptr
|
211
210
|
self.weight_indices = weight_indices
|
212
211
|
|
213
212
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
@@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
216
215
|
weights=self.A_buffer,
|
217
216
|
batch_size=self.bs,
|
218
217
|
weight_column_major=True,
|
219
|
-
|
218
|
+
seg_indptr=self.seg_indptr,
|
220
219
|
weight_indices=self.weight_indices,
|
221
220
|
)
|
222
221
|
lora_output = self.segment_gemm.run(
|
@@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
224
223
|
weights=self.B_buffer,
|
225
224
|
batch_size=self.bs,
|
226
225
|
weight_column_major=True,
|
227
|
-
|
226
|
+
seg_indptr=self.seg_indptr,
|
228
227
|
weight_indices=self.weight_indices,
|
229
228
|
)
|
230
229
|
return base_output + lora_output * self.scaling
|