sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +4 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/bench_latency.py +299 -0
  6. sglang/global_config.py +4 -1
  7. sglang/lang/compiler.py +2 -2
  8. sglang/lang/interpreter.py +1 -1
  9. sglang/lang/ir.py +15 -5
  10. sglang/launch_server.py +4 -1
  11. sglang/launch_server_llavavid.py +2 -1
  12. sglang/srt/constrained/__init__.py +13 -6
  13. sglang/srt/constrained/fsm_cache.py +6 -3
  14. sglang/srt/constrained/jump_forward.py +113 -25
  15. sglang/srt/conversation.py +2 -0
  16. sglang/srt/flush_cache.py +2 -0
  17. sglang/srt/hf_transformers_utils.py +64 -9
  18. sglang/srt/layers/fused_moe.py +186 -89
  19. sglang/srt/layers/logits_processor.py +53 -25
  20. sglang/srt/layers/radix_attention.py +34 -7
  21. sglang/srt/managers/controller/dp_worker.py +6 -3
  22. sglang/srt/managers/controller/infer_batch.py +142 -67
  23. sglang/srt/managers/controller/manager_multi.py +5 -5
  24. sglang/srt/managers/controller/manager_single.py +8 -3
  25. sglang/srt/managers/controller/model_runner.py +154 -54
  26. sglang/srt/managers/controller/radix_cache.py +4 -0
  27. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  28. sglang/srt/managers/controller/tp_worker.py +140 -135
  29. sglang/srt/managers/detokenizer_manager.py +15 -19
  30. sglang/srt/managers/io_struct.py +10 -4
  31. sglang/srt/managers/tokenizer_manager.py +14 -13
  32. sglang/srt/model_config.py +83 -4
  33. sglang/srt/models/chatglm.py +399 -0
  34. sglang/srt/models/commandr.py +2 -2
  35. sglang/srt/models/dbrx.py +1 -1
  36. sglang/srt/models/gemma.py +5 -1
  37. sglang/srt/models/grok.py +204 -137
  38. sglang/srt/models/llama2.py +11 -4
  39. sglang/srt/models/llama_classification.py +104 -0
  40. sglang/srt/models/llava.py +11 -8
  41. sglang/srt/models/llavavid.py +1 -1
  42. sglang/srt/models/mixtral.py +164 -115
  43. sglang/srt/models/mixtral_quant.py +0 -1
  44. sglang/srt/models/qwen.py +1 -1
  45. sglang/srt/models/qwen2.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/models/yivl.py +2 -2
  48. sglang/srt/openai_api_adapter.py +33 -23
  49. sglang/srt/openai_protocol.py +1 -1
  50. sglang/srt/server.py +60 -19
  51. sglang/srt/server_args.py +79 -44
  52. sglang/srt/utils.py +146 -37
  53. sglang/test/test_programs.py +28 -10
  54. sglang/utils.py +4 -3
  55. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
  56. sglang-0.1.18.dist-info/RECORD +78 -0
  57. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  58. sglang/srt/managers/router/infer_batch.py +0 -596
  59. sglang/srt/managers/router/manager.py +0 -82
  60. sglang/srt/managers/router/model_rpc.py +0 -818
  61. sglang/srt/managers/router/model_runner.py +0 -445
  62. sglang/srt/managers/router/radix_cache.py +0 -267
  63. sglang/srt/managers/router/scheduler.py +0 -59
  64. sglang-0.1.17.dist-info/RECORD +0 -81
  65. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  66. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,43 @@
1
- import interegular
1
+ """
2
+ Faster constrained decoding.
3
+ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
4
+ """
5
+
6
+ import dataclasses
7
+ from collections import defaultdict
2
8
 
3
- from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
9
+ import interegular
10
+ import outlines.caching
11
+
12
+ from sglang.srt.constrained import (
13
+ FSMInfo,
14
+ disk_cache,
15
+ make_byte_level_fsm,
16
+ make_deterministic_fsm,
17
+ )
4
18
  from sglang.srt.constrained.base_cache import BaseCache
5
19
 
6
20
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
7
21
 
8
22
 
23
+ @dataclasses.dataclass
24
+ class JumpEdge:
25
+ symbol: str = None
26
+ symbol_next_state: int = None
27
+ byte: int = None
28
+ byte_next_state: int = None
29
+
30
+
9
31
  class JumpForwardMap:
10
32
  def __init__(self, regex_string):
11
33
  @disk_cache()
12
34
  def _init_state_to_jump_forward(regex_string):
13
35
  regex_pattern = interegular.parse_pattern(regex_string)
14
- regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
36
+
37
+ byte_fsm = make_byte_level_fsm(
38
+ regex_pattern.to_fsm().reduce(), keep_utf8=True
39
+ )
40
+ regex_fsm, _ = make_deterministic_fsm(byte_fsm)
15
41
 
16
42
  fsm_info: FSMInfo = regex_fsm.fsm_info
17
43
 
@@ -21,40 +47,93 @@ class JumpForwardMap:
21
47
  id_to_symbol.setdefault(id_, []).append(symbol)
22
48
 
23
49
  transitions = fsm_info.transitions
24
- dirty_states = set()
50
+ outgoings_ct = defaultdict(int)
25
51
  state_to_jump_forward = {}
26
52
 
27
53
  for (state, id_), next_state in transitions.items():
28
- if state in dirty_states:
29
- continue
30
- if state in state_to_jump_forward:
31
- dirty_states.add(state)
32
- del state_to_jump_forward[state]
54
+ if id_ == fsm_info.alphabet_anything_value:
33
55
  continue
34
- if len(id_to_symbol[id_]) > 1:
35
- dirty_states.add(state)
56
+ symbols = id_to_symbol[id_]
57
+ for c in symbols:
58
+ if len(c) > 1:
59
+ # Skip byte level transitions
60
+ continue
61
+
62
+ outgoings_ct[state] += 1
63
+ if outgoings_ct[state] > 1:
64
+ if state in state_to_jump_forward:
65
+ del state_to_jump_forward[state]
66
+ break
67
+
68
+ state_to_jump_forward[state] = JumpEdge(
69
+ symbol=c,
70
+ symbol_next_state=next_state,
71
+ )
72
+
73
+ # Process the byte level jump forward
74
+ outgoings_ct = defaultdict(int)
75
+ for (state, id_), next_state in transitions.items():
76
+ if id_ == fsm_info.alphabet_anything_value:
36
77
  continue
37
-
38
- state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
78
+ symbols = id_to_symbol[id_]
79
+ for c in symbols:
80
+ byte_ = None
81
+ if len(c) == 1 and ord(c) < 0x80:
82
+ # ASCII character
83
+ byte_ = ord(c)
84
+ elif len(c) > 1:
85
+ # FIXME: This logic is due to the leading \x00
86
+ # https://github.com/outlines-dev/outlines/pull/930
87
+ byte_ = int(symbols[0][1:], 16)
88
+
89
+ if byte_ is not None:
90
+ outgoings_ct[state] += 1
91
+ if outgoings_ct[state] > 1:
92
+ if state in state_to_jump_forward:
93
+ del state_to_jump_forward[state]
94
+ break
95
+ e = state_to_jump_forward.get(state, JumpEdge())
96
+ e.byte = byte_
97
+ e.byte_next_state = next_state
98
+ state_to_jump_forward[state] = e
39
99
 
40
100
  return state_to_jump_forward
41
101
 
42
102
  self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
43
103
 
44
- def valid_states(self):
45
- return self.state_to_jump_forward.keys()
104
+ def jump_forward_symbol(self, state):
105
+ jump_forward_str = ""
106
+ next_state = state
107
+ while state in self.state_to_jump_forward:
108
+ e = self.state_to_jump_forward[state]
109
+ if e.symbol is None:
110
+ break
111
+ jump_forward_str += e.symbol
112
+ next_state = e.symbol_next_state
113
+ state = next_state
46
114
 
47
- def jump_forward(self, state):
115
+ return jump_forward_str, next_state
116
+
117
+ def jump_forward_byte(self, state):
48
118
  if state not in self.state_to_jump_forward:
49
119
  return None
50
120
 
51
- jump_forward_str = ""
121
+ jump_forward_bytes = []
52
122
  next_state = None
53
123
  while state in self.state_to_jump_forward:
54
- symbol, next_state = self.state_to_jump_forward[state]
55
- jump_forward_str += symbol
124
+ e = self.state_to_jump_forward[state]
125
+ assert e.byte is not None and e.byte_next_state is not None
126
+ jump_forward_bytes.append((e.byte, e.byte_next_state))
127
+ next_state = e.byte_next_state
56
128
  state = next_state
57
- return jump_forward_str, next_state
129
+
130
+ return jump_forward_bytes
131
+
132
+ def is_jump_forward_symbol_state(self, state):
133
+ return (
134
+ state in self.state_to_jump_forward
135
+ and self.state_to_jump_forward[state].symbol is not None
136
+ )
58
137
 
59
138
 
60
139
  class JumpForwardCache(BaseCache):
@@ -65,12 +144,21 @@ class JumpForwardCache(BaseCache):
65
144
  return JumpForwardMap(regex)
66
145
 
67
146
 
68
- def test_main():
69
- regex_string = r"The google's DNS sever address is " + IP_REGEX
147
+ def test_main(regex_string):
70
148
  jump_forward_map = JumpForwardMap(regex_string)
71
- for state in jump_forward_map.valid_states():
72
- print(state, f'"{jump_forward_map.jump_forward(state)}"')
149
+ for state, e in jump_forward_map.state_to_jump_forward.items():
150
+ if e.symbol is not None:
151
+ jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
152
+ print(f"{state} -> {next_state}", jump_forward_str)
153
+ bytes_ = jump_forward_map.jump_forward_byte(state)
154
+ print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
73
155
 
74
156
 
75
157
  if __name__ == "__main__":
76
- test_main()
158
+ import outlines
159
+
160
+ outlines.caching.clear_cache()
161
+ test_main(r"The google's DNS sever address is " + IP_REGEX)
162
+ test_main(r"霍格沃茨特快列车|霍比特人比尔博")
163
+ # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
164
+ # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
@@ -1,3 +1,5 @@
1
+ """Conversation templates."""
2
+
1
3
  # Adapted from
2
4
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
5
  import dataclasses
sglang/srt/flush_cache.py CHANGED
@@ -1,4 +1,6 @@
1
1
  """
2
+ Flush the KV cache.
3
+
2
4
  Usage:
3
5
  python3 -m sglang.srt.flush_cache --url http://localhost:30000
4
6
  """
@@ -1,10 +1,10 @@
1
1
  """Utilities for Huggingface Transformers."""
2
2
 
3
+ import functools
3
4
  import json
4
5
  import os
5
6
  import warnings
6
- import functools
7
- from typing import Optional, Union, AbstractSet, Collection, Literal
7
+ from typing import AbstractSet, Collection, Literal, Optional, Union
8
8
 
9
9
  from huggingface_hub import snapshot_download
10
10
  from transformers import (
@@ -88,6 +88,9 @@ def get_tokenizer(
88
88
  if tokenizer_name.endswith(".json"):
89
89
  return TiktokenTokenizer(tokenizer_name)
90
90
 
91
+ if tokenizer_name.endswith(".model"):
92
+ return SentencePieceTokenizer(tokenizer_name)
93
+
91
94
  """Gets a tokenizer for the given model name via Huggingface."""
92
95
  if is_multimodal_model(tokenizer_name):
93
96
  processor = get_processor(
@@ -179,6 +182,8 @@ def get_processor(
179
182
  class TiktokenTokenizer:
180
183
  def __init__(self, tokenizer_path):
181
184
  import tiktoken
185
+ from jinja2 import Template
186
+
182
187
  PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
183
188
 
184
189
  # Read JSON
@@ -190,7 +195,8 @@ class TiktokenTokenizer:
190
195
  bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
191
196
  }
192
197
  special_tokens = {
193
- bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
198
+ bytes(item["bytes"]).decode(): item["token"]
199
+ for item in tok_dict["special_tokens"]
194
200
  }
195
201
  assert tok_dict["word_split"] == "V1"
196
202
 
@@ -202,7 +208,10 @@ class TiktokenTokenizer:
202
208
  }
203
209
  if "default_allowed_special" in tok_dict:
204
210
  default_allowed_special = set(
205
- [bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
211
+ [
212
+ bytes(bytes_list).decode()
213
+ for bytes_list in tok_dict["default_allowed_special"]
214
+ ]
206
215
  )
207
216
  else:
208
217
  default_allowed_special = None
@@ -211,25 +220,35 @@ class TiktokenTokenizer:
211
220
 
212
221
  tokenizer = tiktoken.Encoding(**kwargs)
213
222
  tokenizer._default_allowed_special = default_allowed_special or set()
223
+ tokenizer._default_allowed_special |= {"<|separator|>"}
214
224
 
215
225
  def encode_patched(
216
226
  self,
217
227
  text: str,
218
228
  *,
219
- allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
229
+ allowed_special: Union[
230
+ Literal["all"], AbstractSet[str]
231
+ ] = set(), # noqa: B006
220
232
  disallowed_special: Union[Literal["all"], Collection[str]] = "all",
221
233
  ) -> list[int]:
222
234
  if isinstance(allowed_special, set):
223
235
  allowed_special |= self._default_allowed_special
224
236
  return tiktoken.Encoding.encode(
225
- self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
237
+ self,
238
+ text,
239
+ allowed_special=allowed_special,
240
+ disallowed_special=disallowed_special,
226
241
  )
242
+
227
243
  tokenizer.encode = functools.partial(encode_patched, tokenizer)
228
244
 
229
245
  # Convert to HF interface
230
246
  self.tokenizer = tokenizer
231
247
  self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
232
248
  self.vocab_size = tokenizer.n_vocab
249
+ self.chat_template = Template(
250
+ "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
251
+ )
233
252
 
234
253
  def encode(self, x, add_special_tokens=False):
235
254
  return self.tokenizer.encode(x)
@@ -237,10 +256,46 @@ class TiktokenTokenizer:
237
256
  def decode(self, x):
238
257
  return self.tokenizer.decode(x)
239
258
 
240
- def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
259
+ def batch_decode(
260
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
261
+ ):
241
262
  if isinstance(batch[0], int):
242
263
  batch = [[x] for x in batch]
243
264
  return self.tokenizer.decode_batch(batch)
244
265
 
245
- def convert_ids_to_tokens(self, index):
246
- return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
266
+ def apply_chat_template(self, messages, tokenize, add_generation_prompt):
267
+ ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
268
+ return self.encode(ret) if tokenize else ret
269
+
270
+
271
+ class SentencePieceTokenizer:
272
+ def __init__(self, tokenizer_path):
273
+ import sentencepiece as spm
274
+ from jinja2 import Template
275
+
276
+ tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
277
+
278
+ # Convert to HF interface
279
+ self.tokenizer = tokenizer
280
+ self.eos_token_id = tokenizer.eos_id()
281
+ self.vocab_size = tokenizer.vocab_size()
282
+ self.chat_template = Template(
283
+ "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
284
+ )
285
+
286
+ def encode(self, x, add_special_tokens=False):
287
+ return self.tokenizer.encode(x)
288
+
289
+ def decode(self, x):
290
+ return self.tokenizer.decode(x)
291
+
292
+ def batch_decode(
293
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
294
+ ):
295
+ if isinstance(batch[0], int):
296
+ batch = [[x] for x in batch]
297
+ return self.tokenizer.decode(batch)
298
+
299
+ def apply_chat_template(self, messages, tokenize, add_generation_prompt):
300
+ ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
301
+ return self.encode(ret) if tokenize else ret
@@ -12,7 +12,6 @@ import triton.language as tl
12
12
 
13
13
  from vllm import _custom_ops as ops
14
14
  from vllm.logger import init_logger
15
- from vllm.utils import is_hip
16
15
 
17
16
  logger = init_logger(__name__)
18
17
 
@@ -310,92 +309,110 @@ def get_moe_configs(E: int, N: int,
310
309
  return None
311
310
 
312
311
 
313
- def fused_moe(
312
+ def get_default_config(
313
+ M: int,
314
+ E: int,
315
+ N: int,
316
+ K: int,
317
+ topk: int,
318
+ dtype: Optional[str],
319
+ ) -> Dict[str, int]:
320
+ if dtype == "float8":
321
+ config = {
322
+ 'BLOCK_SIZE_M': 128,
323
+ 'BLOCK_SIZE_N': 256,
324
+ 'BLOCK_SIZE_K': 128,
325
+ 'GROUP_SIZE_M': 32,
326
+ "num_warps": 8,
327
+ "num_stages": 4
328
+ }
329
+ if M <= E:
330
+ config = {
331
+ 'BLOCK_SIZE_M': 64,
332
+ 'BLOCK_SIZE_N': 128,
333
+ 'BLOCK_SIZE_K': 128,
334
+ 'GROUP_SIZE_M': 1,
335
+ "num_warps": 4,
336
+ "num_stages": 4
337
+ }
338
+ else:
339
+ config = {
340
+ 'BLOCK_SIZE_M': 64,
341
+ 'BLOCK_SIZE_N': 64,
342
+ 'BLOCK_SIZE_K': 32,
343
+ 'GROUP_SIZE_M': 8
344
+ }
345
+ if M <= E:
346
+ config = {
347
+ 'BLOCK_SIZE_M': 16,
348
+ 'BLOCK_SIZE_N': 32,
349
+ 'BLOCK_SIZE_K': 64,
350
+ 'GROUP_SIZE_M': 1
351
+ }
352
+ return config
353
+
354
+
355
+ def fused_topk(
314
356
  hidden_states: torch.Tensor,
315
- w1: torch.Tensor,
316
- w2: torch.Tensor,
317
357
  gating_output: torch.Tensor,
318
358
  topk: int,
319
359
  renormalize: bool,
320
- inplace: bool = False,
321
- override_config: Optional[Dict[str, Any]] = None,
322
- use_fp8: bool = False,
323
- w1_scale: Optional[torch.Tensor] = None,
324
- w2_scale: Optional[torch.Tensor] = None,
325
- a1_scale: Optional[torch.Tensor] = None,
326
- a2_scale: Optional[torch.Tensor] = None,
327
- ) -> torch.Tensor:
328
- """
329
- This function computes a Mixture of Experts (MoE) layer using two sets of
330
- weights, w1 and w2, and top-k gating mechanism.
360
+ ):
361
+ assert hidden_states.shape[0] == gating_output.shape[0], (
362
+ "Number of tokens mismatch")
331
363
 
332
- Parameters:
333
- - hidden_states (torch.Tensor): The input tensor to the MoE layer.
334
- - w1 (torch.Tensor): The first set of expert weights.
335
- - w2 (torch.Tensor): The second set of expert weights.
336
- - gating_output (torch.Tensor): The output of the gating operation
337
- (before softmax).
338
- - topk (int): The number of top-k experts to select.
339
- - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
340
- - inplace (bool): If True, perform the operation in-place.
341
- Defaults to False.
342
- - override_config (Optional[Dict[str, Any]]): Optional override
343
- for the kernel configuration.
344
- - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
345
- products for w1 and w2. Defaults to False.
346
- - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
347
- w1.
348
- - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
349
- w2.
364
+ M, _ = hidden_states.shape
350
365
 
351
- Returns:
352
- - torch.Tensor: The output tensor after applying the MoE layer.
353
- """
366
+ topk_weights = torch.empty(M,
367
+ topk,
368
+ dtype=torch.float32,
369
+ device=hidden_states.device)
370
+ topk_ids = torch.empty(M,
371
+ topk,
372
+ dtype=torch.int32,
373
+ device=hidden_states.device)
374
+ token_expert_indicies = torch.empty(M,
375
+ topk,
376
+ dtype=torch.int32,
377
+ device=hidden_states.device)
378
+ ops.topk_softmax(
379
+ topk_weights,
380
+ topk_ids,
381
+ token_expert_indicies,
382
+ gating_output.float(), # TODO(woosuk): Optimize this.
383
+ )
384
+ del token_expert_indicies # Not used. Will be used in the future.
385
+
386
+ if renormalize:
387
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
388
+ return topk_weights, topk_ids
389
+
390
+
391
+ def fused_experts(hidden_states: torch.Tensor,
392
+ w1: torch.Tensor,
393
+ w2: torch.Tensor,
394
+ topk_weights: torch.Tensor,
395
+ topk_ids: torch.Tensor,
396
+ inplace: bool = False,
397
+ override_config: Optional[Dict[str, Any]] = None,
398
+ use_fp8: bool = False,
399
+ w1_scale: Optional[torch.Tensor] = None,
400
+ w2_scale: Optional[torch.Tensor] = None,
401
+ a1_scale: Optional[torch.Tensor] = None,
402
+ a2_scale: Optional[torch.Tensor] = None):
354
403
  # Check constraints.
355
- assert hidden_states.shape[0] == gating_output.shape[0], (
356
- "Number of tokens mismatch")
357
404
  assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
358
- assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
405
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
359
406
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
360
407
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
361
408
  assert w2.is_contiguous(), "Expert weights2 must be contiguous"
362
409
  assert hidden_states.dtype in [
363
410
  torch.float32, torch.float16, torch.bfloat16
364
411
  ]
412
+
365
413
  M, _ = hidden_states.shape
366
414
  E, N, _ = w1.shape
367
415
 
368
- if is_hip():
369
- # The MoE kernels are not yet supported on ROCm.
370
- routing_weights = torch.softmax(gating_output,
371
- dim=-1,
372
- dtype=torch.float32)
373
- topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
374
- else:
375
- import vllm._moe_C as moe_kernels
376
-
377
- topk_weights = torch.empty(M,
378
- topk,
379
- dtype=torch.float32,
380
- device=hidden_states.device)
381
- topk_ids = torch.empty(M,
382
- topk,
383
- dtype=torch.int32,
384
- device=hidden_states.device)
385
- token_expert_indicies = torch.empty(M,
386
- topk,
387
- dtype=torch.int32,
388
- device=hidden_states.device)
389
- moe_kernels.topk_softmax(
390
- topk_weights,
391
- topk_ids,
392
- token_expert_indicies,
393
- gating_output.float(), # TODO(woosuk): Optimize this.
394
- )
395
- del token_expert_indicies # Not used. Will be used in the future.
396
- if renormalize:
397
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
398
-
399
416
  if override_config:
400
417
  config = override_config
401
418
  else:
@@ -409,24 +426,9 @@ def fused_moe(
409
426
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
410
427
  else:
411
428
  # Else use the default config
412
- config = {
413
- "BLOCK_SIZE_M": 128,
414
- "BLOCK_SIZE_N": 64,
415
- "BLOCK_SIZE_K": 128,
416
- "GROUP_SIZE_M": 1,
417
- "num_warps": 4,
418
- "num_stages": 4
419
- }
420
-
421
- if M <= E:
422
- config = {
423
- "BLOCK_SIZE_M": 128,
424
- "BLOCK_SIZE_N": 256,
425
- "BLOCK_SIZE_K": 128,
426
- "GROUP_SIZE_M": 16,
427
- "num_warps": 8,
428
- "num_stages": 4
429
- }
429
+ config = get_default_config(M, E, N, w1.shape[2],
430
+ topk_ids.shape[1],
431
+ "float8" if use_fp8 else None)
430
432
 
431
433
  intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
432
434
  device=hidden_states.device,
@@ -482,4 +484,99 @@ def fused_moe(
482
484
  dim=1,
483
485
  out=hidden_states)
484
486
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
485
- dim=1)
487
+ dim=1)
488
+
489
+
490
+ def fused_moe(
491
+ hidden_states: torch.Tensor,
492
+ w1: torch.Tensor,
493
+ w2: torch.Tensor,
494
+ gating_output: torch.Tensor,
495
+ topk: int,
496
+ renormalize: bool,
497
+ inplace: bool = False,
498
+ override_config: Optional[Dict[str, Any]] = None,
499
+ use_fp8: bool = False,
500
+ w1_scale: Optional[torch.Tensor] = None,
501
+ w2_scale: Optional[torch.Tensor] = None,
502
+ a1_scale: Optional[torch.Tensor] = None,
503
+ a2_scale: Optional[torch.Tensor] = None,
504
+ ) -> torch.Tensor:
505
+ """
506
+ This function computes a Mixture of Experts (MoE) layer using two sets of
507
+ weights, w1 and w2, and top-k gating mechanism.
508
+
509
+ Parameters:
510
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
511
+ - w1 (torch.Tensor): The first set of expert weights.
512
+ - w2 (torch.Tensor): The second set of expert weights.
513
+ - gating_output (torch.Tensor): The output of the gating operation
514
+ (before softmax).
515
+ - topk (int): The number of top-k experts to select.
516
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
517
+ - inplace (bool): If True, perform the operation in-place.
518
+ Defaults to False.
519
+ - override_config (Optional[Dict[str, Any]]): Optional override
520
+ for the kernel configuration.
521
+ - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
522
+ products for w1 and w2. Defaults to False.
523
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
524
+ w1.
525
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
526
+ w2.
527
+
528
+ Returns:
529
+ - torch.Tensor: The output tensor after applying the MoE layer.
530
+ """
531
+ # Check constraints.
532
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
533
+
534
+ if hasattr(ops, "topk_softmax"):
535
+ topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
536
+ renormalize)
537
+ else:
538
+ topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
539
+ renormalize)
540
+
541
+ return fused_experts(hidden_states,
542
+ w1,
543
+ w2,
544
+ topk_weights,
545
+ topk_ids,
546
+ inplace=inplace,
547
+ override_config=override_config,
548
+ use_fp8=use_fp8,
549
+ w1_scale=w1_scale,
550
+ w2_scale=w2_scale,
551
+ a1_scale=a1_scale,
552
+ a2_scale=a2_scale)
553
+
554
+
555
+
556
+ def fused_topk_v0_4_3(
557
+ hidden_states: torch.Tensor,
558
+ gating_output: torch.Tensor,
559
+ topk: int,
560
+ renormalize: bool,
561
+ ):
562
+ import vllm._moe_C as moe_kernels
563
+ M, _ = hidden_states.shape
564
+
565
+ topk_weights = torch.empty(
566
+ M, topk, dtype=torch.float32, device=hidden_states.device
567
+ )
568
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
569
+ token_expert_indicies = torch.empty(
570
+ M, topk, dtype=torch.int32, device=hidden_states.device
571
+ )
572
+ moe_kernels.topk_softmax(
573
+ topk_weights,
574
+ topk_ids,
575
+ token_expert_indicies,
576
+ gating_output.float(), # TODO(woosuk): Optimize this.
577
+ )
578
+ del token_expert_indicies # Not used. Will be used in the future.
579
+ if renormalize:
580
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
581
+
582
+ return topk_weights, topk_ids