sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,190 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ """Cache for the compressed finite state machine."""
15
+ import logging
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from sglang.srt.constrained import GrammarMatcher, RegexGuide
21
+ from sglang.srt.constrained.bnf_cache import BNFCache
22
+ from sglang.srt.constrained.fsm_cache import FSMCache
23
+ from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap
24
+
25
+ # from sglang.srt.managers.schedule_batch import Req
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
30
+
31
+
32
+ class XGrammarJump:
33
+ pass
34
+
35
+
36
+ class JumpHelper:
37
+ data: Union[List, str]
38
+ state: int
39
+ suffix_ids: List[int]
40
+
41
+ def __init__(
42
+ self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
43
+ ) -> None:
44
+ self.data = data
45
+ self.state = state
46
+ self.suffix_ids = suffix_ids
47
+
48
+ def can_jump(self):
49
+ return len(self.data) > 0
50
+
51
+
52
+ class Grammar:
53
+ grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
54
+ jump_map: Union[XGrammarJump, JumpForwardMap, None]
55
+
56
+ def __init__(
57
+ self,
58
+ grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
59
+ jump_map: Union[XGrammarJump, JumpForwardMap, None],
60
+ ) -> None:
61
+ self.grammar = grammar
62
+ self.jump_map = jump_map
63
+
64
+ def accept_token(self, token: int):
65
+ if isinstance(self.grammar, GrammarMatcher):
66
+ assert self.grammar.accept_token(token)
67
+ else:
68
+ guide, state = self.grammar
69
+ self.grammar = guide, guide.get_next_state(state, token)
70
+
71
+ def try_jump(self, tokenizer) -> JumpHelper:
72
+ if isinstance(self.jump_map, XGrammarJump):
73
+ assert isinstance(self.grammar, GrammarMatcher)
74
+ return JumpHelper(self.grammar.find_jump_forward_string())
75
+ elif isinstance(self.jump_map, JumpForwardMap):
76
+ assert isinstance(self.grammar, Tuple)
77
+
78
+ _, state = self.grammar
79
+ jump_forward_bytes = self.jump_map.jump_forward_byte(state)
80
+ if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
81
+ return JumpHelper() # can't jump
82
+
83
+ # preprocess the jump forward string
84
+ suffix_bytes = []
85
+ continuation_range = range(0x80, 0xC0)
86
+ cur_state = state
87
+ while (
88
+ len(jump_forward_bytes)
89
+ and jump_forward_bytes[0][0] in continuation_range
90
+ ):
91
+ # continuation bytes
92
+ byte_edge = jump_forward_bytes.pop(0)
93
+ suffix_bytes.append(byte_edge[0])
94
+ cur_state = byte_edge[1]
95
+
96
+ suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
97
+ suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
98
+ return JumpHelper(suffix_ids, cur_state, suffix_bytes)
99
+ else:
100
+ return JumpHelper() # can't jump
101
+
102
+ def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
103
+ if isinstance(helper.data, str):
104
+ return helper.data, -1
105
+ else:
106
+ assert isinstance(self.jump_map, JumpForwardMap)
107
+ return self.jump_map.jump_forward_symbol(helper.state)
108
+
109
+ def jump_and_retokenize(
110
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
111
+ ):
112
+ if isinstance(self.grammar, GrammarMatcher):
113
+ k = 0
114
+ for i, old_id in enumerate(old_output_ids):
115
+ if old_id == new_output_ids[i]:
116
+ k = i + 1
117
+ else:
118
+ break
119
+
120
+ # rollback to the last token that is the same
121
+ if k < len(old_output_ids):
122
+ self.grammar.rollback(len(old_output_ids) - k)
123
+
124
+ for i in range(k, len(new_output_ids)):
125
+ assert self.grammar.accept_token(new_output_ids[i])
126
+ else:
127
+ self.grammar = self.grammar[0], next_state
128
+
129
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
130
+ if isinstance(self.grammar, GrammarMatcher):
131
+ # Note that this bitmask is a bitset, not bool
132
+ bitmask = self.grammar.find_next_token_bitmask()
133
+ # Mask the tokens that are not allowed
134
+ vocab_mask[
135
+ self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
136
+ ] = 1
137
+ else:
138
+ guide, state = self.grammar
139
+ vocab_mask.fill_(1)
140
+ vocab_mask[guide.get_next_instruction(state).tokens] = 0
141
+
142
+
143
+ class GrammarCache:
144
+ grammar_cache: Union[BNFCache, FSMCache]
145
+ jump_cache: Union[XGrammarJump, JumpForwardCache, None]
146
+
147
+ def __init__(
148
+ self,
149
+ tokenizer_path,
150
+ tokenizer_args_dict,
151
+ skip_tokenizer_init=False,
152
+ whitespace_patterns=None,
153
+ backend=None,
154
+ allow_jump=False,
155
+ ):
156
+ if backend == "xgrammar":
157
+ self.grammar_cache = BNFCache(
158
+ tokenizer_path=tokenizer_path,
159
+ tokenizer_args_dict=tokenizer_args_dict,
160
+ skip_tokenizer_init=skip_tokenizer_init,
161
+ whitespace_patterns=whitespace_patterns,
162
+ )
163
+ self.jump_cache = XGrammarJump() if allow_jump else None
164
+ else:
165
+ assert backend == "outlines"
166
+ self.grammar_cache = FSMCache(
167
+ tokenizer_path=tokenizer_path,
168
+ tokenizer_args_dict=tokenizer_args_dict,
169
+ skip_tokenizer_init=skip_tokenizer_init,
170
+ constrained_json_whitespace_pattern=whitespace_patterns,
171
+ enable=True,
172
+ )
173
+ self.jump_cache = JumpForwardCache() if allow_jump else None
174
+
175
+ def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
176
+ if isinstance(self.grammar_cache, BNFCache):
177
+ assert not isinstance(self.jump_cache, JumpForwardCache)
178
+ return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
179
+ else:
180
+ jump_map = None
181
+ guide, regex = self.grammar_cache.query(key)
182
+ if isinstance(self.jump_cache, JumpForwardCache):
183
+ jump_map = self.jump_cache.query(regex)
184
+ return Grammar((guide, 0), jump_map)
185
+
186
+ def reset(self):
187
+ if isinstance(self.grammar_cache, FSMCache):
188
+ self.grammar_cache.reset()
189
+ if isinstance(self.jump_cache, JumpForwardCache):
190
+ self.jump_cache.reset()
@@ -81,26 +81,27 @@ def get_config(
81
81
  CONTEXT_LENGTH_KEYS = [
82
82
  "max_sequence_length",
83
83
  "seq_length",
84
- "max_position_embeddings",
85
84
  "max_seq_len",
86
85
  "model_max_length",
86
+ "max_position_embeddings",
87
87
  ]
88
88
 
89
89
 
90
90
  def get_context_length(config):
91
91
  """Get the context length of a model from a huggingface model configs."""
92
- rope_scaling = getattr(config, "rope_scaling", None)
92
+ text_config = config
93
+ rope_scaling = getattr(text_config, "rope_scaling", None)
93
94
  if rope_scaling:
94
- rope_scaling_factor = config.rope_scaling.get("factor", 1)
95
+ rope_scaling_factor = rope_scaling.get("factor", 1)
95
96
  if "original_max_position_embeddings" in rope_scaling:
96
97
  rope_scaling_factor = 1
97
- if config.rope_scaling.get("rope_type", None) == "llama3":
98
+ if rope_scaling.get("rope_type", None) == "llama3":
98
99
  rope_scaling_factor = 1
99
100
  else:
100
101
  rope_scaling_factor = 1
101
102
 
102
103
  for key in CONTEXT_LENGTH_KEYS:
103
- val = getattr(config, key, None)
104
+ val = getattr(text_config, key, None)
104
105
  if val is not None:
105
106
  return int(rope_scaling_factor * val)
106
107
  return 2048
@@ -163,6 +164,8 @@ def get_tokenizer(
163
164
  "Using a slow tokenizer. This might cause a significant "
164
165
  "slowdown. Consider using a fast tokenizer instead."
165
166
  )
167
+
168
+ attach_additional_stop_token_ids(tokenizer)
166
169
  return tokenizer
167
170
 
168
171
 
@@ -181,4 +184,16 @@ def get_processor(
181
184
  tokenizer_revision=tokenizer_revision,
182
185
  **kwargs,
183
186
  )
187
+
188
+ attach_additional_stop_token_ids(processor.tokenizer)
184
189
  return processor
190
+
191
+
192
+ def attach_additional_stop_token_ids(tokenizer):
193
+ # Special handling for stop token <|eom_id|> generated by llama 3 tool use.
194
+ if "<|eom_id|>" in tokenizer.get_added_vocab():
195
+ tokenizer.additional_stop_token_ids = set(
196
+ [tokenizer.get_added_vocab()["<|eom_id|>"]]
197
+ )
198
+ else:
199
+ tokenizer.additional_stop_token_ids = None
@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
337
337
  def update(
338
338
  self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
339
339
  ):
340
- # Keep the signature for type checking, will be initialized during runtime
340
+ # Keep the signature for type checking. It will be assigned during runtime.
341
341
  raise NotImplementedError()
342
342
 
343
343
  def update_single_wrapper(
@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
432
432
  kv_start_idx,
433
433
  ):
434
434
  bs = len(req_pool_indices)
435
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
435
436
  kv_indptr = kv_indptr[: bs + 1]
436
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
437
437
  kv_indices = torch.empty(
438
438
  paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
439
439
  )
@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
497
497
  self.update = self.update_single_wrapper
498
498
 
499
499
  def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
500
- # Keep the signature for type checking, will be initialized during runtime
500
+ # Keep the signature for type checking. It will be assigned during runtime.
501
501
  raise NotImplementedError()
502
502
 
503
503
  def update_single_wrapper(
@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
589
589
  use_ragged,
590
590
  ):
591
591
  bs = len(req_pool_indices)
592
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
592
593
  kv_indptr = kv_indptr[: bs + 1]
593
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
594
594
  kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
595
595
  create_flashinfer_kv_indices_triton[(bs,)](
596
596
  self.req_to_token,
@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
602
602
  self.max_context_len,
603
603
  )
604
604
 
605
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
605
606
  qo_indptr = qo_indptr[: bs + 1]
606
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
607
607
 
608
608
  # extend part
609
609
  if use_ragged:
@@ -24,6 +24,8 @@ It supports page size = 1.
24
24
  import triton
25
25
  import triton.language as tl
26
26
 
27
+ from sglang.srt.utils import is_hip
28
+
27
29
 
28
30
  @triton.jit
29
31
  def tanh(x):
@@ -296,12 +298,18 @@ def _fwd_grouped_kernel_stage1(
296
298
  Lk: tl.constexpr,
297
299
  ):
298
300
  cur_batch = tl.program_id(0)
299
- cur_kv_head = tl.program_id(1)
301
+ cur_head_id = tl.program_id(1)
302
+ cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
300
303
  start_n = tl.program_id(2)
301
304
 
302
305
  reduce_dtype = Att_Out.dtype.element_ty
303
- cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
304
- mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
306
+
307
+ if BLOCK_H < kv_group_num:
308
+ VALID_BLOCK_H: tl.constexpr = BLOCK_H
309
+ else:
310
+ VALID_BLOCK_H: tl.constexpr = kv_group_num
311
+ cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
312
+ mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
305
313
  mask_h = mask_h & (cur_head < q_head_num)
306
314
 
307
315
  offs_d = tl.arange(0, BLOCK_DMODEL)
@@ -400,10 +408,15 @@ def _fwd_grouped_kernel_stage2(
400
408
  Lv: tl.constexpr,
401
409
  ):
402
410
  cur_batch = tl.program_id(0)
403
- cur_kv_head = tl.program_id(1)
411
+ cur_head_id = tl.program_id(1)
412
+ cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
404
413
 
405
- cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
406
- mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
414
+ if BLOCK_H < kv_group_num:
415
+ VALID_BLOCK_H: tl.constexpr = BLOCK_H
416
+ else:
417
+ VALID_BLOCK_H: tl.constexpr = kv_group_num
418
+ cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
419
+ mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
407
420
  mask_h = mask_h & (cur_head < q_head_num)
408
421
 
409
422
  cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
@@ -485,7 +498,7 @@ def _decode_grouped_att_m_fwd(
485
498
  batch, head_num = B_req_idx.shape[0], q.shape[1]
486
499
  kv_group_num = q.shape[1] // k_buffer.shape[1]
487
500
 
488
- BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
501
+ BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
489
502
  grid = (
490
503
  batch,
491
504
  triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
@@ -534,7 +547,7 @@ def _decode_grouped_softmax_reducev_fwd(
534
547
  BLOCK = 128
535
548
  batch, head_num = b_seq_len.shape[0], logits.shape[0]
536
549
  kv_group_num = logits.shape[0] // v_buffer.shape[1]
537
- BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
550
+ BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
538
551
  grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
539
552
 
540
553
  num_warps = 8
@@ -542,6 +555,12 @@ def _decode_grouped_softmax_reducev_fwd(
542
555
  Lv = v_buffer.shape[-1]
543
556
  BLOCK_DMODEL = triton.next_power_of_2(Lv)
544
557
 
558
+ extra_kargs = {}
559
+ if is_hip():
560
+ # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
561
+ # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
562
+ extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
563
+
545
564
  _fwd_grouped_kernel_stage2[grid](
546
565
  logits,
547
566
  v_buffer,
@@ -564,6 +583,81 @@ def _decode_grouped_softmax_reducev_fwd(
564
583
  Lv=Lv,
565
584
  num_warps=num_warps,
566
585
  num_stages=1,
586
+ **extra_kargs,
587
+ )
588
+
589
+
590
+ def decode_attention_fwd_normal(
591
+ q,
592
+ k_buffer,
593
+ v_buffer,
594
+ o,
595
+ req_to_token,
596
+ b_req_idx,
597
+ b_start_loc,
598
+ b_seq_len,
599
+ attn_logits,
600
+ max_len_in_batch,
601
+ sm_scale,
602
+ logit_cap=0.0,
603
+ ):
604
+ _decode_att_m_fwd(
605
+ q,
606
+ k_buffer,
607
+ attn_logits,
608
+ req_to_token,
609
+ b_req_idx,
610
+ b_start_loc,
611
+ b_seq_len,
612
+ max_len_in_batch,
613
+ sm_scale,
614
+ logit_cap,
615
+ )
616
+ _decode_softmax_reducev_fwd(
617
+ attn_logits,
618
+ v_buffer,
619
+ o,
620
+ req_to_token,
621
+ b_req_idx,
622
+ b_start_loc,
623
+ b_seq_len,
624
+ )
625
+
626
+
627
+ def decode_attention_fwd_grouped(
628
+ q,
629
+ k_buffer,
630
+ v_buffer,
631
+ o,
632
+ req_to_token,
633
+ b_req_idx,
634
+ b_start_loc,
635
+ b_seq_len,
636
+ attn_logits,
637
+ max_len_in_batch,
638
+ sm_scale,
639
+ logit_cap=0.0,
640
+ ):
641
+ _decode_grouped_att_m_fwd(
642
+ q,
643
+ k_buffer,
644
+ attn_logits,
645
+ req_to_token,
646
+ b_req_idx,
647
+ b_start_loc,
648
+ b_seq_len,
649
+ max_len_in_batch,
650
+ sm_scale,
651
+ logit_cap,
652
+ )
653
+ _decode_grouped_softmax_reducev_fwd(
654
+ attn_logits,
655
+ v_buffer,
656
+ o,
657
+ req_to_token,
658
+ b_req_idx,
659
+ b_start_loc,
660
+ b_seq_len,
567
661
  )
568
662
 
569
663
 
@@ -585,47 +679,33 @@ def decode_attention_fwd(
585
679
 
586
680
  if kv_group_num == 1:
587
681
  # MHA
588
- _decode_att_m_fwd(
682
+ decode_attention_fwd_normal(
589
683
  q,
590
684
  k_buffer,
591
- attn_logits,
685
+ v_buffer,
686
+ o,
592
687
  req_to_token,
593
688
  b_req_idx,
594
689
  b_start_loc,
595
690
  b_seq_len,
691
+ attn_logits,
596
692
  max_len_in_batch,
597
693
  sm_scale,
598
694
  logit_cap,
599
695
  )
600
- _decode_softmax_reducev_fwd(
601
- attn_logits,
602
- v_buffer,
603
- o,
604
- req_to_token,
605
- b_req_idx,
606
- b_start_loc,
607
- b_seq_len,
608
- )
609
696
  else:
610
697
  # GQA/MQA/MLA
611
- _decode_grouped_att_m_fwd(
698
+ decode_attention_fwd_grouped(
612
699
  q,
613
700
  k_buffer,
614
- attn_logits,
701
+ v_buffer,
702
+ o,
615
703
  req_to_token,
616
704
  b_req_idx,
617
705
  b_start_loc,
618
706
  b_seq_len,
707
+ attn_logits,
619
708
  max_len_in_batch,
620
709
  sm_scale,
621
710
  logit_cap,
622
711
  )
623
- _decode_grouped_softmax_reducev_fwd(
624
- attn_logits,
625
- v_buffer,
626
- o,
627
- req_to_token,
628
- b_req_idx,
629
- b_start_loc,
630
- b_seq_len,
631
- )
@@ -168,7 +168,7 @@ def _fwd_kernel(
168
168
  def context_attention_fwd(
169
169
  q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
170
170
  ):
171
- if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
171
+ if is_cuda_available and CUDA_CAPABILITY[0] > 8:
172
172
  BLOCK = 128
173
173
  else:
174
174
  BLOCK = 64
@@ -14,6 +14,7 @@ from vllm import _custom_ops as ops
14
14
  from vllm.logger import init_logger
15
15
 
16
16
  logger = init_logger(__name__)
17
+ padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
17
18
 
18
19
 
19
20
  @triton.jit
@@ -263,7 +264,7 @@ def invoke_fused_moe_kernel(
263
264
  expert_ids,
264
265
  num_tokens_post_padded,
265
266
  B.shape[1],
266
- B.shape[2],
267
+ B.shape[2] - padding_size,
267
268
  sorted_token_ids.shape[0],
268
269
  topk_ids.numel(),
269
270
  A.stride(0),
@@ -464,7 +465,7 @@ def fused_experts(
464
465
  a2_scale: Optional[torch.Tensor] = None,
465
466
  ):
466
467
  # Check constraints.
467
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
468
+ assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
468
469
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
469
470
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
470
471
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -481,7 +482,7 @@ def fused_experts(
481
482
  get_config_func = functools.partial(
482
483
  try_get_optimal_moe_config,
483
484
  w1.shape,
484
- w2.shape,
485
+ (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
485
486
  topk_ids.shape[1],
486
487
  "float8" if use_fp8 else None,
487
488
  override_config=override_config,
@@ -1,9 +1,11 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
3
+ import os
3
4
  from abc import abstractmethod
4
5
  from typing import List, Optional, Tuple
5
6
 
6
7
  import torch
8
+ import torch.nn.functional as F
7
9
  from vllm.distributed import (
8
10
  get_tensor_model_parallel_rank,
9
11
  get_tensor_model_parallel_world_size,
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
18
20
  from vllm.model_executor.layers.quantization.fp8 import Fp8Config
19
21
  from vllm.model_executor.utils import set_weight_attrs
20
22
 
23
+ from sglang.srt.layers.fused_moe.fused_moe import padding_size
21
24
  from sglang.srt.utils import is_hip
22
25
 
23
26
  logger = init_logger(__name__)
@@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
506
509
  )
507
510
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
508
511
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
512
+
513
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
514
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
515
+ layer.w13_weight = torch.nn.Parameter(
516
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
517
+ requires_grad=False,
518
+ )
519
+ torch.cuda.empty_cache()
520
+ layer.w2_weight = torch.nn.Parameter(
521
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
522
+ requires_grad=False,
523
+ )
524
+ torch.cuda.empty_cache()
509
525
  return
510
526
 
511
527
  # If checkpoint is fp8, we need to handle that the
@@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
572
588
  start += shard_size
573
589
 
574
590
  layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
591
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
592
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
593
+ layer.w13_weight = torch.nn.Parameter(
594
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
595
+ requires_grad=False,
596
+ )
597
+ torch.cuda.empty_cache()
598
+ layer.w2_weight = torch.nn.Parameter(
599
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
600
+ requires_grad=False,
601
+ )
602
+ torch.cuda.empty_cache()
575
603
  return
576
604
 
577
605
  def apply(
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
36
- next_token_logprobs: torch.Tensor
36
+ next_token_logprobs: torch.Tensor = None
37
37
 
38
38
  # The normlaized logprobs of prompts. shape: [#seq]
39
- normalized_prompt_logprobs: torch.Tensor
39
+ normalized_prompt_logprobs: torch.Tensor = None
40
40
  # The logprobs of input tokens. shape: [#token, vocab_size]
41
- input_token_logprobs: torch.Tensor
41
+ input_token_logprobs: torch.Tensor = None
42
42
 
43
43
  # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
44
- input_top_logprobs: List
44
+ input_top_logprobs: List = None
45
45
  # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
46
- output_top_logprobs: List
46
+ output_top_logprobs: List = None
47
47
 
48
48
 
49
49
  @dataclasses.dataclass
@@ -1,7 +1,8 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
2
2
 
3
+ import inspect
3
4
  from abc import ABC, abstractmethod
4
- from typing import Any, Dict, List, Optional
5
+ from typing import Any, Dict, List, Optional, Type
5
6
 
6
7
  import torch
7
8
  from torch import nn
@@ -120,3 +121,17 @@ class QuantizationConfig(ABC):
120
121
  For now, this is only used by AWQ.
121
122
  """
122
123
  raise NotImplementedError
124
+
125
+ def method_has_implemented_embedding(
126
+ method_class: Type[QuantizeMethodBase]) -> bool:
127
+ """
128
+ Not all quant methods have embedding implemented, so we need to check that
129
+ it exists for our given method. We check this by making sure the function
130
+ has been changed from the base implementation.
131
+ """
132
+ base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
133
+ None)
134
+ class_embedding = inspect.getattr_static(method_class, "embedding", None)
135
+
136
+ return (class_embedding is not None
137
+ and class_embedding is not base_embedding)