sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,56 @@
1
+ """Logits processing."""
2
+
3
+ import dataclasses
4
+ from typing import List, Union
5
+
1
6
  import torch
2
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
3
7
  from torch import nn
4
- from vllm.model_executor.parallel_utils.communication_op import (
8
+ from vllm.distributed import (
5
9
  get_tensor_model_parallel_world_size,
6
10
  tensor_model_parallel_all_gather,
7
11
  )
8
12
 
13
+ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class LogitProcessorOutput:
18
+ # The logits of the next tokens. shape: [#seq, vocab_size]
19
+ next_token_logits: torch.Tensor
20
+ # The logprobs of the next tokens. shape: [#seq, vocab_size]
21
+ next_token_logprobs: torch.Tensor
22
+
23
+ # The normlaized logprobs of prompts. shape: [#seq]
24
+ normalized_prompt_logprobs: torch.Tensor
25
+ # The logprobs of prefill tokens. shape: [#token, vocab_size]
26
+ prefill_token_logprobs: torch.Tensor
27
+
28
+ # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
29
+ prefill_top_logprobs: List
30
+ # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
31
+ decode_top_logprobs: List
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class LogitsMetadata:
36
+ forward_mode: ForwardMode
37
+ extend_seq_lens: torch.Tensor
38
+ extend_start_loc: torch.Tensor
39
+
40
+ # For logprobs
41
+ return_logprob: bool
42
+ top_logprobs_nums: List[int]
43
+
44
+ @classmethod
45
+ def from_input_metadata(cls, input_metadata: InputMetadata):
46
+ return cls(
47
+ forward_mode=input_metadata.forward_mode,
48
+ extend_seq_lens=input_metadata.extend_seq_lens,
49
+ extend_start_loc=input_metadata.extend_start_loc,
50
+ return_logprob=input_metadata.return_logprob,
51
+ top_logprobs_nums=input_metadata.top_logprobs_nums,
52
+ )
53
+
9
54
 
10
55
  class LogitsProcessor(nn.Module):
11
56
  def __init__(self, config):
@@ -13,78 +58,159 @@ class LogitsProcessor(nn.Module):
13
58
  self.config = config
14
59
  self.tp_size = get_tensor_model_parallel_world_size()
15
60
 
16
- def forward(self, input_ids, hidden_states, weight, input_metadata):
17
- last_index = None
61
+ def _get_normalized_prompt_logprobs(
62
+ self, prefill_token_logprobs, logits_metadata: LogitsMetadata
63
+ ):
64
+ logprobs_cumsum = torch.cumsum(
65
+ prefill_token_logprobs, dim=0, dtype=torch.float32
66
+ )
18
67
 
19
- # Compute the last index (the first decode token) of each requeast
20
- # if we are in prefill or extend mode.
21
- if input_metadata.forward_mode != ForwardMode.DECODE:
22
- last_index = (
23
- torch.cumsum(
24
- input_metadata.seq_lens - input_metadata.prefix_lens,
25
- dim=0,
26
- dtype=torch.long,
68
+ start = logits_metadata.extend_start_loc.clone()
69
+ end = start + logits_metadata.extend_seq_lens - 2
70
+ start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
71
+ end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
72
+ sum_logp = (
73
+ logprobs_cumsum[end]
74
+ - logprobs_cumsum[start]
75
+ + prefill_token_logprobs[start]
76
+ )
77
+ normalized_prompt_logprobs = sum_logp / (
78
+ (logits_metadata.extend_seq_lens - 1).clamp(min=1)
79
+ )
80
+
81
+ return normalized_prompt_logprobs
82
+
83
+ def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
84
+ # TODO: vectorize the code below
85
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
86
+ decode_top_logprobs = []
87
+ for i in range(all_logprobs.shape[0]):
88
+ k = logits_metadata.top_logprobs_nums[i]
89
+ t = all_logprobs[i].topk(k)
90
+ v_cpu = t.values.tolist()
91
+ p_cpu = t.indices.tolist()
92
+ decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
93
+ return None, decode_top_logprobs
94
+ else:
95
+ prefill_top_logprobs, decode_top_logprobs = [], []
96
+ pt = 0
97
+ extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
98
+ for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
99
+ if extend_seq_len == 0:
100
+ prefill_top_logprobs.append([])
101
+ decode_top_logprobs.append([])
102
+ continue
103
+ k = logits_metadata.top_logprobs_nums[i]
104
+ t = all_logprobs[pt : pt + extend_seq_len].topk(k)
105
+ vs_cpu = t.values.tolist()
106
+ ps_cpu = t.indices.tolist()
107
+ prefill_top_logprobs.append(
108
+ [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
27
109
  )
110
+ decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
111
+ pt += extend_seq_len
112
+
113
+ return prefill_top_logprobs, decode_top_logprobs
114
+
115
+ def forward(
116
+ self,
117
+ input_ids,
118
+ hidden_states,
119
+ weight,
120
+ logits_metadata: Union[LogitsMetadata, InputMetadata],
121
+ ):
122
+ if isinstance(logits_metadata, InputMetadata):
123
+ logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
124
+ assert isinstance(logits_metadata, LogitsMetadata)
125
+
126
+ # Get the last hidden states and last logits for the next token prediction
127
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
128
+ last_index = None
129
+ last_hidden = hidden_states
130
+ else:
131
+ last_index = (
132
+ torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
28
133
  - 1
29
134
  )
135
+ last_hidden = hidden_states[last_index]
30
136
 
31
- if not input_metadata.return_logprob:
32
- # When logprob is not requested, only compute the last logits.
33
- if input_metadata.forward_mode == ForwardMode.DECODE:
34
- last_hidden = hidden_states
35
- else:
36
- last_hidden = hidden_states[last_index]
37
- hidden_states = None
38
-
39
- last_logits = torch.matmul(last_hidden, weight.T)
40
- if self.tp_size > 1:
41
- last_logits = tensor_model_parallel_all_gather(last_logits)
42
- last_logits = last_logits[:, : self.config.vocab_size]
43
- return last_logits, (None, None, None)
137
+ last_logits = torch.matmul(last_hidden, weight.T)
138
+ if self.tp_size > 1:
139
+ last_logits = tensor_model_parallel_all_gather(last_logits)
140
+ last_logits = last_logits[:, : self.config.vocab_size]
141
+
142
+ if hasattr(self.config, "final_logit_softcapping"):
143
+ last_logits /= self.config.final_logit_softcapping
144
+ last_logits = torch.tanh(last_logits)
145
+ last_logits *= self.config.final_logit_softcapping
146
+
147
+ # Return only last_logits if logprob is not requested
148
+ if not logits_metadata.return_logprob:
149
+ return LogitProcessorOutput(
150
+ next_token_logits=last_logits,
151
+ next_token_logprobs=None,
152
+ normalized_prompt_logprobs=None,
153
+ prefill_token_logprobs=None,
154
+ prefill_top_logprobs=None,
155
+ decode_top_logprobs=None,
156
+ )
44
157
  else:
45
158
  # When logprob is requested, compute the logits for all tokens.
46
- logits = torch.matmul(hidden_states, weight.T)
47
- if self.tp_size > 1:
48
- logits = tensor_model_parallel_all_gather(logits)
49
- logits = logits[:, : self.config.vocab_size]
50
- all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
51
-
52
- if input_metadata.forward_mode == ForwardMode.DECODE:
53
- last_logits = logits
54
- last_logprobs = all_logprobs
55
- prefill_logprobs = normalized_logprobs = None
159
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
160
+ all_logits = last_logits
161
+ else:
162
+ all_logits = torch.matmul(hidden_states, weight.T)
163
+ if self.tp_size > 1:
164
+ all_logits = tensor_model_parallel_all_gather(all_logits)
165
+ all_logits = all_logits[:, : self.config.vocab_size]
166
+
167
+ all_logprobs = all_logits.float()
168
+ del all_logits
169
+ all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
170
+
171
+ # Get the logprob of top-k tokens
172
+ return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
173
+ if return_top_logprob:
174
+ prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
175
+ all_logprobs, logits_metadata
176
+ )
177
+ else:
178
+ prefill_top_logprobs = decode_top_logprobs = None
179
+
180
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
181
+ return LogitProcessorOutput(
182
+ next_token_logits=last_logits,
183
+ next_token_logprobs=all_logprobs,
184
+ normalized_prompt_logprobs=None,
185
+ prefill_token_logprobs=None,
186
+ prefill_top_logprobs=None,
187
+ decode_top_logprobs=decode_top_logprobs,
188
+ )
56
189
  else:
57
- # Compute the logprobs for the last token of each request.
58
- last_logits = logits[last_index]
59
190
  last_logprobs = all_logprobs[last_index]
60
191
 
61
192
  # Compute the logprobs and normalized logprobs for the prefill tokens.
62
193
  # Note that we pad a zero at the end of each sequence for easy computation.
63
- prefill_logprobs = all_logprobs[
194
+ prefill_token_logprobs = all_logprobs[
64
195
  torch.arange(all_logprobs.shape[0], device="cuda"),
65
196
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
66
197
  ]
67
- logprobs_cumsum = torch.cumsum(
68
- prefill_logprobs, dim=0, dtype=torch.float32
69
- )
70
198
 
71
- start = input_metadata.extend_start_loc.clone()
72
- end = start + input_metadata.extend_seq_lens - 2
73
- start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
74
- end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
75
- sum_logp = (
76
- logprobs_cumsum[end]
77
- - logprobs_cumsum[start]
78
- + prefill_logprobs[start]
79
- )
80
- normalized_logprobs = sum_logp / (
81
- (input_metadata.extend_seq_lens - 1).clamp(min=1)
199
+ normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
200
+ prefill_token_logprobs, logits_metadata
82
201
  )
83
202
 
84
- return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
203
+ return LogitProcessorOutput(
204
+ next_token_logits=last_logits,
205
+ next_token_logprobs=last_logprobs,
206
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
207
+ prefill_token_logprobs=prefill_token_logprobs,
208
+ prefill_top_logprobs=prefill_top_logprobs,
209
+ decode_top_logprobs=decode_top_logprobs,
210
+ )
85
211
 
86
212
 
87
- if __name__ == "__main__":
213
+ def test():
88
214
  all_logprobs = torch.tensor(
89
215
  # s s s
90
216
  [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
@@ -93,23 +219,26 @@ if __name__ == "__main__":
93
219
  )
94
220
  seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
95
221
  input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
96
- logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
97
222
 
98
- logprobs = all_logprobs[
223
+ token_logprobs = all_logprobs[
99
224
  torch.arange(all_logprobs.shape[0], device="cuda"),
100
225
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
101
226
  ]
102
- logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
227
+ logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
103
228
 
104
229
  len_cumsum = torch.cumsum(seq_lens, dim=0)
105
230
  start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
106
231
  end = start + seq_lens - 2
107
- start.clamp_(min=0, max=logprobs.shape[0] - 1)
108
- end.clamp_(min=0, max=logprobs.shape[0] - 1)
109
- sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
232
+ start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
233
+ end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
234
+ sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
110
235
 
111
236
  # assert logprobs == [2, _, 2, 4, _]
112
- print("logprobs", logprobs)
237
+ print("token logprobs", token_logprobs)
113
238
  print("start", start)
114
239
  print("end", end)
115
240
  print("sum_logp", sum_logp)
241
+
242
+
243
+ if __name__ == "__main__":
244
+ test()
@@ -1,46 +1,42 @@
1
+ """Radix attention."""
2
+
1
3
  import torch
2
- from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
4
+ from flashinfer.cascade import merge_state
5
+ from torch import nn
6
+
7
+ from sglang.global_config import global_config
3
8
  from sglang.srt.layers.extend_attention import extend_attention_fwd
4
9
  from sglang.srt.layers.token_attention import token_attention_fwd
5
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
6
- from torch import nn
10
+ from sglang.srt.managers.controller.infer_batch import global_server_args_dict
11
+ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
7
12
 
8
13
 
9
14
  class RadixAttention(nn.Module):
10
- def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
15
+ def __init__(
16
+ self,
17
+ num_heads: int,
18
+ head_dim: int,
19
+ scaling: float,
20
+ num_kv_heads: int,
21
+ layer_id: int,
22
+ logit_cap: int = -1,
23
+ ):
11
24
  super().__init__()
12
25
  self.tp_q_head_num = num_heads
13
26
  self.tp_k_head_num = num_kv_heads
14
27
  self.tp_v_head_num = num_kv_heads
15
28
  self.head_dim = head_dim
29
+ self.scaling = scaling
16
30
  self.layer_id = layer_id
17
31
 
18
- from sglang.srt.managers.router.model_runner import global_server_args_dict
19
-
20
- if global_server_args_dict.get("enable_flashinfer", False):
21
- self.prefill_forward = self.prefill_forward_flashinfer
22
- self.extend_forward = self.prefill_forward_flashinfer
32
+ if not global_server_args_dict.get("disable_flashinfer", False):
33
+ self.extend_forward = self.extend_forward_flashinfer
23
34
  self.decode_forward = self.decode_forward_flashinfer
24
35
  else:
25
- self.prefill_forward = self.prefill_forward_triton
26
36
  self.extend_forward = self.extend_forward_triton
27
37
  self.decode_forward = self.decode_forward_triton
28
38
 
29
- def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
30
- o = torch.empty_like(q)
31
-
32
- context_attention_fwd(
33
- q.view(-1, self.tp_q_head_num, self.head_dim),
34
- k,
35
- v,
36
- o.view(-1, self.tp_q_head_num, self.head_dim),
37
- input_metadata.start_loc,
38
- input_metadata.seq_lens,
39
- input_metadata.max_seq_len,
40
- )
41
- self.store_kv_cache(k, v, input_metadata)
42
-
43
- return o
39
+ self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
44
40
 
45
41
  def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
46
42
  o = torch.empty_like(q)
@@ -54,13 +50,15 @@ class RadixAttention(nn.Module):
54
50
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
55
51
  input_metadata.req_to_token_pool.req_to_token,
56
52
  input_metadata.req_pool_indices,
57
- input_metadata.start_loc,
53
+ input_metadata.triton_start_loc,
58
54
  input_metadata.seq_lens,
59
- input_metadata.prefix_lens,
55
+ input_metadata.triton_prefix_lens,
60
56
  input_metadata.extend_start_loc,
61
57
  input_metadata.extend_seq_lens,
62
- input_metadata.max_seq_len,
63
- input_metadata.max_extend_len,
58
+ input_metadata.triton_max_seq_len,
59
+ input_metadata.triton_max_extend_len,
60
+ sm_scale=self.scaling,
61
+ logit_cap=self.logit_cap,
64
62
  )
65
63
 
66
64
  return o
@@ -76,31 +74,54 @@ class RadixAttention(nn.Module):
76
74
  o.view(-1, self.tp_q_head_num, self.head_dim),
77
75
  input_metadata.req_to_token_pool.req_to_token,
78
76
  input_metadata.req_pool_indices,
79
- input_metadata.start_loc,
77
+ input_metadata.triton_start_loc,
80
78
  input_metadata.seq_lens,
81
- input_metadata.max_seq_len,
82
- input_metadata.other_kv_index,
79
+ input_metadata.triton_max_seq_len,
83
80
  input_metadata.total_num_tokens,
81
+ sm_scale=self.scaling,
82
+ logit_cap=self.logit_cap,
84
83
  )
85
84
 
86
85
  return o
87
86
 
88
- def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
89
- self.store_kv_cache(k, v, input_metadata)
90
-
91
- o = input_metadata.prefill_wrapper.forward(
87
+ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
88
+ o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
92
89
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
93
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
90
+ k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
91
+ v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
92
+ causal=True,
93
+ sm_scale=self.scaling,
94
+ logits_soft_cap=self.logit_cap,
94
95
  )
95
96
 
97
+ if input_metadata.extend_no_prefix:
98
+ o = o1
99
+ else:
100
+ o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
101
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
102
+ input_metadata.token_to_kv_pool.kv_data[self.layer_id],
103
+ causal=False,
104
+ sm_scale=self.scaling,
105
+ logits_soft_cap=self.logit_cap,
106
+ )
107
+
108
+ o, _ = merge_state(o1, s1, o2, s2)
109
+
110
+ self.store_kv_cache(k, v, input_metadata)
111
+
112
+ if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
113
+ torch.cuda.synchronize()
114
+
96
115
  return o.view(-1, self.tp_q_head_num * self.head_dim)
97
116
 
98
117
  def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
99
118
  self.store_kv_cache(k, v, input_metadata)
100
119
 
101
- o = input_metadata.decode_wrapper.forward(
120
+ o = input_metadata.flashinfer_decode_wrapper.forward(
102
121
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
103
122
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
123
+ sm_scale=self.scaling,
124
+ logits_soft_cap=self.logit_cap,
104
125
  )
105
126
 
106
127
  return o.view(-1, self.tp_q_head_num * self.head_dim)
@@ -109,25 +130,13 @@ class RadixAttention(nn.Module):
109
130
  k = k.view(-1, self.tp_k_head_num, self.head_dim)
110
131
  v = v.view(-1, self.tp_v_head_num, self.head_dim)
111
132
 
112
- if input_metadata.forward_mode == ForwardMode.PREFILL:
113
- return self.prefill_forward(q, k, v, input_metadata)
114
- elif input_metadata.forward_mode == ForwardMode.EXTEND:
133
+ if input_metadata.forward_mode == ForwardMode.EXTEND:
115
134
  return self.extend_forward(q, k, v, input_metadata)
116
135
  elif input_metadata.forward_mode == ForwardMode.DECODE:
117
136
  return self.decode_forward(q, k, v, input_metadata)
118
137
 
119
138
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
120
139
  key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
140
+ key_buffer[input_metadata.out_cache_loc] = cache_k
121
141
  value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
122
- if input_metadata.out_cache_loc is not None:
123
- key_buffer[input_metadata.out_cache_loc] = cache_k
124
- value_buffer[input_metadata.out_cache_loc] = cache_v
125
- elif input_metadata.out_cache_cont_start is not None:
126
- key_buffer[
127
- input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
128
- ] = cache_k
129
- value_buffer[
130
- input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
131
- ] = cache_v
132
- else:
133
- raise RuntimeError()
142
+ value_buffer[input_metadata.out_cache_loc] = cache_v
@@ -4,7 +4,8 @@
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
- from sglang.srt.managers.router.model_runner import global_server_args_dict
7
+
8
+ from sglang.srt.managers.controller.model_runner import global_server_args_dict
8
9
  from sglang.srt.utils import wrap_kernel_launcher
9
10
 
10
11
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
@@ -15,6 +16,12 @@ else:
15
16
  REDUCE_TORCH_TYPE = torch.float16
16
17
 
17
18
 
19
+ @triton.jit
20
+ def tanh(x):
21
+ # Tanh is just a scaled sigmoid
22
+ return 2 * tl.sigmoid(2 * x) - 1
23
+
24
+
18
25
  @triton.jit
19
26
  def _fwd_kernel_stage1(
20
27
  Q,
@@ -34,6 +41,7 @@ def _fwd_kernel_stage1(
34
41
  kv_group_num: tl.constexpr,
35
42
  BLOCK_DMODEL: tl.constexpr,
36
43
  BLOCK_N: tl.constexpr,
44
+ logit_cap: tl.constexpr,
37
45
  ):
38
46
  cur_batch = tl.program_id(0)
39
47
  cur_head = tl.program_id(1)
@@ -76,6 +84,10 @@ def _fwd_kernel_stage1(
76
84
  ).to(REDUCE_TRITON_TYPE)
77
85
  att_value = tl.sum(q[None, :] * k, 1)
78
86
  att_value *= sm_scale
87
+
88
+ if logit_cap > 0:
89
+ att_value = logit_cap * tanh(att_value / logit_cap)
90
+
79
91
  off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
80
92
  tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
81
93
 
@@ -95,7 +107,6 @@ def _fwd_kernel_stage2(
95
107
  stride_obs,
96
108
  stride_oh,
97
109
  stride_req_to_token_b,
98
- other_kv_index, # To fix a NAN issue
99
110
  kv_group_num: tl.constexpr,
100
111
  BLOCK_DMODEL: tl.constexpr,
101
112
  BLOCK_N: tl.constexpr,
@@ -126,7 +137,7 @@ def _fwd_kernel_stage2(
126
137
  + cur_batch_req_idx * stride_req_to_token_b
127
138
  + (start_n + offs_n),
128
139
  mask=(start_n + offs_n) < cur_batch_seq_len,
129
- other=other_kv_index,
140
+ other=0,
130
141
  )
131
142
 
132
143
  qk = tl.load(
@@ -164,13 +175,14 @@ def _token_att_m_fwd(
164
175
  B_Start_Loc,
165
176
  B_Seqlen,
166
177
  max_len_in_batch,
178
+ sm_scale,
179
+ logit_cap,
167
180
  ):
168
181
  BLOCK = 32
169
182
  # shape constraints
170
183
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
171
184
  assert Lq == Lk
172
185
  assert Lk in {16, 32, 64, 128, 256}
173
- sm_scale = 1.0 / (Lk**0.5)
174
186
 
175
187
  batch, head_num = B_req_idx.shape[0], q.shape[1]
176
188
 
@@ -222,6 +234,7 @@ def _token_att_m_fwd(
222
234
  kv_group_num=kv_group_num,
223
235
  BLOCK_DMODEL=Lk,
224
236
  BLOCK_N=BLOCK,
237
+ logit_cap=logit_cap,
225
238
  num_warps=num_warps,
226
239
  num_stages=1,
227
240
  )
@@ -236,7 +249,6 @@ def _token_softmax_reducev_fwd(
236
249
  b_req_idx,
237
250
  b_start_loc,
238
251
  b_seq_len,
239
- other_kv_index,
240
252
  ):
241
253
  BLOCK = 64
242
254
  batch, head = b_seq_len.shape[0], logics.shape[0]
@@ -263,7 +275,6 @@ def _token_softmax_reducev_fwd(
263
275
  o.stride(0),
264
276
  o.stride(1),
265
277
  req_to_tokens.stride(0),
266
- other_kv_index,
267
278
  )
268
279
  return
269
280
 
@@ -281,7 +292,6 @@ def _token_softmax_reducev_fwd(
281
292
  o.stride(0),
282
293
  o.stride(1),
283
294
  req_to_tokens.stride(0),
284
- other_kv_index,
285
295
  kv_group_num=kv_group_num,
286
296
  BLOCK_DMODEL=v_buffer.shape[-1],
287
297
  BLOCK_N=BLOCK,
@@ -301,8 +311,9 @@ def token_attention_fwd(
301
311
  b_start_loc,
302
312
  b_seq_len,
303
313
  max_len_in_batch,
304
- other_kv_index,
305
314
  total_num_tokens,
315
+ sm_scale,
316
+ logit_cap=-1,
306
317
  att_m=None,
307
318
  ):
308
319
  if att_m is None:
@@ -319,6 +330,8 @@ def token_attention_fwd(
319
330
  b_start_loc,
320
331
  b_seq_len,
321
332
  max_len_in_batch,
333
+ sm_scale,
334
+ logit_cap,
322
335
  )
323
336
  _token_softmax_reducev_fwd(
324
337
  att_m,
@@ -328,5 +341,4 @@ def token_attention_fwd(
328
341
  b_req_idx,
329
342
  b_start_loc,
330
343
  b_seq_len,
331
- other_kv_index,
332
344
  )