sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_offline_throughput.py +309 -0
  2. sglang/bench_serving.py +148 -24
  3. sglang/srt/configs/model_config.py +5 -2
  4. sglang/srt/constrained/__init__.py +2 -66
  5. sglang/srt/constrained/base_grammar_backend.py +73 -0
  6. sglang/srt/constrained/outlines_backend.py +165 -0
  7. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  8. sglang/srt/constrained/xgrammar_backend.py +150 -0
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  11. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  12. sglang/srt/layers/fused_moe/patch.py +4 -2
  13. sglang/srt/layers/quantization/base_config.py +4 -6
  14. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  15. sglang/srt/managers/detokenizer_manager.py +0 -14
  16. sglang/srt/managers/io_struct.py +5 -3
  17. sglang/srt/managers/schedule_batch.py +14 -20
  18. sglang/srt/managers/scheduler.py +159 -96
  19. sglang/srt/managers/tokenizer_manager.py +81 -17
  20. sglang/srt/metrics/collector.py +211 -0
  21. sglang/srt/metrics/func_timer.py +108 -0
  22. sglang/srt/mm_utils.py +1 -1
  23. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  24. sglang/srt/model_executor/forward_batch_info.py +7 -3
  25. sglang/srt/model_executor/model_runner.py +6 -2
  26. sglang/srt/models/gemma2_reward.py +69 -0
  27. sglang/srt/models/gpt2.py +31 -37
  28. sglang/srt/models/internlm2_reward.py +62 -0
  29. sglang/srt/models/llama.py +11 -6
  30. sglang/srt/models/llama_reward.py +5 -26
  31. sglang/srt/models/qwen2_vl.py +5 -7
  32. sglang/srt/openai_api/adapter.py +11 -4
  33. sglang/srt/openai_api/protocol.py +29 -26
  34. sglang/srt/sampling/sampling_batch_info.py +2 -3
  35. sglang/srt/sampling/sampling_params.py +2 -16
  36. sglang/srt/server.py +60 -17
  37. sglang/srt/server_args.py +66 -25
  38. sglang/srt/utils.py +120 -0
  39. sglang/test/simple_eval_common.py +1 -1
  40. sglang/test/simple_eval_humaneval.py +2 -2
  41. sglang/test/simple_eval_mgsm.py +2 -2
  42. sglang/test/test_utils.py +21 -7
  43. sglang/utils.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
  46. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
  47. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
  48. sglang/srt/constrained/base_tool_cache.py +0 -65
  49. sglang/srt/constrained/bnf_cache.py +0 -61
  50. sglang/srt/constrained/fsm_cache.py +0 -95
  51. sglang/srt/constrained/grammar.py +0 -190
  52. sglang/srt/constrained/jump_forward.py +0 -203
  53. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
  54. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """The baseclass of a backend for grammar-guided constrained decoding."""
17
+
18
+ from concurrent.futures import Future, ThreadPoolExecutor
19
+ from dataclasses import dataclass
20
+ from threading import Event, Lock
21
+ from typing import Any, Optional, Tuple
22
+
23
+
24
+ @dataclass
25
+ class CacheEntry:
26
+ value: Any
27
+ event: Event
28
+
29
+
30
+ class BaseGrammarObject:
31
+ pass
32
+
33
+
34
+ class BaseGrammarBackend:
35
+ def __init__(self):
36
+ self.executor = ThreadPoolExecutor()
37
+ self.cache = {}
38
+ self.cache_lock = Lock()
39
+
40
+ def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
41
+ with self.cache_lock:
42
+ if key in self.cache:
43
+ cache_hit = True
44
+ entry = self.cache[key]
45
+ else:
46
+ cache_hit = False
47
+ entry = CacheEntry(None, Event())
48
+ self.cache[key] = entry
49
+
50
+ if cache_hit:
51
+ entry.event.wait()
52
+ else:
53
+ entry.value = self.init_value_impl(key)
54
+ entry.event.set()
55
+ return entry.value.copy() if entry.value else None
56
+
57
+ def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
58
+ raise NotImplementedError()
59
+
60
+ def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
61
+ with self.cache_lock:
62
+ entry = self.cache.get(key)
63
+ if not entry or not entry.event.is_set():
64
+ return None
65
+ val = self.cache[key].value
66
+ return val.copy() if val else None
67
+
68
+ def get_future_value(self, key: Tuple[str, str]) -> Future:
69
+ return self.executor.submit(self.init_value, key)
70
+
71
+ def reset(self):
72
+ with self.cache_lock:
73
+ self.cache.clear()
@@ -0,0 +1,165 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Constrained decoding with outlines backend."""
17
+
18
+ import json
19
+ import logging
20
+ from typing import Dict, List, Optional, Tuple, Union
21
+
22
+ import interegular
23
+ import torch
24
+ from outlines.fsm.guide import RegexGuide
25
+ from outlines.fsm.json_schema import build_regex_from_schema
26
+ from outlines.models.transformers import TransformerTokenizer
27
+ from pydantic import BaseModel
28
+
29
+ from sglang.srt.constrained.base_grammar_backend import (
30
+ BaseGrammarBackend,
31
+ BaseGrammarObject,
32
+ )
33
+ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class OutlinesGrammar(BaseGrammarObject):
39
+ def __init__(
40
+ self,
41
+ guide: RegexGuide,
42
+ jump_forward_map: Union[OutlinesJumpForwardMap, None],
43
+ ) -> None:
44
+ self.guide = guide
45
+ self.jump_forward_map = jump_forward_map
46
+ self.state = 0
47
+
48
+ def accept_token(self, token: int):
49
+ self.state = self.guide.get_next_state(self.state, token)
50
+
51
+ def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
52
+ if not self.jump_forward_map:
53
+ return None
54
+
55
+ jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
56
+ if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
57
+ return None
58
+
59
+ # preprocess the jump forward string
60
+ suffix_bytes = []
61
+ continuation_range = range(0x80, 0xC0)
62
+ cur_state = self.state
63
+ while (
64
+ len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
65
+ ):
66
+ # continuation bytes
67
+ byte_edge = jump_forward_bytes.pop(0)
68
+ suffix_bytes.append(byte_edge[0])
69
+ cur_state = byte_edge[1]
70
+
71
+ suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
72
+ suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
73
+ return suffix_ids, cur_state
74
+
75
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
76
+ _, cur_state = helper
77
+ return self.jump_forward_map.jump_forward_symbol(cur_state)
78
+
79
+ def jump_and_retokenize(
80
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
81
+ ):
82
+ self.state = next_state
83
+
84
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor):
85
+ vocab_mask.fill_(1)
86
+ vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
87
+
88
+ def copy(self):
89
+ return OutlinesGrammar(self.guide, self.jump_forward_map)
90
+
91
+
92
+ class OutlinesGrammarBackend(BaseGrammarBackend):
93
+ def __init__(
94
+ self,
95
+ tokenizer,
96
+ whitespace_pattern: bool,
97
+ allow_jump_forward: bool,
98
+ ):
99
+ super().__init__()
100
+
101
+ try:
102
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
103
+ except AttributeError:
104
+ # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
105
+ origin_pad_token_id = tokenizer.pad_token_id
106
+
107
+ def fset(self, value):
108
+ self._value = value
109
+
110
+ type(tokenizer).pad_token_id = property(
111
+ fget=type(tokenizer).pad_token_id.fget, fset=fset
112
+ )
113
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
114
+ self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
115
+ self.outlines_tokenizer.pad_token_id = origin_pad_token_id
116
+ self.outlines_tokenizer.pad_token = (
117
+ self.outlines_tokenizer.tokenizer.pad_token
118
+ )
119
+ self.outlines_tokenizer.vocabulary = (
120
+ self.outlines_tokenizer.tokenizer.get_vocab()
121
+ )
122
+ self.allow_jump_forward = allow_jump_forward
123
+ self.whitespace_pattern = whitespace_pattern
124
+
125
+ def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
126
+ key_type, key_string = key
127
+ if key_type == "json":
128
+ try:
129
+ regex = build_regex_from_object(
130
+ key_string,
131
+ whitespace_pattern=self.whitespace_pattern,
132
+ )
133
+ except (NotImplementedError, json.decoder.JSONDecodeError) as e:
134
+ logger.warning(
135
+ f"Skip invalid json_schema: json_schema={key_string}, {e=}"
136
+ )
137
+ return None
138
+ elif key_type == "regex":
139
+ regex = key_string
140
+ else:
141
+ raise ValueError(f"Invalid key_type: {key_type}")
142
+
143
+ try:
144
+ guide = RegexGuide(regex, self.outlines_tokenizer)
145
+ except interegular.patterns.InvalidSyntax as e:
146
+ logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
147
+ return None
148
+
149
+ if self.allow_jump_forward:
150
+ jump_forward_map = OutlinesJumpForwardMap(regex)
151
+ else:
152
+ jump_forward_map = None
153
+ return OutlinesGrammar(guide, jump_forward_map)
154
+
155
+
156
+ def build_regex_from_object(
157
+ object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
158
+ ):
159
+ if isinstance(object, type(BaseModel)):
160
+ schema = json.dumps(object.model_json_schema())
161
+ elif isinstance(object, Dict):
162
+ schema = json.dumps(object)
163
+ else:
164
+ schema = object
165
+ return build_regex_from_schema(schema, whitespace_pattern)
@@ -0,0 +1,182 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Faster constrained decoding with jump forward decoding / compressed finite state machine.
18
+ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
19
+ """
20
+
21
+ import dataclasses
22
+ import logging
23
+ from collections import defaultdict
24
+
25
+ import interegular
26
+ from interegular import InvalidSyntax
27
+ from outlines.caching import cache as disk_cache
28
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
29
+
30
+ IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class JumpEdge:
37
+ symbol: str = None
38
+ symbol_next_state: int = None
39
+ byte: int = None
40
+ byte_next_state: int = None
41
+
42
+
43
+ @disk_cache()
44
+ def init_state_to_jump_forward(regex_string):
45
+ try:
46
+ regex_pattern = interegular.parse_pattern(regex_string)
47
+ except InvalidSyntax as e:
48
+ logger.warning(f"skip invalid regex: {regex_string}, {e=}")
49
+ return
50
+
51
+ byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
52
+ regex_fsm, _ = make_deterministic_fsm(byte_fsm)
53
+
54
+ fsm_info: FSMInfo = regex_fsm.fsm_info
55
+
56
+ symbol_to_id = fsm_info.alphabet_symbol_mapping
57
+ id_to_symbol = {}
58
+ for symbol, id_ in symbol_to_id.items():
59
+ id_to_symbol.setdefault(id_, []).append(symbol)
60
+
61
+ transitions = fsm_info.transitions
62
+
63
+ outgoings_ct = defaultdict(int)
64
+ # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
65
+ for s in fsm_info.finals:
66
+ outgoings_ct[s] = 1
67
+
68
+ state_to_jump_forward = {}
69
+ for (state, id_), next_state in transitions.items():
70
+ if id_ == fsm_info.alphabet_anything_value:
71
+ # Arbitrarily symbol cannot be recognized as jump forward
72
+ continue
73
+
74
+ symbols = id_to_symbol[id_]
75
+ for c in symbols:
76
+ if len(c) > 1:
77
+ # Skip byte level transitions like c = "5E"
78
+ continue
79
+
80
+ outgoings_ct[state] += 1
81
+ if outgoings_ct[state] > 1:
82
+ if state in state_to_jump_forward:
83
+ del state_to_jump_forward[state]
84
+ break
85
+
86
+ state_to_jump_forward[state] = JumpEdge(
87
+ symbol=c,
88
+ symbol_next_state=next_state,
89
+ )
90
+
91
+ # Process the byte level jump forward
92
+ outgoings_ct = defaultdict(int)
93
+ for s in fsm_info.finals:
94
+ outgoings_ct[s] = 1
95
+
96
+ for (state, id_), next_state in transitions.items():
97
+ if id_ == fsm_info.alphabet_anything_value:
98
+ continue
99
+ symbols = id_to_symbol[id_]
100
+ for c in symbols:
101
+ byte_ = None
102
+ if len(c) == 1 and ord(c) < 0x80:
103
+ # ASCII character
104
+ byte_ = ord(c)
105
+ elif len(c) > 1:
106
+ # FIXME: This logic is due to the leading \x00
107
+ # https://github.com/outlines-dev/outlines/pull/930
108
+ byte_ = int(symbols[0][1:], 16)
109
+
110
+ if byte_ is not None:
111
+ outgoings_ct[state] += 1
112
+ if outgoings_ct[state] > 1:
113
+ if state in state_to_jump_forward:
114
+ del state_to_jump_forward[state]
115
+ break
116
+ e = state_to_jump_forward.get(state, JumpEdge())
117
+ e.byte = byte_
118
+ e.byte_next_state = next_state
119
+ state_to_jump_forward[state] = e
120
+
121
+ return state_to_jump_forward
122
+
123
+
124
+ class OutlinesJumpForwardMap:
125
+ def __init__(self, regex_string):
126
+ self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
127
+
128
+ def jump_forward_symbol(self, state):
129
+ jump_forward_str = ""
130
+ next_state = state
131
+ while state in self.state_to_jump_forward:
132
+ e = self.state_to_jump_forward[state]
133
+ if e.symbol is None:
134
+ break
135
+ jump_forward_str += e.symbol
136
+ next_state = e.symbol_next_state
137
+ state = next_state
138
+
139
+ return jump_forward_str, next_state
140
+
141
+ def jump_forward_byte(self, state):
142
+ if state not in self.state_to_jump_forward:
143
+ return None
144
+
145
+ jump_forward_bytes = []
146
+ next_state = None
147
+ while state in self.state_to_jump_forward:
148
+ e = self.state_to_jump_forward[state]
149
+ assert e.byte is not None and e.byte_next_state is not None
150
+ jump_forward_bytes.append((e.byte, e.byte_next_state))
151
+ next_state = e.byte_next_state
152
+ state = next_state
153
+
154
+ return jump_forward_bytes
155
+
156
+ def is_jump_forward_symbol_state(self, state):
157
+ return (
158
+ state in self.state_to_jump_forward
159
+ and self.state_to_jump_forward[state].symbol is not None
160
+ )
161
+
162
+
163
+ def test_main(regex_string):
164
+ jump_forward_map = OutlinesJumpForwardMap(regex_string)
165
+ for state, e in jump_forward_map.state_to_jump_forward.items():
166
+ if e.symbol is not None:
167
+ jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
168
+ print(f"{state} -> {next_state}", jump_forward_str)
169
+ bytes_ = jump_forward_map.jump_forward_byte(state)
170
+ print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
171
+
172
+
173
+ if __name__ == "__main__":
174
+ import outlines
175
+
176
+ outlines.caching.clear_cache()
177
+ test_main(r"The google's DNS sever address is " + IP_REGEX)
178
+ test_main(r"霍格沃茨特快列车|霍比特人比尔博")
179
+ # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
180
+ # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
181
+
182
+ test_main(r"[-+]?[0-9]+[ ]*")
@@ -0,0 +1,150 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Constrained decoding with xgrammar backend."""
17
+
18
+ import logging
19
+ from typing import List, Tuple
20
+
21
+ import torch
22
+
23
+ try:
24
+ from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
25
+
26
+ import_error = None
27
+ except ImportError as e:
28
+ CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
29
+ ImportError
30
+ )
31
+ import_error = e
32
+
33
+ from sglang.srt.constrained.base_grammar_backend import (
34
+ BaseGrammarBackend,
35
+ BaseGrammarObject,
36
+ )
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ MAX_ROLLBACK_TOKENS = 10
42
+
43
+
44
+ class XGrammarGrammar(BaseGrammarObject):
45
+
46
+ def __init__(
47
+ self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
48
+ ) -> None:
49
+ self.matcher = matcher
50
+ self.vocab_size = vocab_size
51
+ self.ctx = ctx
52
+
53
+ def accept_token(self, token: int):
54
+ assert self.matcher.accept_token(token)
55
+
56
+ def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
57
+ s = self.matcher.find_jump_forward_string()
58
+ if s:
59
+ return [], s
60
+ return None
61
+
62
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
63
+ _, data = helper
64
+ return data, -1
65
+
66
+ def jump_and_retokenize(
67
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
68
+ ):
69
+ k = 0
70
+ for i, old_id in enumerate(old_output_ids):
71
+ if old_id == new_output_ids[i]:
72
+ k = i + 1
73
+ else:
74
+ break
75
+
76
+ # rollback to the last token that is the same
77
+ if k < len(old_output_ids):
78
+ self.matcher.rollback(len(old_output_ids) - k)
79
+
80
+ for i in range(k, len(new_output_ids)):
81
+ assert self.matcher.accept_token(new_output_ids[i])
82
+
83
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor):
84
+ # Note that this bitmask is a bitset, not bool
85
+ bitmask = self.matcher.get_next_token_bitmask()
86
+ # Mask the tokens that are not allowed
87
+ vocab_mask[
88
+ self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
89
+ ] = 1
90
+
91
+ def copy(self):
92
+ matcher = GrammarMatcher(
93
+ self.ctx,
94
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
95
+ mask_vocab_size=self.vocab_size,
96
+ )
97
+ return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
98
+
99
+
100
+ class XGrammarGrammarBackend(BaseGrammarBackend):
101
+ def __init__(
102
+ self,
103
+ tokenizer,
104
+ vocab_size: int,
105
+ ):
106
+ super().__init__()
107
+
108
+ if import_error:
109
+ logger.warning(
110
+ f"Ignore import error for the grammar backend: {import_error}"
111
+ )
112
+ self.grammar_cache = None
113
+ return
114
+
115
+ self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
116
+ self.vocab_size = vocab_size
117
+
118
+ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
119
+ if import_error:
120
+ raise import_error
121
+
122
+ key_type, key_string = key
123
+ if key_type == "json":
124
+ try:
125
+ ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
126
+ key_string
127
+ )
128
+ except RuntimeError as e:
129
+ logging.warning(
130
+ f"Skip invalid json_schema: json_schema={key_string}, {e=}"
131
+ )
132
+ return None
133
+ elif key_type == "regex":
134
+ logger.warning(
135
+ "regex hasn't been supported by xgrammar yet. This is skipped."
136
+ )
137
+ return None
138
+ else:
139
+ raise ValueError(f"Invalid key_type: {key_type}")
140
+
141
+ matcher = GrammarMatcher(
142
+ ctx,
143
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
144
+ mask_vocab_size=self.vocab_size,
145
+ )
146
+ return XGrammarGrammar(matcher, self.vocab_size, ctx)
147
+
148
+ def reset(self):
149
+ if self.grammar_cache:
150
+ self.grammar_cache.clear()
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
507
507
 
508
508
  num_warps = 4
509
509
 
510
+ extra_kargs = {}
511
+ if is_hip():
512
+ # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
513
+ # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
514
+ extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
515
+
510
516
  _fwd_grouped_kernel_stage1[grid](
511
517
  q,
512
518
  k_buffer,
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
532
538
  num_warps=num_warps,
533
539
  num_stages=1,
534
540
  Lk=Lk,
541
+ **extra_kargs,
535
542
  )
536
543
 
537
544
 
@@ -25,6 +25,7 @@ import triton.language as tl
25
25
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
26
26
  context_attention_fwd,
27
27
  )
28
+ from sglang.srt.utils import is_hip
28
29
 
29
30
  is_cuda_available = torch.cuda.is_available()
30
31
  if is_cuda_available:
@@ -311,6 +312,10 @@ def extend_attention_fwd(
311
312
  num_warps = 4 if Lk <= 64 else 8
312
313
  num_stages = 1
313
314
 
315
+ extra_kargs = {}
316
+ if is_hip():
317
+ extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
318
+
314
319
  _fwd_kernel[grid](
315
320
  q_extend,
316
321
  k_extend,
@@ -348,6 +353,7 @@ def extend_attention_fwd(
348
353
  Lv=Lv,
349
354
  num_warps=num_warps,
350
355
  num_stages=num_stages,
356
+ **extra_kargs,
351
357
  )
352
358
 
353
359
 
@@ -54,6 +54,7 @@ def fused_moe_kernel(
54
54
  top_k: tl.constexpr,
55
55
  compute_type: tl.constexpr,
56
56
  use_fp8: tl.constexpr,
57
+ even_Ks: tl.constexpr,
57
58
  ):
58
59
  """
59
60
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -130,16 +131,24 @@ def fused_moe_kernel(
130
131
  # of fp32 values for higher accuracy.
131
132
  # `accumulator` will be converted back to fp16 after the loop.
132
133
  accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
133
-
134
134
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
135
135
  # Load the next block of A and B, generate a mask by checking the
136
136
  # K dimension.
137
- a = tl.load(
138
- a_ptrs,
139
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
140
- other=0.0,
141
- )
142
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
137
+ if even_Ks:
138
+ a = tl.load(
139
+ a_ptrs,
140
+ mask=token_mask[:, None],
141
+ other=0.0,
142
+ )
143
+ b = tl.load(b_ptrs)
144
+ else:
145
+ a = tl.load(
146
+ a_ptrs,
147
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
148
+ other=0.0,
149
+ )
150
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
151
+
143
152
  # We accumulate along the K dimension.
144
153
  if use_fp8:
145
154
  accumulator = tl.dot(a, b, acc=accumulator)
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
253
262
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
254
263
  )
255
264
 
265
+ K = B.shape[2] - padding_size
266
+ if K % config["BLOCK_SIZE_K"] == 0:
267
+ even_ks = True
268
+ else:
269
+ even_ks = False
270
+
256
271
  fused_moe_kernel[grid](
257
272
  A,
258
273
  B,
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
278
293
  top_k=top_k,
279
294
  compute_type=compute_type,
280
295
  use_fp8=use_fp8,
296
+ even_Ks=even_ks,
281
297
  **config,
282
298
  )
283
299
 
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from typing import Callable, Optional
2
2
 
3
3
  import torch
4
4
  from torch.nn import functional as F
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
98
98
  renormalize: bool,
99
99
  topk_group: Optional[int] = None,
100
100
  num_expert_group: Optional[int] = None,
101
+ custom_routing_function: Optional[Callable] = None,
101
102
  ) -> torch.Tensor:
103
+ assert custom_routing_function is None
102
104
  topk_weights, topk_ids = select_experts_native(
103
105
  hidden_states=x,
104
106
  router_logits=router_logits,
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
114
116
  x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
115
117
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
116
118
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
117
- return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
119
+ return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))