sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,30 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Constrained decoding with xgrammar backend."""
17
15
 
18
16
  import logging
19
17
  from typing import List, Tuple
20
18
 
21
19
  import torch
22
-
23
- try:
24
- from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
25
-
26
- import_error = None
27
- except ImportError as e:
28
- CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
29
- ImportError
30
- )
31
- import_error = e
20
+ from xgrammar import (
21
+ CompiledGrammar,
22
+ GrammarCompiler,
23
+ GrammarMatcher,
24
+ TokenizerInfo,
25
+ allocate_token_bitmask,
26
+ apply_token_bitmask_inplace,
27
+ )
32
28
 
33
29
  from sglang.srt.constrained.base_grammar_backend import (
34
30
  BaseGrammarBackend,
@@ -38,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
38
34
  logger = logging.getLogger(__name__)
39
35
 
40
36
 
41
- MAX_ROLLBACK_TOKENS = 10
37
+ MAX_ROLLBACK_TOKENS = 200
42
38
 
43
39
 
44
40
  class XGrammarGrammar(BaseGrammarObject):
@@ -80,20 +76,25 @@ class XGrammarGrammar(BaseGrammarObject):
80
76
  for i in range(k, len(new_output_ids)):
81
77
  assert self.matcher.accept_token(new_output_ids[i])
82
78
 
83
- def fill_vocab_mask(self, vocab_mask: torch.Tensor):
84
- # Note that this bitmask is a bitset, not bool
85
- bitmask = self.matcher.get_next_token_bitmask()
86
- # Mask the tokens that are not allowed
87
- vocab_mask[
88
- self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
89
- ] = 1
79
+ def allocate_vocab_mask(
80
+ self, vocab_size: int, batch_size: int, device
81
+ ) -> torch.Tensor:
82
+ return allocate_token_bitmask(batch_size, vocab_size)
83
+
84
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
85
+ self.matcher.fill_next_token_bitmask(vocab_mask, idx)
86
+
87
+ @staticmethod
88
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
89
+ if vocab_mask.device.type != logits.device.type:
90
+ # vocab_mask must then be on the same device as logits
91
+ # when applying the token bitmask, so we check and move if needed
92
+ vocab_mask = vocab_mask.to(logits.device)
93
+
94
+ apply_token_bitmask_inplace(logits, vocab_mask)
90
95
 
91
96
  def copy(self):
92
- matcher = GrammarMatcher(
93
- self.ctx,
94
- max_rollback_tokens=MAX_ROLLBACK_TOKENS,
95
- mask_vocab_size=self.vocab_size,
96
- )
97
+ matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
97
98
  return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
98
99
 
99
100
 
@@ -105,26 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
105
106
  ):
106
107
  super().__init__()
107
108
 
108
- if import_error:
109
- logger.warning(
110
- f"Ignore import error for the grammar backend: {import_error}"
111
- )
112
- self.grammar_cache = None
113
- return
114
-
115
- self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
109
+ tokenizer_info = TokenizerInfo.from_huggingface(
110
+ tokenizer, vocab_size=vocab_size
111
+ )
112
+ self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
116
113
  self.vocab_size = vocab_size
117
114
 
118
115
  def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
119
- if import_error:
120
- raise import_error
121
116
 
122
117
  key_type, key_string = key
123
118
  if key_type == "json":
124
119
  try:
125
- ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
126
- key_string
127
- )
120
+ ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
128
121
  except RuntimeError as e:
129
122
  logging.warning(
130
123
  f"Skip invalid json_schema: json_schema={key_string}, {e=}"
@@ -138,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
138
131
  else:
139
132
  raise ValueError(f"Invalid key_type: {key_type}")
140
133
 
141
- matcher = GrammarMatcher(
142
- ctx,
143
- max_rollback_tokens=MAX_ROLLBACK_TOKENS,
144
- mask_vocab_size=self.vocab_size,
145
- )
134
+ matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
146
135
  return XGrammarGrammar(matcher, self.vocab_size, ctx)
147
136
 
148
137
  def reset(self):
149
- if self.grammar_cache:
150
- self.grammar_cache.clear()
138
+ if self.grammar_compiler:
139
+ self.grammar_compiler.clear_cache()
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Conversation chat templates."""
17
15
 
18
16
  # Adapted from
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Utilities for Huggingface Transformers."""
17
15
 
18
16
  import contextlib
@@ -1,16 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
- http://www.apache.org/licenses/LICENSE-2.0
7
- Unless required by applicable law or agreed to in writing, software
8
- distributed under the License is distributed on an "AS IS" BASIS,
9
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- See the License for the specific language governing permissions and
11
- limitations under the License.
12
- """
13
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
14
  """Fused operators for activation layers."""
15
15
 
16
16
  import logging
@@ -32,12 +32,14 @@ from vllm.distributed import (
32
32
  )
33
33
  from vllm.model_executor.custom_op import CustomOp
34
34
 
35
+ from sglang.srt.layers.custom_op_util import register_custom_op
35
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
37
  from sglang.srt.utils import set_weight_attrs
37
38
 
38
39
  logger = logging.getLogger(__name__)
39
40
 
40
41
 
42
+ @register_custom_op("sglang_silu_and_mul")
41
43
  class SiluAndMul(CustomOp):
42
44
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
43
45
  d = x.shape[-1] // 2
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
51
53
  return out
52
54
 
53
55
 
56
+ @register_custom_op("sglang_gelu_and_mul")
54
57
  class GeluAndMul(CustomOp):
55
58
  def __init__(self, approximate="tanh"):
56
59
  super().__init__()
@@ -7,8 +7,9 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import os
10
11
  from enum import Enum, auto
11
- from typing import TYPE_CHECKING
12
+ from typing import TYPE_CHECKING, List
12
13
 
13
14
  import torch
14
15
  import triton
@@ -45,13 +46,19 @@ class FlashInferAttnBackend(AttentionBackend):
45
46
  super().__init__()
46
47
 
47
48
  # Parse constants
48
- if not _grouped_size_compiled_for_decode_kernels(
49
- model_runner.model_config.num_attention_heads // model_runner.tp_size,
50
- model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
51
- ):
52
- self.decode_use_tensor_cores = True
49
+ if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
50
+ self.decode_use_tensor_cores = (
51
+ os.environ["SGLANG_FLASHINFER_USE_TENSOR_CORE"].lower() == "true"
52
+ )
53
53
  else:
54
- self.decode_use_tensor_cores = False
54
+ if not _grouped_size_compiled_for_decode_kernels(
55
+ model_runner.model_config.num_attention_heads // model_runner.tp_size,
56
+ model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
57
+ ):
58
+ self.decode_use_tensor_cores = True
59
+ else:
60
+ self.decode_use_tensor_cores = False
61
+
55
62
  self.max_context_len = model_runner.model_config.context_len
56
63
 
57
64
  assert not (
@@ -136,15 +143,17 @@ class FlashInferAttnBackend(AttentionBackend):
136
143
  prefix_lens = forward_batch.extend_prefix_lens
137
144
 
138
145
  # Some heuristics to check whether to use ragged forward
139
- use_ragged = False
140
146
  if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
141
147
  use_ragged = True
142
-
143
- extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
148
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
149
+ else:
150
+ use_ragged = False
151
+ extend_no_prefix = False
144
152
 
145
153
  self.indices_updater_prefill.update(
146
154
  forward_batch.req_pool_indices,
147
155
  forward_batch.seq_lens,
156
+ forward_batch.seq_lens_sum,
148
157
  prefix_lens,
149
158
  use_ragged=use_ragged,
150
159
  encoder_lens=forward_batch.encoder_lens,
@@ -314,7 +323,6 @@ class FlashInferIndicesUpdaterDecode:
314
323
  self.head_dim = model_runner.model_config.head_dim
315
324
  self.data_type = model_runner.kv_cache_dtype
316
325
  self.q_data_type = model_runner.dtype
317
- self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
318
326
  self.sliding_window_size = model_runner.sliding_window_size
319
327
 
320
328
  self.attn_backend = attn_backend
@@ -335,7 +343,12 @@ class FlashInferIndicesUpdaterDecode:
335
343
  self.update = self.update_single_wrapper
336
344
 
337
345
  def update(
338
- self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
346
+ self,
347
+ req_pool_indices: torch.Tensor,
348
+ seq_lens: torch.Tensor,
349
+ seq_lens_sum: int,
350
+ decode_wrappers: List,
351
+ encoder_lens: torch.Tensor,
339
352
  ):
340
353
  # Keep the signature for type checking. It will be assigned during runtime.
341
354
  raise NotImplementedError()
@@ -345,8 +358,8 @@ class FlashInferIndicesUpdaterDecode:
345
358
  req_pool_indices: torch.Tensor,
346
359
  seq_lens: torch.Tensor,
347
360
  seq_lens_sum: int,
348
- decode_wrappers=None,
349
- encoder_lens=None,
361
+ decode_wrappers: List,
362
+ encoder_lens: torch.Tensor,
350
363
  ):
351
364
  decode_wrappers = decode_wrappers or self.decode_wrappers
352
365
  self.call_begin_forward(
@@ -363,8 +376,8 @@ class FlashInferIndicesUpdaterDecode:
363
376
  req_pool_indices: torch.Tensor,
364
377
  seq_lens: torch.Tensor,
365
378
  seq_lens_sum: int,
366
- decode_wrappers=None,
367
- encoder_lens=None,
379
+ decode_wrappers: List,
380
+ encoder_lens: torch.Tensor,
368
381
  ):
369
382
  decode_wrappers = decode_wrappers or self.decode_wrappers
370
383
 
@@ -394,11 +407,11 @@ class FlashInferIndicesUpdaterDecode:
394
407
 
395
408
  def update_cross_attention(
396
409
  self,
397
- req_pool_indices,
398
- seq_lens,
399
- seq_lens_sum,
400
- decode_wrappers=None,
401
- encoder_lens=None,
410
+ req_pool_indices: torch.Tensor,
411
+ seq_lens: torch.Tensor,
412
+ seq_lens_sum: int,
413
+ decode_wrappers: List,
414
+ encoder_lens: torch.Tensor,
402
415
  ):
403
416
  decode_wrappers = decode_wrappers or self.decode_wrappers
404
417
 
@@ -425,11 +438,11 @@ class FlashInferIndicesUpdaterDecode:
425
438
  def call_begin_forward(
426
439
  self,
427
440
  wrapper,
428
- req_pool_indices,
429
- paged_kernel_lens,
430
- paged_kernel_lens_sum,
431
- kv_indptr,
432
- kv_start_idx,
441
+ req_pool_indices: torch.Tensor,
442
+ paged_kernel_lens: torch.Tensor,
443
+ paged_kernel_lens_sum: int,
444
+ kv_indptr: torch.Tensor,
445
+ kv_start_idx: torch.Tensor,
433
446
  ):
434
447
  bs = len(req_pool_indices)
435
448
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
@@ -445,7 +458,7 @@ class FlashInferIndicesUpdaterDecode:
445
458
  kv_indptr,
446
459
  kv_start_idx,
447
460
  kv_indices,
448
- self.max_context_len,
461
+ self.req_to_token.shape[1],
449
462
  )
450
463
 
451
464
  wrapper.end_forward()
@@ -474,7 +487,6 @@ class FlashInferIndicesUpdaterPrefill:
474
487
  self.head_dim = model_runner.model_config.head_dim
475
488
  self.data_type = model_runner.kv_cache_dtype
476
489
  self.q_data_type = model_runner.dtype
477
- self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
478
490
  self.sliding_window_size = model_runner.sliding_window_size
479
491
 
480
492
  self.attn_backend = attn_backend
@@ -496,23 +508,40 @@ class FlashInferIndicesUpdaterPrefill:
496
508
  assert self.attn_backend.num_wrappers == 1
497
509
  self.update = self.update_single_wrapper
498
510
 
499
- def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
511
+ def update(
512
+ self,
513
+ req_pool_indices: torch.Tnesor,
514
+ seq_lens: torch.Tensor,
515
+ seq_lens_sum: int,
516
+ prefix_lens: torch.Tensor,
517
+ use_ragged: bool,
518
+ encoder_lens: torch.Tensor,
519
+ ):
500
520
  # Keep the signature for type checking. It will be assigned during runtime.
501
521
  raise NotImplementedError()
502
522
 
503
523
  def update_single_wrapper(
504
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
524
+ self,
525
+ req_pool_indices: torch.Tnesor,
526
+ seq_lens: torch.Tensor,
527
+ seq_lens_sum: int,
528
+ prefix_lens: torch.Tensor,
529
+ use_ragged: bool,
530
+ encoder_lens: torch.Tensor,
505
531
  ):
506
532
  if use_ragged:
507
533
  paged_kernel_lens = prefix_lens
534
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
508
535
  else:
509
536
  paged_kernel_lens = seq_lens
537
+ paged_kernel_lens_sum = seq_lens_sum
510
538
 
511
539
  self.call_begin_forward(
512
540
  self.wrapper_ragged,
513
541
  self.wrappers_paged[0],
514
542
  req_pool_indices,
515
543
  paged_kernel_lens,
544
+ paged_kernel_lens_sum,
516
545
  seq_lens,
517
546
  prefix_lens,
518
547
  None,
@@ -522,7 +551,13 @@ class FlashInferIndicesUpdaterPrefill:
522
551
  )
523
552
 
524
553
  def update_sliding_window(
525
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
554
+ self,
555
+ req_pool_indices: torch.Tensor,
556
+ seq_lens: torch.Tensor,
557
+ seq_lens_sum: int,
558
+ prefix_lens: torch.Tensor,
559
+ use_ragged: bool,
560
+ encoder_lens: torch.Tensor,
526
561
  ):
527
562
  for wrapper_id in range(2):
528
563
  if wrapper_id == 0:
@@ -531,9 +566,12 @@ class FlashInferIndicesUpdaterPrefill:
531
566
  seq_lens,
532
567
  torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
533
568
  )
569
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
534
570
  else:
535
571
  # full attention
536
572
  paged_kernel_lens = seq_lens
573
+ paged_kernel_lens_sum = seq_lens_sum
574
+
537
575
  kv_start_idx = seq_lens - paged_kernel_lens
538
576
 
539
577
  self.call_begin_forward(
@@ -541,6 +579,7 @@ class FlashInferIndicesUpdaterPrefill:
541
579
  self.wrappers_paged[wrapper_id],
542
580
  req_pool_indices,
543
581
  paged_kernel_lens,
582
+ paged_kernel_lens_sum,
544
583
  seq_lens,
545
584
  prefix_lens,
546
585
  kv_start_idx,
@@ -550,23 +589,32 @@ class FlashInferIndicesUpdaterPrefill:
550
589
  )
551
590
 
552
591
  def update_cross_attention(
553
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
592
+ self,
593
+ req_pool_indices: torch.Tensor,
594
+ seq_lens: torch.Tensor,
595
+ seq_lens_sum: int,
596
+ prefix_lens: torch.Tensor,
597
+ use_ragged: bool,
598
+ encoder_lens: torch.Tensor,
554
599
  ):
555
600
  for wrapper_id in range(2):
556
601
  if wrapper_id == 0:
557
602
  # normal attention
558
603
  paged_kernel_lens = seq_lens
559
604
  kv_start_idx = encoder_lens
605
+ paged_kernel_lens_sum = seq_lens_sum
560
606
  else:
561
607
  # cross attention
562
608
  paged_kernel_lens = encoder_lens
563
609
  kv_start_idx = torch.zeros_like(encoder_lens)
610
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
564
611
 
565
612
  self.call_begin_forward(
566
613
  self.wrapper_ragged,
567
614
  self.wrappers_paged[wrapper_id],
568
615
  req_pool_indices,
569
616
  paged_kernel_lens,
617
+ paged_kernel_lens_sum,
570
618
  seq_lens,
571
619
  prefix_lens,
572
620
  kv_start_idx,
@@ -579,19 +627,22 @@ class FlashInferIndicesUpdaterPrefill:
579
627
  self,
580
628
  wrapper_ragged,
581
629
  wrapper_paged,
582
- req_pool_indices,
583
- paged_kernel_lens,
584
- seq_lens,
585
- prefix_lens,
586
- kv_start_idx,
587
- kv_indptr,
588
- qo_indptr,
589
- use_ragged,
630
+ req_pool_indices: torch.Tensor,
631
+ paged_kernel_lens: torch.Tensor,
632
+ paged_kernel_lens_sum: int,
633
+ seq_lens: torch.Tensor,
634
+ prefix_lens: torch.Tensor,
635
+ kv_start_idx: torch.Tensor,
636
+ kv_indptr: torch.Tensor,
637
+ qo_indptr: torch.Tensor,
638
+ use_ragged: bool,
590
639
  ):
591
640
  bs = len(req_pool_indices)
592
641
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
593
642
  kv_indptr = kv_indptr[: bs + 1]
594
- kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
643
+ kv_indices = torch.empty(
644
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
645
+ )
595
646
  create_flashinfer_kv_indices_triton[(bs,)](
596
647
  self.req_to_token,
597
648
  req_pool_indices,
@@ -599,7 +650,7 @@ class FlashInferIndicesUpdaterPrefill:
599
650
  kv_indptr,
600
651
  kv_start_idx,
601
652
  kv_indices,
602
- self.max_context_len,
653
+ self.req_to_token.shape[1],
603
654
  )
604
655
 
605
656
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
@@ -638,10 +689,11 @@ def create_flashinfer_kv_indices_triton(
638
689
  kv_indptr,
639
690
  kv_start_idx,
640
691
  kv_indices_ptr,
641
- max_context_len: tl.constexpr,
692
+ req_to_token_ptr_stride: tl.constexpr,
642
693
  ):
643
694
  BLOCK_SIZE: tl.constexpr = 512
644
695
  pid = tl.program_id(axis=0)
696
+
645
697
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
646
698
  kv_indices_offset = tl.load(kv_indptr + pid)
647
699
 
@@ -652,15 +704,15 @@ def create_flashinfer_kv_indices_triton(
652
704
  kv_end = kv_start
653
705
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
654
706
 
655
- req_to_token_ptr += req_pool_index * max_context_len
656
- kv_indices_ptr += kv_indices_offset
657
-
658
- ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
659
- st_offset = tl.arange(0, BLOCK_SIZE)
660
707
  num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
661
- for _ in range(num_loop):
662
- mask = ld_offset < kv_end
663
- data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
664
- tl.store(kv_indices_ptr + st_offset, data, mask=mask)
665
- ld_offset += BLOCK_SIZE
666
- st_offset += BLOCK_SIZE
708
+ for i in range(num_loop):
709
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
710
+ mask = offset < kv_end - kv_start
711
+ data = tl.load(
712
+ req_to_token_ptr
713
+ + req_pool_index * req_to_token_ptr_stride
714
+ + kv_start
715
+ + offset,
716
+ mask=mask,
717
+ )
718
+ tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import torch
6
- import torch.nn as nn
7
6
 
8
7
  from sglang.srt.layers.attention import AttentionBackend
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -28,9 +27,13 @@ class TritonAttnBackend(AttentionBackend):
28
27
 
29
28
  self.decode_attention_fwd = decode_attention_fwd
30
29
  self.extend_attention_fwd = extend_attention_fwd
31
- self.num_head = (
32
- model_runner.model_config.num_attention_heads // model_runner.tp_size
33
- )
30
+
31
+ if model_runner.server_args.enable_dp_attention:
32
+ self.num_head = model_runner.model_config.num_attention_heads
33
+ else:
34
+ self.num_head = (
35
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
36
+ )
34
37
 
35
38
  if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
36
39
  self.reduce_dtype = torch.float32
@@ -50,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
50
53
  start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
51
54
  start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
52
55
 
53
- total_num_tokens = torch.sum(forward_batch.seq_lens).item()
56
+ total_num_tokens = forward_batch.seq_lens_sum
54
57
  attn_logits = torch.empty(
55
58
  (self.num_head, total_num_tokens),
56
59
  dtype=self.reduce_dtype,
@@ -61,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
61
64
  max_extend_len = None
62
65
  else:
63
66
  start_loc = attn_logits = max_seq_len = None
64
- prefix_lens = forward_batch.extend_prefix_lens
65
- max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
67
+ max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
66
68
 
67
69
  self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
68
70