sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +113 -3
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +72 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +114 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +153 -94
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +2 -1
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +0 -14
- sglang/srt/server.py +58 -16
- sglang/srt/server_args.py +42 -22
- sglang/srt/utils.py +87 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +18 -4
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,114 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Constrained decoding with xgrammar backend."""
|
17
|
+
|
18
|
+
from typing import List, Tuple
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
22
|
+
|
23
|
+
from sglang.srt.constrained.base_grammar_backend import (
|
24
|
+
BaseGrammarBackend,
|
25
|
+
BaseGrammarObject,
|
26
|
+
)
|
27
|
+
|
28
|
+
MAX_ROLLBACK_TOKENS = 10
|
29
|
+
|
30
|
+
|
31
|
+
class XGrammarGrammar(BaseGrammarObject):
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
|
35
|
+
) -> None:
|
36
|
+
self.matcher = matcher
|
37
|
+
self.vocab_size = vocab_size
|
38
|
+
self.ctx = ctx
|
39
|
+
|
40
|
+
def accept_token(self, token: int):
|
41
|
+
assert self.matcher.accept_token(token)
|
42
|
+
|
43
|
+
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
44
|
+
s = self.matcher.find_jump_forward_string()
|
45
|
+
if s:
|
46
|
+
return [], s
|
47
|
+
return None
|
48
|
+
|
49
|
+
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
50
|
+
_, data = helper
|
51
|
+
return data, -1
|
52
|
+
|
53
|
+
def jump_and_retokenize(
|
54
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
55
|
+
):
|
56
|
+
k = 0
|
57
|
+
for i, old_id in enumerate(old_output_ids):
|
58
|
+
if old_id == new_output_ids[i]:
|
59
|
+
k = i + 1
|
60
|
+
else:
|
61
|
+
break
|
62
|
+
|
63
|
+
# rollback to the last token that is the same
|
64
|
+
if k < len(old_output_ids):
|
65
|
+
self.matcher.rollback(len(old_output_ids) - k)
|
66
|
+
|
67
|
+
for i in range(k, len(new_output_ids)):
|
68
|
+
assert self.matcher.accept_token(new_output_ids[i])
|
69
|
+
|
70
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
71
|
+
# Note that this bitmask is a bitset, not bool
|
72
|
+
bitmask = self.matcher.get_next_token_bitmask()
|
73
|
+
# Mask the tokens that are not allowed
|
74
|
+
vocab_mask[
|
75
|
+
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
76
|
+
] = 1
|
77
|
+
|
78
|
+
def copy(self):
|
79
|
+
matcher = GrammarMatcher(
|
80
|
+
self.ctx,
|
81
|
+
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
82
|
+
mask_vocab_size=self.vocab_size,
|
83
|
+
)
|
84
|
+
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
85
|
+
|
86
|
+
|
87
|
+
class XGrammarGrammarBackend(BaseGrammarBackend):
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
tokenizer,
|
91
|
+
vocab_size: int,
|
92
|
+
):
|
93
|
+
super().__init__()
|
94
|
+
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
95
|
+
self.vocab_size = vocab_size
|
96
|
+
|
97
|
+
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
98
|
+
key_type, key_string = key
|
99
|
+
if key_type == "json":
|
100
|
+
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
101
|
+
elif key_type == "regex":
|
102
|
+
raise ValueError("regex hasn't been supported by xgrammar yet")
|
103
|
+
else:
|
104
|
+
raise ValueError(f"Invalid key_type: {key_type}")
|
105
|
+
|
106
|
+
matcher = GrammarMatcher(
|
107
|
+
ctx,
|
108
|
+
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
109
|
+
mask_vocab_size=self.vocab_size,
|
110
|
+
)
|
111
|
+
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
112
|
+
|
113
|
+
def reset(self):
|
114
|
+
self.grammar_cache.clear()
|
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
|
|
507
507
|
|
508
508
|
num_warps = 4
|
509
509
|
|
510
|
+
extra_kargs = {}
|
511
|
+
if is_hip():
|
512
|
+
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
513
|
+
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
514
|
+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
515
|
+
|
510
516
|
_fwd_grouped_kernel_stage1[grid](
|
511
517
|
q,
|
512
518
|
k_buffer,
|
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
|
|
532
538
|
num_warps=num_warps,
|
533
539
|
num_stages=1,
|
534
540
|
Lk=Lk,
|
541
|
+
**extra_kargs,
|
535
542
|
)
|
536
543
|
|
537
544
|
|
@@ -25,6 +25,7 @@ import triton.language as tl
|
|
25
25
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
26
26
|
context_attention_fwd,
|
27
27
|
)
|
28
|
+
from sglang.srt.utils import is_hip
|
28
29
|
|
29
30
|
is_cuda_available = torch.cuda.is_available()
|
30
31
|
if is_cuda_available:
|
@@ -311,6 +312,10 @@ def extend_attention_fwd(
|
|
311
312
|
num_warps = 4 if Lk <= 64 else 8
|
312
313
|
num_stages = 1
|
313
314
|
|
315
|
+
extra_kargs = {}
|
316
|
+
if is_hip():
|
317
|
+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
318
|
+
|
314
319
|
_fwd_kernel[grid](
|
315
320
|
q_extend,
|
316
321
|
k_extend,
|
@@ -348,6 +353,7 @@ def extend_attention_fwd(
|
|
348
353
|
Lv=Lv,
|
349
354
|
num_warps=num_warps,
|
350
355
|
num_stages=num_stages,
|
356
|
+
**extra_kargs,
|
351
357
|
)
|
352
358
|
|
353
359
|
|
@@ -54,6 +54,7 @@ def fused_moe_kernel(
|
|
54
54
|
top_k: tl.constexpr,
|
55
55
|
compute_type: tl.constexpr,
|
56
56
|
use_fp8: tl.constexpr,
|
57
|
+
even_Ks: tl.constexpr,
|
57
58
|
):
|
58
59
|
"""
|
59
60
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
@@ -130,16 +131,24 @@ def fused_moe_kernel(
|
|
130
131
|
# of fp32 values for higher accuracy.
|
131
132
|
# `accumulator` will be converted back to fp16 after the loop.
|
132
133
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
133
|
-
|
134
134
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
135
135
|
# Load the next block of A and B, generate a mask by checking the
|
136
136
|
# K dimension.
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
137
|
+
if even_Ks:
|
138
|
+
a = tl.load(
|
139
|
+
a_ptrs,
|
140
|
+
mask=token_mask[:, None],
|
141
|
+
other=0.0,
|
142
|
+
)
|
143
|
+
b = tl.load(b_ptrs)
|
144
|
+
else:
|
145
|
+
a = tl.load(
|
146
|
+
a_ptrs,
|
147
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
148
|
+
other=0.0,
|
149
|
+
)
|
150
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
151
|
+
|
143
152
|
# We accumulate along the K dimension.
|
144
153
|
if use_fp8:
|
145
154
|
accumulator = tl.dot(a, b, acc=accumulator)
|
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
|
|
253
262
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
254
263
|
)
|
255
264
|
|
265
|
+
K = B.shape[2] - padding_size
|
266
|
+
if K % config["BLOCK_SIZE_K"] == 0:
|
267
|
+
even_ks = True
|
268
|
+
else:
|
269
|
+
even_ks = False
|
270
|
+
|
256
271
|
fused_moe_kernel[grid](
|
257
272
|
A,
|
258
273
|
B,
|
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
|
|
278
293
|
top_k=top_k,
|
279
294
|
compute_type=compute_type,
|
280
295
|
use_fp8=use_fp8,
|
296
|
+
even_Ks=even_ks,
|
281
297
|
**config,
|
282
298
|
)
|
283
299
|
|
@@ -122,16 +122,14 @@ class QuantizationConfig(ABC):
|
|
122
122
|
"""
|
123
123
|
raise NotImplementedError
|
124
124
|
|
125
|
-
|
126
|
-
|
125
|
+
|
126
|
+
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
|
127
127
|
"""
|
128
128
|
Not all quant methods have embedding implemented, so we need to check that
|
129
129
|
it exists for our given method. We check this by making sure the function
|
130
130
|
has been changed from the base implementation.
|
131
131
|
"""
|
132
|
-
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
|
133
|
-
None)
|
132
|
+
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
|
134
133
|
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
135
134
|
|
136
|
-
return
|
137
|
-
and class_embedding is not base_embedding)
|
135
|
+
return class_embedding is not None and class_embedding is not base_embedding
|