sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from sglang.srt.layers.attention import AttentionBackend
|
8
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
9
|
|
11
10
|
if TYPE_CHECKING:
|
@@ -35,10 +34,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
35
34
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
36
35
|
)
|
37
36
|
|
38
|
-
|
39
|
-
|
40
|
-
else:
|
41
|
-
self.reduce_dtype = torch.float16
|
37
|
+
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
38
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
42
39
|
|
43
40
|
self.forward_metadata = None
|
44
41
|
|
@@ -50,23 +47,23 @@ class TritonAttnBackend(AttentionBackend):
|
|
50
47
|
"""Init auxiliary variables for triton attention backend."""
|
51
48
|
|
52
49
|
if forward_batch.forward_mode.is_decode():
|
53
|
-
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
54
|
-
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
55
|
-
|
56
|
-
total_num_tokens = forward_batch.seq_lens_sum
|
57
50
|
attn_logits = torch.empty(
|
58
|
-
(
|
59
|
-
|
51
|
+
(
|
52
|
+
forward_batch.batch_size,
|
53
|
+
self.num_head,
|
54
|
+
self.num_kv_splits,
|
55
|
+
self.v_head_dim + 1,
|
56
|
+
),
|
57
|
+
dtype=torch.float32,
|
60
58
|
device=self.device,
|
61
59
|
)
|
62
60
|
|
63
|
-
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
64
61
|
max_extend_len = None
|
65
62
|
else:
|
66
|
-
|
63
|
+
attn_logits = None
|
67
64
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
68
65
|
|
69
|
-
self.forward_metadata =
|
66
|
+
self.forward_metadata = attn_logits, max_extend_len
|
70
67
|
|
71
68
|
def init_cuda_graph_state(self, max_bs: int):
|
72
69
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
@@ -75,11 +72,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
75
72
|
(max_bs,), dtype=torch.int32, device=self.device
|
76
73
|
)
|
77
74
|
self.cuda_graph_attn_logits = torch.empty(
|
78
|
-
(
|
79
|
-
|
80
|
-
self.cuda_graph_max_total_num_tokens,
|
81
|
-
),
|
82
|
-
dtype=self.reduce_dtype,
|
75
|
+
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
76
|
+
dtype=torch.float32,
|
83
77
|
device="cuda",
|
84
78
|
)
|
85
79
|
|
@@ -92,9 +86,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
92
86
|
):
|
93
87
|
# NOTE: encoder_lens expected to be zeros or None
|
94
88
|
self.forward_metadata = (
|
95
|
-
self.cuda_graph_start_loc,
|
96
89
|
self.cuda_graph_attn_logits,
|
97
|
-
self.cuda_graph_max_seq_len,
|
98
90
|
None,
|
99
91
|
)
|
100
92
|
|
@@ -114,7 +106,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
114
106
|
return 1
|
115
107
|
|
116
108
|
def forward_extend(
|
117
|
-
self,
|
109
|
+
self,
|
110
|
+
q,
|
111
|
+
k,
|
112
|
+
v,
|
113
|
+
layer: RadixAttention,
|
114
|
+
forward_batch: ForwardBatch,
|
115
|
+
save_kv_cache=True,
|
118
116
|
):
|
119
117
|
# TODO: reuse the buffer across layers
|
120
118
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -122,11 +120,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
122
120
|
else:
|
123
121
|
o = torch.empty_like(q)
|
124
122
|
|
125
|
-
|
126
|
-
|
127
|
-
|
123
|
+
if save_kv_cache:
|
124
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
125
|
+
layer, forward_batch.out_cache_loc, k, v
|
126
|
+
)
|
128
127
|
|
129
|
-
|
128
|
+
_, max_extend_len = self.forward_metadata
|
130
129
|
self.extend_attention_fwd(
|
131
130
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
132
131
|
k.contiguous(),
|
@@ -146,7 +145,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
146
145
|
return o
|
147
146
|
|
148
147
|
def forward_decode(
|
149
|
-
self,
|
148
|
+
self,
|
149
|
+
q,
|
150
|
+
k,
|
151
|
+
v,
|
152
|
+
layer: RadixAttention,
|
153
|
+
forward_batch: ForwardBatch,
|
154
|
+
save_kv_cache=True,
|
150
155
|
):
|
151
156
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
152
157
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -158,11 +163,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
158
163
|
else:
|
159
164
|
o = torch.empty_like(q)
|
160
165
|
|
161
|
-
|
166
|
+
attn_logits, _ = self.forward_metadata
|
162
167
|
|
163
|
-
|
164
|
-
|
165
|
-
|
168
|
+
if save_kv_cache:
|
169
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
170
|
+
layer, forward_batch.out_cache_loc, k, v
|
171
|
+
)
|
166
172
|
|
167
173
|
self.decode_attention_fwd(
|
168
174
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
@@ -171,10 +177,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
171
177
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
172
178
|
forward_batch.req_to_token_pool.req_to_token,
|
173
179
|
forward_batch.req_pool_indices,
|
174
|
-
start_loc,
|
175
180
|
forward_batch.seq_lens,
|
176
181
|
attn_logits,
|
177
|
-
|
182
|
+
self.num_kv_splits,
|
178
183
|
layer.scaling,
|
179
184
|
layer.logit_cap,
|
180
185
|
)
|