sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,43 @@
|
|
1
|
-
|
1
|
+
"""
|
2
|
+
Faster constrained decoding.
|
3
|
+
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from collections import defaultdict
|
2
8
|
|
3
|
-
|
9
|
+
import interegular
|
10
|
+
import outlines.caching
|
11
|
+
|
12
|
+
from sglang.srt.constrained import (
|
13
|
+
FSMInfo,
|
14
|
+
disk_cache,
|
15
|
+
make_byte_level_fsm,
|
16
|
+
make_deterministic_fsm,
|
17
|
+
)
|
4
18
|
from sglang.srt.constrained.base_cache import BaseCache
|
5
19
|
|
6
20
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
7
21
|
|
8
22
|
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class JumpEdge:
|
25
|
+
symbol: str = None
|
26
|
+
symbol_next_state: int = None
|
27
|
+
byte: int = None
|
28
|
+
byte_next_state: int = None
|
29
|
+
|
30
|
+
|
9
31
|
class JumpForwardMap:
|
10
32
|
def __init__(self, regex_string):
|
11
33
|
@disk_cache()
|
12
34
|
def _init_state_to_jump_forward(regex_string):
|
13
35
|
regex_pattern = interegular.parse_pattern(regex_string)
|
14
|
-
|
36
|
+
|
37
|
+
byte_fsm = make_byte_level_fsm(
|
38
|
+
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
39
|
+
)
|
40
|
+
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
15
41
|
|
16
42
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
17
43
|
|
@@ -21,40 +47,93 @@ class JumpForwardMap:
|
|
21
47
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
22
48
|
|
23
49
|
transitions = fsm_info.transitions
|
24
|
-
|
50
|
+
outgoings_ct = defaultdict(int)
|
25
51
|
state_to_jump_forward = {}
|
26
52
|
|
27
53
|
for (state, id_), next_state in transitions.items():
|
28
|
-
if
|
29
|
-
continue
|
30
|
-
if state in state_to_jump_forward:
|
31
|
-
dirty_states.add(state)
|
32
|
-
del state_to_jump_forward[state]
|
54
|
+
if id_ == fsm_info.alphabet_anything_value:
|
33
55
|
continue
|
34
|
-
|
35
|
-
|
56
|
+
symbols = id_to_symbol[id_]
|
57
|
+
for c in symbols:
|
58
|
+
if len(c) > 1:
|
59
|
+
# Skip byte level transitions
|
60
|
+
continue
|
61
|
+
|
62
|
+
outgoings_ct[state] += 1
|
63
|
+
if outgoings_ct[state] > 1:
|
64
|
+
if state in state_to_jump_forward:
|
65
|
+
del state_to_jump_forward[state]
|
66
|
+
break
|
67
|
+
|
68
|
+
state_to_jump_forward[state] = JumpEdge(
|
69
|
+
symbol=c,
|
70
|
+
symbol_next_state=next_state,
|
71
|
+
)
|
72
|
+
|
73
|
+
# Process the byte level jump forward
|
74
|
+
outgoings_ct = defaultdict(int)
|
75
|
+
for (state, id_), next_state in transitions.items():
|
76
|
+
if id_ == fsm_info.alphabet_anything_value:
|
36
77
|
continue
|
37
|
-
|
38
|
-
|
78
|
+
symbols = id_to_symbol[id_]
|
79
|
+
for c in symbols:
|
80
|
+
byte_ = None
|
81
|
+
if len(c) == 1 and ord(c) < 0x80:
|
82
|
+
# ASCII character
|
83
|
+
byte_ = ord(c)
|
84
|
+
elif len(c) > 1:
|
85
|
+
# FIXME: This logic is due to the leading \x00
|
86
|
+
# https://github.com/outlines-dev/outlines/pull/930
|
87
|
+
byte_ = int(symbols[0][1:], 16)
|
88
|
+
|
89
|
+
if byte_ is not None:
|
90
|
+
outgoings_ct[state] += 1
|
91
|
+
if outgoings_ct[state] > 1:
|
92
|
+
if state in state_to_jump_forward:
|
93
|
+
del state_to_jump_forward[state]
|
94
|
+
break
|
95
|
+
e = state_to_jump_forward.get(state, JumpEdge())
|
96
|
+
e.byte = byte_
|
97
|
+
e.byte_next_state = next_state
|
98
|
+
state_to_jump_forward[state] = e
|
39
99
|
|
40
100
|
return state_to_jump_forward
|
41
101
|
|
42
102
|
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
43
103
|
|
44
|
-
def
|
45
|
-
|
104
|
+
def jump_forward_symbol(self, state):
|
105
|
+
jump_forward_str = ""
|
106
|
+
next_state = state
|
107
|
+
while state in self.state_to_jump_forward:
|
108
|
+
e = self.state_to_jump_forward[state]
|
109
|
+
if e.symbol is None:
|
110
|
+
break
|
111
|
+
jump_forward_str += e.symbol
|
112
|
+
next_state = e.symbol_next_state
|
113
|
+
state = next_state
|
46
114
|
|
47
|
-
|
115
|
+
return jump_forward_str, next_state
|
116
|
+
|
117
|
+
def jump_forward_byte(self, state):
|
48
118
|
if state not in self.state_to_jump_forward:
|
49
119
|
return None
|
50
120
|
|
51
|
-
|
121
|
+
jump_forward_bytes = []
|
52
122
|
next_state = None
|
53
123
|
while state in self.state_to_jump_forward:
|
54
|
-
|
55
|
-
|
124
|
+
e = self.state_to_jump_forward[state]
|
125
|
+
assert e.byte is not None and e.byte_next_state is not None
|
126
|
+
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
127
|
+
next_state = e.byte_next_state
|
56
128
|
state = next_state
|
57
|
-
|
129
|
+
|
130
|
+
return jump_forward_bytes
|
131
|
+
|
132
|
+
def is_jump_forward_symbol_state(self, state):
|
133
|
+
return (
|
134
|
+
state in self.state_to_jump_forward
|
135
|
+
and self.state_to_jump_forward[state].symbol is not None
|
136
|
+
)
|
58
137
|
|
59
138
|
|
60
139
|
class JumpForwardCache(BaseCache):
|
@@ -65,12 +144,21 @@ class JumpForwardCache(BaseCache):
|
|
65
144
|
return JumpForwardMap(regex)
|
66
145
|
|
67
146
|
|
68
|
-
def test_main():
|
69
|
-
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
147
|
+
def test_main(regex_string):
|
70
148
|
jump_forward_map = JumpForwardMap(regex_string)
|
71
|
-
for state in jump_forward_map.
|
72
|
-
|
149
|
+
for state, e in jump_forward_map.state_to_jump_forward.items():
|
150
|
+
if e.symbol is not None:
|
151
|
+
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
152
|
+
print(f"{state} -> {next_state}", jump_forward_str)
|
153
|
+
bytes_ = jump_forward_map.jump_forward_byte(state)
|
154
|
+
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
73
155
|
|
74
156
|
|
75
157
|
if __name__ == "__main__":
|
76
|
-
|
158
|
+
import outlines
|
159
|
+
|
160
|
+
outlines.caching.clear_cache()
|
161
|
+
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
162
|
+
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
163
|
+
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
164
|
+
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
sglang/srt/conversation.py
CHANGED
sglang/srt/flush_cache.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
"""Utilities for Huggingface Transformers."""
|
2
2
|
|
3
|
+
import functools
|
3
4
|
import json
|
4
5
|
import os
|
5
6
|
import warnings
|
6
|
-
import
|
7
|
-
from typing import Optional, Union, AbstractSet, Collection, Literal
|
7
|
+
from typing import AbstractSet, Collection, Literal, Optional, Union
|
8
8
|
|
9
9
|
from huggingface_hub import snapshot_download
|
10
10
|
from transformers import (
|
@@ -88,6 +88,9 @@ def get_tokenizer(
|
|
88
88
|
if tokenizer_name.endswith(".json"):
|
89
89
|
return TiktokenTokenizer(tokenizer_name)
|
90
90
|
|
91
|
+
if tokenizer_name.endswith(".model"):
|
92
|
+
return SentencePieceTokenizer(tokenizer_name)
|
93
|
+
|
91
94
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
92
95
|
if is_multimodal_model(tokenizer_name):
|
93
96
|
processor = get_processor(
|
@@ -179,6 +182,8 @@ def get_processor(
|
|
179
182
|
class TiktokenTokenizer:
|
180
183
|
def __init__(self, tokenizer_path):
|
181
184
|
import tiktoken
|
185
|
+
from jinja2 import Template
|
186
|
+
|
182
187
|
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
183
188
|
|
184
189
|
# Read JSON
|
@@ -190,7 +195,8 @@ class TiktokenTokenizer:
|
|
190
195
|
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
191
196
|
}
|
192
197
|
special_tokens = {
|
193
|
-
bytes(item["bytes"]).decode(): item["token"]
|
198
|
+
bytes(item["bytes"]).decode(): item["token"]
|
199
|
+
for item in tok_dict["special_tokens"]
|
194
200
|
}
|
195
201
|
assert tok_dict["word_split"] == "V1"
|
196
202
|
|
@@ -202,7 +208,10 @@ class TiktokenTokenizer:
|
|
202
208
|
}
|
203
209
|
if "default_allowed_special" in tok_dict:
|
204
210
|
default_allowed_special = set(
|
205
|
-
[
|
211
|
+
[
|
212
|
+
bytes(bytes_list).decode()
|
213
|
+
for bytes_list in tok_dict["default_allowed_special"]
|
214
|
+
]
|
206
215
|
)
|
207
216
|
else:
|
208
217
|
default_allowed_special = None
|
@@ -211,25 +220,35 @@ class TiktokenTokenizer:
|
|
211
220
|
|
212
221
|
tokenizer = tiktoken.Encoding(**kwargs)
|
213
222
|
tokenizer._default_allowed_special = default_allowed_special or set()
|
223
|
+
tokenizer._default_allowed_special |= {"<|separator|>"}
|
214
224
|
|
215
225
|
def encode_patched(
|
216
226
|
self,
|
217
227
|
text: str,
|
218
228
|
*,
|
219
|
-
allowed_special: Union[
|
229
|
+
allowed_special: Union[
|
230
|
+
Literal["all"], AbstractSet[str]
|
231
|
+
] = set(), # noqa: B006
|
220
232
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
221
233
|
) -> list[int]:
|
222
234
|
if isinstance(allowed_special, set):
|
223
235
|
allowed_special |= self._default_allowed_special
|
224
236
|
return tiktoken.Encoding.encode(
|
225
|
-
self,
|
237
|
+
self,
|
238
|
+
text,
|
239
|
+
allowed_special=allowed_special,
|
240
|
+
disallowed_special=disallowed_special,
|
226
241
|
)
|
242
|
+
|
227
243
|
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
228
244
|
|
229
245
|
# Convert to HF interface
|
230
246
|
self.tokenizer = tokenizer
|
231
247
|
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
|
232
248
|
self.vocab_size = tokenizer.n_vocab
|
249
|
+
self.chat_template = Template(
|
250
|
+
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
251
|
+
)
|
233
252
|
|
234
253
|
def encode(self, x, add_special_tokens=False):
|
235
254
|
return self.tokenizer.encode(x)
|
@@ -237,10 +256,46 @@ class TiktokenTokenizer:
|
|
237
256
|
def decode(self, x):
|
238
257
|
return self.tokenizer.decode(x)
|
239
258
|
|
240
|
-
def batch_decode(
|
259
|
+
def batch_decode(
|
260
|
+
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
261
|
+
):
|
241
262
|
if isinstance(batch[0], int):
|
242
263
|
batch = [[x] for x in batch]
|
243
264
|
return self.tokenizer.decode_batch(batch)
|
244
265
|
|
245
|
-
def
|
246
|
-
|
266
|
+
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
267
|
+
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
|
268
|
+
return self.encode(ret) if tokenize else ret
|
269
|
+
|
270
|
+
|
271
|
+
class SentencePieceTokenizer:
|
272
|
+
def __init__(self, tokenizer_path):
|
273
|
+
import sentencepiece as spm
|
274
|
+
from jinja2 import Template
|
275
|
+
|
276
|
+
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
|
277
|
+
|
278
|
+
# Convert to HF interface
|
279
|
+
self.tokenizer = tokenizer
|
280
|
+
self.eos_token_id = tokenizer.eos_id()
|
281
|
+
self.vocab_size = tokenizer.vocab_size()
|
282
|
+
self.chat_template = Template(
|
283
|
+
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
284
|
+
)
|
285
|
+
|
286
|
+
def encode(self, x, add_special_tokens=False):
|
287
|
+
return self.tokenizer.encode(x)
|
288
|
+
|
289
|
+
def decode(self, x):
|
290
|
+
return self.tokenizer.decode(x)
|
291
|
+
|
292
|
+
def batch_decode(
|
293
|
+
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
294
|
+
):
|
295
|
+
if isinstance(batch[0], int):
|
296
|
+
batch = [[x] for x in batch]
|
297
|
+
return self.tokenizer.decode(batch)
|
298
|
+
|
299
|
+
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
300
|
+
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
|
301
|
+
return self.encode(ret) if tokenize else ret
|
sglang/srt/layers/fused_moe.py
CHANGED
@@ -12,7 +12,6 @@ import triton.language as tl
|
|
12
12
|
|
13
13
|
from vllm import _custom_ops as ops
|
14
14
|
from vllm.logger import init_logger
|
15
|
-
from vllm.utils import is_hip
|
16
15
|
|
17
16
|
logger = init_logger(__name__)
|
18
17
|
|
@@ -310,92 +309,110 @@ def get_moe_configs(E: int, N: int,
|
|
310
309
|
return None
|
311
310
|
|
312
311
|
|
313
|
-
def
|
312
|
+
def get_default_config(
|
313
|
+
M: int,
|
314
|
+
E: int,
|
315
|
+
N: int,
|
316
|
+
K: int,
|
317
|
+
topk: int,
|
318
|
+
dtype: Optional[str],
|
319
|
+
) -> Dict[str, int]:
|
320
|
+
if dtype == "float8":
|
321
|
+
config = {
|
322
|
+
'BLOCK_SIZE_M': 128,
|
323
|
+
'BLOCK_SIZE_N': 256,
|
324
|
+
'BLOCK_SIZE_K': 128,
|
325
|
+
'GROUP_SIZE_M': 32,
|
326
|
+
"num_warps": 8,
|
327
|
+
"num_stages": 4
|
328
|
+
}
|
329
|
+
if M <= E:
|
330
|
+
config = {
|
331
|
+
'BLOCK_SIZE_M': 64,
|
332
|
+
'BLOCK_SIZE_N': 128,
|
333
|
+
'BLOCK_SIZE_K': 128,
|
334
|
+
'GROUP_SIZE_M': 1,
|
335
|
+
"num_warps": 4,
|
336
|
+
"num_stages": 4
|
337
|
+
}
|
338
|
+
else:
|
339
|
+
config = {
|
340
|
+
'BLOCK_SIZE_M': 64,
|
341
|
+
'BLOCK_SIZE_N': 64,
|
342
|
+
'BLOCK_SIZE_K': 32,
|
343
|
+
'GROUP_SIZE_M': 8
|
344
|
+
}
|
345
|
+
if M <= E:
|
346
|
+
config = {
|
347
|
+
'BLOCK_SIZE_M': 16,
|
348
|
+
'BLOCK_SIZE_N': 32,
|
349
|
+
'BLOCK_SIZE_K': 64,
|
350
|
+
'GROUP_SIZE_M': 1
|
351
|
+
}
|
352
|
+
return config
|
353
|
+
|
354
|
+
|
355
|
+
def fused_topk(
|
314
356
|
hidden_states: torch.Tensor,
|
315
|
-
w1: torch.Tensor,
|
316
|
-
w2: torch.Tensor,
|
317
357
|
gating_output: torch.Tensor,
|
318
358
|
topk: int,
|
319
359
|
renormalize: bool,
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
w1_scale: Optional[torch.Tensor] = None,
|
324
|
-
w2_scale: Optional[torch.Tensor] = None,
|
325
|
-
a1_scale: Optional[torch.Tensor] = None,
|
326
|
-
a2_scale: Optional[torch.Tensor] = None,
|
327
|
-
) -> torch.Tensor:
|
328
|
-
"""
|
329
|
-
This function computes a Mixture of Experts (MoE) layer using two sets of
|
330
|
-
weights, w1 and w2, and top-k gating mechanism.
|
360
|
+
):
|
361
|
+
assert hidden_states.shape[0] == gating_output.shape[0], (
|
362
|
+
"Number of tokens mismatch")
|
331
363
|
|
332
|
-
|
333
|
-
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
334
|
-
- w1 (torch.Tensor): The first set of expert weights.
|
335
|
-
- w2 (torch.Tensor): The second set of expert weights.
|
336
|
-
- gating_output (torch.Tensor): The output of the gating operation
|
337
|
-
(before softmax).
|
338
|
-
- topk (int): The number of top-k experts to select.
|
339
|
-
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
340
|
-
- inplace (bool): If True, perform the operation in-place.
|
341
|
-
Defaults to False.
|
342
|
-
- override_config (Optional[Dict[str, Any]]): Optional override
|
343
|
-
for the kernel configuration.
|
344
|
-
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
345
|
-
products for w1 and w2. Defaults to False.
|
346
|
-
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
347
|
-
w1.
|
348
|
-
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
349
|
-
w2.
|
364
|
+
M, _ = hidden_states.shape
|
350
365
|
|
351
|
-
|
352
|
-
|
353
|
-
|
366
|
+
topk_weights = torch.empty(M,
|
367
|
+
topk,
|
368
|
+
dtype=torch.float32,
|
369
|
+
device=hidden_states.device)
|
370
|
+
topk_ids = torch.empty(M,
|
371
|
+
topk,
|
372
|
+
dtype=torch.int32,
|
373
|
+
device=hidden_states.device)
|
374
|
+
token_expert_indicies = torch.empty(M,
|
375
|
+
topk,
|
376
|
+
dtype=torch.int32,
|
377
|
+
device=hidden_states.device)
|
378
|
+
ops.topk_softmax(
|
379
|
+
topk_weights,
|
380
|
+
topk_ids,
|
381
|
+
token_expert_indicies,
|
382
|
+
gating_output.float(), # TODO(woosuk): Optimize this.
|
383
|
+
)
|
384
|
+
del token_expert_indicies # Not used. Will be used in the future.
|
385
|
+
|
386
|
+
if renormalize:
|
387
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
388
|
+
return topk_weights, topk_ids
|
389
|
+
|
390
|
+
|
391
|
+
def fused_experts(hidden_states: torch.Tensor,
|
392
|
+
w1: torch.Tensor,
|
393
|
+
w2: torch.Tensor,
|
394
|
+
topk_weights: torch.Tensor,
|
395
|
+
topk_ids: torch.Tensor,
|
396
|
+
inplace: bool = False,
|
397
|
+
override_config: Optional[Dict[str, Any]] = None,
|
398
|
+
use_fp8: bool = False,
|
399
|
+
w1_scale: Optional[torch.Tensor] = None,
|
400
|
+
w2_scale: Optional[torch.Tensor] = None,
|
401
|
+
a1_scale: Optional[torch.Tensor] = None,
|
402
|
+
a2_scale: Optional[torch.Tensor] = None):
|
354
403
|
# Check constraints.
|
355
|
-
assert hidden_states.shape[0] == gating_output.shape[0], (
|
356
|
-
"Number of tokens mismatch")
|
357
404
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
358
|
-
assert
|
405
|
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
359
406
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
360
407
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
361
408
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
362
409
|
assert hidden_states.dtype in [
|
363
410
|
torch.float32, torch.float16, torch.bfloat16
|
364
411
|
]
|
412
|
+
|
365
413
|
M, _ = hidden_states.shape
|
366
414
|
E, N, _ = w1.shape
|
367
415
|
|
368
|
-
if is_hip():
|
369
|
-
# The MoE kernels are not yet supported on ROCm.
|
370
|
-
routing_weights = torch.softmax(gating_output,
|
371
|
-
dim=-1,
|
372
|
-
dtype=torch.float32)
|
373
|
-
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
374
|
-
else:
|
375
|
-
import vllm._moe_C as moe_kernels
|
376
|
-
|
377
|
-
topk_weights = torch.empty(M,
|
378
|
-
topk,
|
379
|
-
dtype=torch.float32,
|
380
|
-
device=hidden_states.device)
|
381
|
-
topk_ids = torch.empty(M,
|
382
|
-
topk,
|
383
|
-
dtype=torch.int32,
|
384
|
-
device=hidden_states.device)
|
385
|
-
token_expert_indicies = torch.empty(M,
|
386
|
-
topk,
|
387
|
-
dtype=torch.int32,
|
388
|
-
device=hidden_states.device)
|
389
|
-
moe_kernels.topk_softmax(
|
390
|
-
topk_weights,
|
391
|
-
topk_ids,
|
392
|
-
token_expert_indicies,
|
393
|
-
gating_output.float(), # TODO(woosuk): Optimize this.
|
394
|
-
)
|
395
|
-
del token_expert_indicies # Not used. Will be used in the future.
|
396
|
-
if renormalize:
|
397
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
398
|
-
|
399
416
|
if override_config:
|
400
417
|
config = override_config
|
401
418
|
else:
|
@@ -409,24 +426,9 @@ def fused_moe(
|
|
409
426
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
410
427
|
else:
|
411
428
|
# Else use the default config
|
412
|
-
config =
|
413
|
-
|
414
|
-
|
415
|
-
"BLOCK_SIZE_K": 128,
|
416
|
-
"GROUP_SIZE_M": 1,
|
417
|
-
"num_warps": 4,
|
418
|
-
"num_stages": 4
|
419
|
-
}
|
420
|
-
|
421
|
-
if M <= E:
|
422
|
-
config = {
|
423
|
-
"BLOCK_SIZE_M": 128,
|
424
|
-
"BLOCK_SIZE_N": 256,
|
425
|
-
"BLOCK_SIZE_K": 128,
|
426
|
-
"GROUP_SIZE_M": 16,
|
427
|
-
"num_warps": 8,
|
428
|
-
"num_stages": 4
|
429
|
-
}
|
429
|
+
config = get_default_config(M, E, N, w1.shape[2],
|
430
|
+
topk_ids.shape[1],
|
431
|
+
"float8" if use_fp8 else None)
|
430
432
|
|
431
433
|
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
432
434
|
device=hidden_states.device,
|
@@ -482,4 +484,99 @@ def fused_moe(
|
|
482
484
|
dim=1,
|
483
485
|
out=hidden_states)
|
484
486
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
485
|
-
dim=1)
|
487
|
+
dim=1)
|
488
|
+
|
489
|
+
|
490
|
+
def fused_moe(
|
491
|
+
hidden_states: torch.Tensor,
|
492
|
+
w1: torch.Tensor,
|
493
|
+
w2: torch.Tensor,
|
494
|
+
gating_output: torch.Tensor,
|
495
|
+
topk: int,
|
496
|
+
renormalize: bool,
|
497
|
+
inplace: bool = False,
|
498
|
+
override_config: Optional[Dict[str, Any]] = None,
|
499
|
+
use_fp8: bool = False,
|
500
|
+
w1_scale: Optional[torch.Tensor] = None,
|
501
|
+
w2_scale: Optional[torch.Tensor] = None,
|
502
|
+
a1_scale: Optional[torch.Tensor] = None,
|
503
|
+
a2_scale: Optional[torch.Tensor] = None,
|
504
|
+
) -> torch.Tensor:
|
505
|
+
"""
|
506
|
+
This function computes a Mixture of Experts (MoE) layer using two sets of
|
507
|
+
weights, w1 and w2, and top-k gating mechanism.
|
508
|
+
|
509
|
+
Parameters:
|
510
|
+
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
511
|
+
- w1 (torch.Tensor): The first set of expert weights.
|
512
|
+
- w2 (torch.Tensor): The second set of expert weights.
|
513
|
+
- gating_output (torch.Tensor): The output of the gating operation
|
514
|
+
(before softmax).
|
515
|
+
- topk (int): The number of top-k experts to select.
|
516
|
+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
517
|
+
- inplace (bool): If True, perform the operation in-place.
|
518
|
+
Defaults to False.
|
519
|
+
- override_config (Optional[Dict[str, Any]]): Optional override
|
520
|
+
for the kernel configuration.
|
521
|
+
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
522
|
+
products for w1 and w2. Defaults to False.
|
523
|
+
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
524
|
+
w1.
|
525
|
+
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
526
|
+
w2.
|
527
|
+
|
528
|
+
Returns:
|
529
|
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
530
|
+
"""
|
531
|
+
# Check constraints.
|
532
|
+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
533
|
+
|
534
|
+
if hasattr(ops, "topk_softmax"):
|
535
|
+
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
536
|
+
renormalize)
|
537
|
+
else:
|
538
|
+
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
|
539
|
+
renormalize)
|
540
|
+
|
541
|
+
return fused_experts(hidden_states,
|
542
|
+
w1,
|
543
|
+
w2,
|
544
|
+
topk_weights,
|
545
|
+
topk_ids,
|
546
|
+
inplace=inplace,
|
547
|
+
override_config=override_config,
|
548
|
+
use_fp8=use_fp8,
|
549
|
+
w1_scale=w1_scale,
|
550
|
+
w2_scale=w2_scale,
|
551
|
+
a1_scale=a1_scale,
|
552
|
+
a2_scale=a2_scale)
|
553
|
+
|
554
|
+
|
555
|
+
|
556
|
+
def fused_topk_v0_4_3(
|
557
|
+
hidden_states: torch.Tensor,
|
558
|
+
gating_output: torch.Tensor,
|
559
|
+
topk: int,
|
560
|
+
renormalize: bool,
|
561
|
+
):
|
562
|
+
import vllm._moe_C as moe_kernels
|
563
|
+
M, _ = hidden_states.shape
|
564
|
+
|
565
|
+
topk_weights = torch.empty(
|
566
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
567
|
+
)
|
568
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
569
|
+
token_expert_indicies = torch.empty(
|
570
|
+
M, topk, dtype=torch.int32, device=hidden_states.device
|
571
|
+
)
|
572
|
+
moe_kernels.topk_softmax(
|
573
|
+
topk_weights,
|
574
|
+
topk_ids,
|
575
|
+
token_expert_indicies,
|
576
|
+
gating_output.float(), # TODO(woosuk): Optimize this.
|
577
|
+
)
|
578
|
+
del token_expert_indicies # Not used. Will be used in the future.
|
579
|
+
if renormalize:
|
580
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
581
|
+
|
582
|
+
return topk_weights, topk_ids
|