sglang 0.3.6__py3-none-any.whl → 0.3.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_one_batch.py +4 -7
  4. sglang/bench_one_batch_server.py +2 -2
  5. sglang/bench_serving.py +75 -26
  6. sglang/check_env.py +7 -1
  7. sglang/lang/backend/base_backend.py +1 -1
  8. sglang/lang/backend/runtime_endpoint.py +2 -2
  9. sglang/lang/tracer.py +1 -1
  10. sglang/launch_server.py +0 -3
  11. sglang/srt/configs/model_config.py +15 -20
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +13 -15
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +38 -57
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +13 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +14 -7
  21. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  22. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  23. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  24. sglang/srt/layers/custom_op_util.py +13 -14
  25. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  26. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  27. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  28. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  29. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  30. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  31. sglang/srt/layers/layernorm.py +13 -15
  32. sglang/srt/layers/logits_processor.py +13 -15
  33. sglang/srt/layers/quantization/__init__.py +77 -17
  34. sglang/srt/layers/radix_attention.py +13 -15
  35. sglang/srt/layers/rotary_embedding.py +13 -13
  36. sglang/srt/layers/sampler.py +1 -1
  37. sglang/srt/lora/lora.py +13 -14
  38. sglang/srt/lora/lora_config.py +13 -14
  39. sglang/srt/lora/lora_manager.py +22 -24
  40. sglang/srt/managers/data_parallel_controller.py +25 -19
  41. sglang/srt/managers/detokenizer_manager.py +13 -18
  42. sglang/srt/managers/image_processor.py +6 -9
  43. sglang/srt/managers/io_struct.py +43 -28
  44. sglang/srt/managers/schedule_batch.py +92 -27
  45. sglang/srt/managers/schedule_policy.py +13 -15
  46. sglang/srt/managers/scheduler.py +94 -72
  47. sglang/srt/managers/session_controller.py +29 -19
  48. sglang/srt/managers/tokenizer_manager.py +29 -22
  49. sglang/srt/managers/tp_worker.py +13 -15
  50. sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
  51. sglang/srt/metrics/collector.py +13 -15
  52. sglang/srt/metrics/func_timer.py +13 -15
  53. sglang/srt/mm_utils.py +13 -14
  54. sglang/srt/model_executor/cuda_graph_runner.py +20 -19
  55. sglang/srt/model_executor/forward_batch_info.py +19 -17
  56. sglang/srt/model_executor/model_runner.py +42 -30
  57. sglang/srt/models/chatglm.py +15 -16
  58. sglang/srt/models/commandr.py +15 -16
  59. sglang/srt/models/dbrx.py +15 -16
  60. sglang/srt/models/deepseek.py +15 -15
  61. sglang/srt/models/deepseek_v2.py +15 -15
  62. sglang/srt/models/exaone.py +14 -15
  63. sglang/srt/models/gemma.py +14 -14
  64. sglang/srt/models/gemma2.py +24 -19
  65. sglang/srt/models/gemma2_reward.py +13 -14
  66. sglang/srt/models/gpt_bigcode.py +14 -14
  67. sglang/srt/models/grok.py +15 -15
  68. sglang/srt/models/internlm2.py +13 -15
  69. sglang/srt/models/internlm2_reward.py +13 -14
  70. sglang/srt/models/llama.py +21 -21
  71. sglang/srt/models/llama_classification.py +13 -14
  72. sglang/srt/models/llama_reward.py +13 -14
  73. sglang/srt/models/llava.py +20 -16
  74. sglang/srt/models/llavavid.py +13 -15
  75. sglang/srt/models/minicpm.py +13 -15
  76. sglang/srt/models/minicpm3.py +13 -15
  77. sglang/srt/models/mistral.py +13 -15
  78. sglang/srt/models/mixtral.py +15 -15
  79. sglang/srt/models/mixtral_quant.py +14 -14
  80. sglang/srt/models/olmo.py +21 -19
  81. sglang/srt/models/olmoe.py +23 -20
  82. sglang/srt/models/qwen.py +14 -14
  83. sglang/srt/models/qwen2.py +22 -19
  84. sglang/srt/models/qwen2_moe.py +17 -18
  85. sglang/srt/models/stablelm.py +18 -16
  86. sglang/srt/models/torch_native_llama.py +15 -17
  87. sglang/srt/models/xverse.py +13 -14
  88. sglang/srt/models/xverse_moe.py +15 -16
  89. sglang/srt/models/yivl.py +13 -15
  90. sglang/srt/openai_api/adapter.py +13 -15
  91. sglang/srt/openai_api/protocol.py +13 -15
  92. sglang/srt/sampling/sampling_batch_info.py +4 -1
  93. sglang/srt/sampling/sampling_params.py +13 -15
  94. sglang/srt/server.py +60 -34
  95. sglang/srt/server_args.py +22 -22
  96. sglang/srt/utils.py +208 -19
  97. sglang/test/few_shot_gsm8k.py +8 -4
  98. sglang/test/runners.py +13 -14
  99. sglang/test/test_utils.py +2 -2
  100. sglang/version.py +1 -1
  101. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/LICENSE +1 -1
  102. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/METADATA +25 -15
  103. sglang-0.3.6.post2.dist-info/RECORD +164 -0
  104. sglang/srt/layers/fused_moe/__init__.py +0 -1
  105. sglang-0.3.6.dist-info/RECORD +0 -161
  106. /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
  107. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/WHEEL +0 -0
  108. {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/top_level.txt +0 -0
@@ -1,39 +1,30 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Constrained decoding with xgrammar backend."""
17
15
 
18
16
  import logging
19
17
  from typing import List, Tuple
20
18
 
21
19
  import torch
22
-
23
- try:
24
- from xgrammar import (
25
- CachedGrammarCompiler,
26
- CompiledGrammar,
27
- GrammarMatcher,
28
- TokenizerInfo,
29
- )
30
-
31
- import_error = None
32
- except ImportError as e:
33
- CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
34
- ImportError
35
- )
36
- import_error = e
20
+ from xgrammar import (
21
+ CompiledGrammar,
22
+ GrammarCompiler,
23
+ GrammarMatcher,
24
+ TokenizerInfo,
25
+ allocate_token_bitmask,
26
+ apply_token_bitmask_inplace,
27
+ )
37
28
 
38
29
  from sglang.srt.constrained.base_grammar_backend import (
39
30
  BaseGrammarBackend,
@@ -43,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
43
34
  logger = logging.getLogger(__name__)
44
35
 
45
36
 
46
- MAX_ROLLBACK_TOKENS = 10
37
+ MAX_ROLLBACK_TOKENS = 200
47
38
 
48
39
 
49
40
  class XGrammarGrammar(BaseGrammarObject):
@@ -88,21 +79,22 @@ class XGrammarGrammar(BaseGrammarObject):
88
79
  def allocate_vocab_mask(
89
80
  self, vocab_size: int, batch_size: int, device
90
81
  ) -> torch.Tensor:
91
- return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
82
+ return allocate_token_bitmask(batch_size, vocab_size)
92
83
 
93
84
  def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
94
85
  self.matcher.fill_next_token_bitmask(vocab_mask, idx)
95
86
 
96
87
  @staticmethod
97
88
  def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
98
- GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
89
+ if vocab_mask.device.type != logits.device.type:
90
+ # vocab_mask must then be on the same device as logits
91
+ # when applying the token bitmask, so we check and move if needed
92
+ vocab_mask = vocab_mask.to(logits.device)
93
+
94
+ apply_token_bitmask_inplace(logits, vocab_mask)
99
95
 
100
96
  def copy(self):
101
- matcher = GrammarMatcher(
102
- self.ctx,
103
- max_rollback_tokens=MAX_ROLLBACK_TOKENS,
104
- vocab_size=self.vocab_size,
105
- )
97
+ matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
106
98
  return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
107
99
 
108
100
 
@@ -114,25 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
114
106
  ):
115
107
  super().__init__()
116
108
 
117
- if import_error:
118
- logger.warning(
119
- f"Ignore import error for the grammar backend: {import_error}"
120
- )
121
- self.grammar_cache = None
122
- return
123
-
124
- tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
125
- self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
109
+ tokenizer_info = TokenizerInfo.from_huggingface(
110
+ tokenizer, vocab_size=vocab_size
111
+ )
112
+ self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
126
113
  self.vocab_size = vocab_size
127
114
 
128
115
  def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
129
- if import_error:
130
- raise import_error
131
116
 
132
117
  key_type, key_string = key
133
118
  if key_type == "json":
134
119
  try:
135
- ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
120
+ ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
136
121
  except RuntimeError as e:
137
122
  logging.warning(
138
123
  f"Skip invalid json_schema: json_schema={key_string}, {e=}"
@@ -146,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
146
131
  else:
147
132
  raise ValueError(f"Invalid key_type: {key_type}")
148
133
 
149
- matcher = GrammarMatcher(
150
- ctx,
151
- max_rollback_tokens=MAX_ROLLBACK_TOKENS,
152
- vocab_size=self.vocab_size,
153
- )
134
+ matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
154
135
  return XGrammarGrammar(matcher, self.vocab_size, ctx)
155
136
 
156
137
  def reset(self):
157
- if self.grammar_cache:
158
- self.grammar_cache.clear()
138
+ if self.grammar_compiler:
139
+ self.grammar_compiler.clear_cache()
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Conversation chat templates."""
17
15
 
18
16
  # Adapted from
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Utilities for Huggingface Transformers."""
17
15
 
18
16
  import contextlib
@@ -1,16 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
- http://www.apache.org/licenses/LICENSE-2.0
7
- Unless required by applicable law or agreed to in writing, software
8
- distributed under the License is distributed on an "AS IS" BASIS,
9
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- See the License for the specific language governing permissions and
11
- limitations under the License.
12
- """
13
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
14
  """Fused operators for activation layers."""
15
15
 
16
16
  import logging
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import os
10
11
  from enum import Enum, auto
11
12
  from typing import TYPE_CHECKING, List
12
13
 
@@ -17,7 +18,7 @@ import triton.language as tl
17
18
  from sglang.global_config import global_config
18
19
  from sglang.srt.layers.attention import AttentionBackend
19
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
- from sglang.srt.utils import is_flashinfer_available
21
+ from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -45,13 +46,19 @@ class FlashInferAttnBackend(AttentionBackend):
45
46
  super().__init__()
46
47
 
47
48
  # Parse constants
48
- if not _grouped_size_compiled_for_decode_kernels(
49
- model_runner.model_config.num_attention_heads // model_runner.tp_size,
50
- model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
51
- ):
52
- self.decode_use_tensor_cores = True
49
+ if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
50
+ self.decode_use_tensor_cores = get_bool_env_var(
51
+ "SGLANG_FLASHINFER_USE_TENSOR_CORE"
52
+ )
53
53
  else:
54
- self.decode_use_tensor_cores = False
54
+ if not _grouped_size_compiled_for_decode_kernels(
55
+ model_runner.model_config.num_attention_heads // model_runner.tp_size,
56
+ model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
57
+ ):
58
+ self.decode_use_tensor_cores = True
59
+ else:
60
+ self.decode_use_tensor_cores = False
61
+
55
62
  self.max_context_len = model_runner.model_config.context_len
56
63
 
57
64
  assert not (
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Memory-efficient attention for decoding.
18
16
  It supports page size = 1.
@@ -26,6 +24,8 @@ import triton.language as tl
26
24
 
27
25
  from sglang.srt.utils import is_hip
28
26
 
27
+ is_hip_ = is_hip()
28
+
29
29
 
30
30
  @triton.jit
31
31
  def tanh(x):
@@ -52,12 +52,13 @@ def _fwd_kernel_stage1(
52
52
  kv_group_num: tl.constexpr,
53
53
  BLOCK_DMODEL: tl.constexpr,
54
54
  BLOCK_N: tl.constexpr,
55
+ SPLIT_K: tl.constexpr,
55
56
  logit_cap: tl.constexpr,
56
57
  Lk: tl.constexpr,
57
58
  ):
58
59
  cur_batch = tl.program_id(0)
59
60
  cur_head = tl.program_id(1)
60
- start_n = tl.program_id(2)
61
+ split_k_id = tl.program_id(2)
61
62
 
62
63
  reduce_dtype = Att_Out.dtype.element_ty
63
64
  cur_kv_head = cur_head // kv_group_num
@@ -67,22 +68,18 @@ def _fwd_kernel_stage1(
67
68
  cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
68
69
  cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
69
70
 
70
- cur_batch_start_index = 0
71
- cur_batch_end_index = cur_batch_seq_len
72
-
73
71
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
72
+ q = tl.load(Q + off_q).to(reduce_dtype)
74
73
 
75
- offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
74
+ kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
75
+ split_k_start = kv_len_per_split * split_k_id
76
+ split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
76
77
 
77
- block_stard_index = start_n * BLOCK_N
78
- block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
79
-
80
- for start_mark in range(0, block_mask, 1):
81
- q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
82
- offs_n_new = cur_batch_start_index + offs_n
78
+ for start_n in range(split_k_start, split_k_end, BLOCK_N):
79
+ offs_n = start_n + tl.arange(0, BLOCK_N)
83
80
  k_loc = tl.load(
84
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
85
- mask=offs_n_new < cur_batch_end_index,
81
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
82
+ mask=offs_n < split_k_end,
86
83
  other=0,
87
84
  )
88
85
  offs_buf_k = (
@@ -92,7 +89,7 @@ def _fwd_kernel_stage1(
92
89
  )
93
90
  k = tl.load(
94
91
  K_Buffer + offs_buf_k,
95
- mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
92
+ mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk),
96
93
  other=0.0,
97
94
  ).to(reduce_dtype)
98
95
  att_value = tl.sum(q[None, :] * k, 1)
@@ -102,7 +99,7 @@ def _fwd_kernel_stage1(
102
99
  att_value = logit_cap * tanh(att_value / logit_cap)
103
100
 
104
101
  off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
105
- tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
102
+ tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end)
106
103
 
107
104
 
108
105
  @triton.jit
@@ -191,11 +188,12 @@ def _decode_att_m_fwd(
191
188
  logit_cap,
192
189
  ):
193
190
  BLOCK = 32
191
+ SPLIT_K = 8
194
192
  Lk = k_buffer.shape[-1]
195
193
 
196
194
  batch, head_num = B_req_idx.shape[0], q.shape[1]
197
195
 
198
- grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
196
+ grid = (batch, head_num, SPLIT_K)
199
197
  kv_group_num = q.shape[1] // k_buffer.shape[1]
200
198
 
201
199
  if kv_group_num == 1:
@@ -223,6 +221,7 @@ def _decode_att_m_fwd(
223
221
  kv_group_num=kv_group_num,
224
222
  BLOCK_DMODEL=BLOCK_DMODEL,
225
223
  BLOCK_N=BLOCK,
224
+ SPLIT_K=SPLIT_K,
226
225
  logit_cap=logit_cap,
227
226
  num_warps=num_warps,
228
227
  num_stages=1,
@@ -294,13 +293,14 @@ def _fwd_grouped_kernel_stage1(
294
293
  BLOCK_DPE: tl.constexpr,
295
294
  BLOCK_N: tl.constexpr,
296
295
  BLOCK_H: tl.constexpr,
296
+ SPLIT_K: tl.constexpr,
297
297
  logit_cap: tl.constexpr,
298
298
  Lk: tl.constexpr,
299
299
  ):
300
300
  cur_batch = tl.program_id(0)
301
301
  cur_head_id = tl.program_id(1)
302
302
  cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
303
- start_n = tl.program_id(2)
303
+ split_k_id = tl.program_id(2)
304
304
 
305
305
  reduce_dtype = Att_Out.dtype.element_ty
306
306
 
@@ -317,30 +317,27 @@ def _fwd_grouped_kernel_stage1(
317
317
  cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
318
318
  cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
319
319
 
320
- cur_batch_start_index = 0
321
- cur_batch_end_index = cur_batch_seq_len
322
-
323
320
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
321
+ q = tl.load(
322
+ Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
323
+ ).to(reduce_dtype)
324
324
 
325
325
  if BLOCK_DPE > 0:
326
326
  offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
327
327
  off_qpe = (
328
328
  cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
329
329
  )
330
+ qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype)
330
331
 
331
- offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
332
-
333
- block_stard_index = start_n * BLOCK_N
334
- block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
332
+ kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
333
+ split_k_start = kv_len_per_split * split_k_id
334
+ split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
335
335
 
336
- for start_mark in range(0, block_mask, 1):
337
- q = tl.load(
338
- Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
339
- ).to(reduce_dtype)
340
- offs_n_new = cur_batch_start_index + offs_n
336
+ for start_n in range(split_k_start, split_k_end, BLOCK_N):
337
+ offs_n = start_n + tl.arange(0, BLOCK_N)
341
338
  k_loc = tl.load(
342
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
343
- mask=offs_n_new < cur_batch_end_index,
339
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
340
+ mask=offs_n < split_k_end,
344
341
  other=0,
345
342
  )
346
343
  offs_buf_k = (
@@ -350,14 +347,11 @@ def _fwd_grouped_kernel_stage1(
350
347
  )
351
348
  k = tl.load(
352
349
  K_Buffer + offs_buf_k,
353
- mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
350
+ mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
354
351
  other=0.0,
355
352
  ).to(reduce_dtype)
356
353
  qk = tl.dot(q, k)
357
354
  if BLOCK_DPE > 0:
358
- qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
359
- reduce_dtype
360
- )
361
355
  offs_buf_kpe = (
362
356
  k_loc[None, :] * stride_buf_kbs
363
357
  + cur_kv_head * stride_buf_kh
@@ -365,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
365
359
  )
366
360
  kpe = tl.load(
367
361
  K_Buffer + offs_buf_kpe,
368
- mask=offs_n_new[None, :] < cur_batch_end_index,
362
+ mask=offs_n[None, :] < split_k_end,
369
363
  other=0.0,
370
364
  ).to(reduce_dtype)
371
365
  qk += tl.dot(qpe, kpe)
@@ -381,7 +375,7 @@ def _fwd_grouped_kernel_stage1(
381
375
  tl.store(
382
376
  Att_Out + offs_o,
383
377
  qk,
384
- mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
378
+ mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
385
379
  )
386
380
 
387
381
 
@@ -499,16 +493,17 @@ def _decode_grouped_att_m_fwd(
499
493
  kv_group_num = q.shape[1] // k_buffer.shape[1]
500
494
 
501
495
  BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
496
+ SPLIT_K = 8
502
497
  grid = (
503
498
  batch,
504
499
  triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
505
- triton.cdiv(max_len_in_batch, BLOCK),
500
+ SPLIT_K,
506
501
  )
507
502
 
508
503
  num_warps = 4
509
504
 
510
505
  extra_kargs = {}
511
- if is_hip():
506
+ if is_hip_:
512
507
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
513
508
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
514
509
  extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
@@ -534,6 +529,7 @@ def _decode_grouped_att_m_fwd(
534
529
  BLOCK_DPE=BLOCK_DPE,
535
530
  BLOCK_N=BLOCK,
536
531
  BLOCK_H=BLOCK_H,
532
+ SPLIT_K=SPLIT_K,
537
533
  logit_cap=logit_cap,
538
534
  num_warps=num_warps,
539
535
  num_stages=1,
@@ -563,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
563
559
  BLOCK_DMODEL = triton.next_power_of_2(Lv)
564
560
 
565
561
  extra_kargs = {}
566
- if is_hip():
562
+ if is_hip_:
567
563
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
568
564
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
569
565
  extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Memory-efficient attention for prefill.
18
16
  It supports page size = 1 and prefill with KV cache (i.e. extend).
@@ -31,6 +29,8 @@ is_cuda_available = torch.cuda.is_available()
31
29
  if is_cuda_available:
32
30
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
33
31
 
32
+ is_hip_ = is_hip()
33
+
34
34
 
35
35
  @triton.jit
36
36
  def tanh(x):
@@ -313,7 +313,7 @@ def extend_attention_fwd(
313
313
  num_stages = 1
314
314
 
315
315
  extra_kargs = {}
316
- if is_hip():
316
+ if is_hip_:
317
317
  extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
318
318
 
319
319
  _fwd_kernel[grid](
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Memory-efficient attention for prefill.
18
16
  It supporst page size = 1.