sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def create_flashinfer_kv_indices_triton(
8
+ req_to_token_ptr, # [max_batch, max_context_len]
9
+ req_pool_indices_ptr,
10
+ page_kernel_lens_ptr,
11
+ kv_indptr,
12
+ kv_start_idx,
13
+ kv_indices_ptr,
14
+ max_context_len: tl.constexpr,
15
+ ):
16
+ BLOCK_SIZE: tl.constexpr = 512
17
+ pid = tl.program_id(axis=0)
18
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
19
+ kv_indices_offset = tl.load(kv_indptr + pid)
20
+
21
+ kv_start = 0
22
+ kv_end = 0
23
+ if kv_start_idx:
24
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
25
+ kv_end = kv_start
26
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
27
+
28
+ req_to_token_ptr += req_pool_index * max_context_len
29
+ kv_indices_ptr += kv_indices_offset
30
+
31
+ ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
32
+ st_offset = tl.arange(0, BLOCK_SIZE)
33
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
34
+ for _ in range(num_loop):
35
+ mask = ld_offset < kv_end
36
+ data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
37
+ tl.store(kv_indices_ptr + st_offset, data, mask=mask)
38
+ ld_offset += BLOCK_SIZE
39
+ st_offset += BLOCK_SIZE
40
+
41
+
42
+ class FlashinferUpdater:
43
+ def __init__(
44
+ self,
45
+ forward_mode,
46
+ model_runner,
47
+ req_pool_indices,
48
+ seq_lens,
49
+ prefix_lens,
50
+ decode_wrapper=None,
51
+ use_ragged=False,
52
+ ):
53
+ self.forward_mode = forward_mode
54
+ self.model_runner = model_runner
55
+ self.req_pool_indices = req_pool_indices
56
+ self.seq_lens = seq_lens
57
+ self.prefix_lens = prefix_lens
58
+ self.use_ragged = use_ragged
59
+
60
+ self.num_qo_heads = (
61
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
62
+ )
63
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
64
+ model_runner.tp_size
65
+ )
66
+ self.head_dim = model_runner.model_config.head_dim
67
+ self.batch_size = len(req_pool_indices)
68
+
69
+ self.decode_wrapper = (
70
+ decode_wrapper or self.model_runner.attn_backend.decode_wrapper
71
+ )
72
+ self.prefill_wrapper_ragged = (
73
+ self.model_runner.attn_backend.prefill_wrapper_ragged
74
+ )
75
+ self.prefill_wrapper_paged = (
76
+ self.model_runner.attn_backend.prefill_wrapper_paged
77
+ )
78
+
79
+ self.kv_last_page_len = torch.ones(
80
+ (self.batch_size,), dtype=torch.int32, device="cuda"
81
+ )
82
+
83
+ def _init_indices_no_sliding_window(self):
84
+ if self.use_ragged:
85
+ paged_kernel_lens = self.prefix_lens
86
+ else:
87
+ paged_kernel_lens = self.seq_lens
88
+
89
+ self.kv_indptr = torch.zeros(
90
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
91
+ )
92
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
93
+ self.kv_indices = torch.empty(
94
+ self.kv_indptr[-1], dtype=torch.int32, device="cuda"
95
+ )
96
+
97
+ create_flashinfer_kv_indices_triton[(self.batch_size,)](
98
+ self.model_runner.req_to_token_pool.req_to_token,
99
+ self.req_pool_indices,
100
+ paged_kernel_lens,
101
+ self.kv_indptr,
102
+ None,
103
+ self.kv_indices,
104
+ self.model_runner.req_to_token_pool.req_to_token.size(1),
105
+ )
106
+
107
+ def _init_indices_sliding_window(self, wrapper_id):
108
+ if wrapper_id == 0:
109
+ # window attention use paged only
110
+ if self.forward_mode.is_decode():
111
+ paged_kernel_lens = torch.minimum(
112
+ self.seq_lens,
113
+ torch.tensor(self.model_runner.sliding_window_size + 1),
114
+ )
115
+ else:
116
+ paged_kernel_lens = torch.minimum(
117
+ self.seq_lens,
118
+ torch.tensor(self.model_runner.sliding_window_size)
119
+ + self.seq_lens
120
+ - self.prefix_lens,
121
+ )
122
+ else:
123
+ # full attention
124
+ paged_kernel_lens = self.seq_lens
125
+
126
+ kv_start_idx = self.seq_lens - paged_kernel_lens
127
+ self.kv_indptr = torch.zeros(
128
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
129
+ )
130
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
131
+ self.kv_indices = torch.empty(
132
+ self.kv_indptr[-1], dtype=torch.int32, device="cuda"
133
+ )
134
+ create_flashinfer_kv_indices_triton[(self.batch_size,)](
135
+ self.model_runner.req_to_token_pool.req_to_token,
136
+ self.req_pool_indices,
137
+ paged_kernel_lens,
138
+ self.kv_indptr,
139
+ kv_start_idx,
140
+ self.kv_indices,
141
+ self.model_runner.req_to_token_pool.req_to_token.size(1),
142
+ )
143
+
144
+ def _update_decode_indices(self, decode_wrapper):
145
+ decode_wrapper.end_forward()
146
+ decode_wrapper.begin_forward(
147
+ self.kv_indptr,
148
+ self.kv_indices,
149
+ self.kv_last_page_len,
150
+ self.num_qo_heads,
151
+ self.num_kv_heads,
152
+ self.head_dim,
153
+ 1,
154
+ data_type=self.model_runner.kv_cache_dtype,
155
+ q_data_type=self.model_runner.dtype,
156
+ )
157
+
158
+ def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
159
+ # extend part
160
+ qo_indptr = torch.zeros(
161
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
162
+ )
163
+ qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
164
+
165
+ if self.use_ragged:
166
+ ragged_wrapper.end_forward()
167
+ ragged_wrapper.begin_forward(
168
+ qo_indptr,
169
+ qo_indptr,
170
+ self.num_qo_heads,
171
+ self.num_kv_heads,
172
+ self.head_dim,
173
+ )
174
+
175
+ # cached part
176
+ paged_wrapper.end_forward()
177
+ paged_wrapper.begin_forward(
178
+ qo_indptr,
179
+ self.kv_indptr,
180
+ self.kv_indices,
181
+ self.kv_last_page_len,
182
+ self.num_qo_heads,
183
+ self.num_kv_heads,
184
+ self.head_dim,
185
+ 1,
186
+ )
187
+
188
+ def update_indices_no_sliding_window(self):
189
+ self._init_indices_no_sliding_window()
190
+
191
+ if self.forward_mode.is_decode():
192
+ self._update_decode_indices(self.decode_wrapper)
193
+ else:
194
+ self._update_extend_indices(
195
+ self.prefill_wrapper_ragged,
196
+ self.prefill_wrapper_paged,
197
+ )
198
+
199
+ def update_indices_sliding_window(self):
200
+ assert self.use_ragged is False
201
+
202
+ for wrapper_id in range(2):
203
+ self._init_indices_sliding_window(wrapper_id)
204
+ if self.forward_mode.is_decode():
205
+ self._update_decode_indices(self.decode_wrapper[wrapper_id])
206
+ else:
207
+ self._update_extend_indices(
208
+ None,
209
+ self.prefill_wrapper_paged[wrapper_id],
210
+ )
211
+
212
+
213
+ def update_flashinfer_indices(
214
+ forward_mode,
215
+ model_runner,
216
+ req_pool_indices,
217
+ seq_lens,
218
+ prefix_lens,
219
+ decode_wrapper=None,
220
+ use_ragged=False,
221
+ ):
222
+ updater = FlashinferUpdater(
223
+ forward_mode,
224
+ model_runner,
225
+ req_pool_indices,
226
+ seq_lens,
227
+ prefix_lens,
228
+ decode_wrapper,
229
+ use_ragged,
230
+ )
231
+
232
+ if model_runner.sliding_window_size is None:
233
+ updater.update_indices_no_sliding_window()
234
+ else:
235
+ updater.update_indices_sliding_window()
@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
37
37
 
38
38
  # The normlaized logprobs of prompts. shape: [#seq]
39
39
  normalized_prompt_logprobs: torch.Tensor
40
- # The logprobs of input tokens. shape: [#token, vocab_size]
40
+ # The logprobs of input tokens. shape: [#token, vocab_size]
41
41
  input_token_logprobs: torch.Tensor
42
42
 
43
43
  # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
49
49
  @dataclasses.dataclass
50
50
  class LogitsMetadata:
51
51
  forward_mode: ForwardMode
52
+ top_logprobs_nums: Optional[List[int]]
53
+
52
54
  return_logprob: bool = False
55
+ return_top_logprob: bool = False
53
56
 
54
57
  extend_seq_lens: Optional[torch.Tensor] = None
55
- extend_start_loc: Optional[torch.Tensor] = None
56
- top_logprobs_nums: Optional[List[int]] = None
58
+ extend_seq_lens_cpu: Optional[List[int]] = None
57
59
 
58
- extend_seq_lens_cpu: List[int] = None
59
- logprob_start_lens_cpu: List[int] = None
60
+ extend_logprob_start_lens_cpu: Optional[List[int]] = None
61
+ extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
60
62
 
61
63
  @classmethod
62
64
  def from_input_metadata(cls, input_metadata: InputMetadata):
65
+ return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
66
+ if input_metadata.forward_mode.is_extend():
67
+ extend_logprob_pruned_lens_cpu = [
68
+ extend_len - start_len
69
+ for extend_len, start_len in zip(
70
+ input_metadata.extend_seq_lens,
71
+ input_metadata.extend_logprob_start_lens_cpu,
72
+ )
73
+ ]
74
+ else:
75
+ extend_logprob_pruned_lens_cpu = None
63
76
  return cls(
64
77
  forward_mode=input_metadata.forward_mode,
65
- extend_seq_lens=input_metadata.extend_seq_lens,
66
- extend_start_loc=input_metadata.extend_start_loc,
67
- return_logprob=input_metadata.return_logprob,
68
78
  top_logprobs_nums=input_metadata.top_logprobs_nums,
79
+ return_logprob=input_metadata.return_logprob,
80
+ return_top_logprob=return_top_logprob,
81
+ extend_seq_lens=input_metadata.extend_seq_lens,
69
82
  extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
70
- logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
83
+ extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
84
+ extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
71
85
  )
72
86
 
73
87
 
@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
82
96
  def _get_normalized_prompt_logprobs(
83
97
  self,
84
98
  input_token_logprobs: torch.Tensor,
85
- cum_start_len0: torch.Tensor,
86
- cum_start_len1: torch.Tensor,
87
99
  logits_metadata: LogitsMetadata,
88
100
  ):
89
101
  logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
102
+ pruned_lens = torch.tensor(
103
+ logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
104
+ )
90
105
 
91
- start = logits_metadata.extend_start_loc.clone() - cum_start_len0
92
- end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
93
- start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
94
- end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
106
+ start = torch.zeros_like(pruned_lens)
107
+ start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
108
+ end = torch.clamp(
109
+ start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
110
+ )
95
111
  sum_logp = (
96
112
  logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
97
113
  )
98
- normalized_prompt_logprobs = sum_logp / (
99
- (logits_metadata.extend_seq_lens - 1).clamp(min=1)
100
- )
101
-
114
+ normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
102
115
  return normalized_prompt_logprobs
103
116
 
104
117
  @staticmethod
105
118
  def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
106
- if logits_metadata.forward_mode == ForwardMode.DECODE:
119
+ max_k = max(logits_metadata.top_logprobs_nums)
120
+ ret = all_logprobs.topk(max_k, dim=1)
121
+ values = ret.values.tolist()
122
+ indices = ret.indices.tolist()
123
+
124
+ if logits_metadata.forward_mode.is_decode():
107
125
  output_top_logprobs = []
108
- max_k = max(logits_metadata.top_logprobs_nums)
109
- ret = all_logprobs.topk(max_k, dim=1)
110
- values = ret.values.tolist()
111
- indices = ret.indices.tolist()
112
126
  for i, k in enumerate(logits_metadata.top_logprobs_nums):
113
127
  output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
114
128
  return None, output_top_logprobs
115
129
  else:
116
- # TODO: vectorize the code below
117
130
  input_top_logprobs, output_top_logprobs = [], []
118
- pt = 0
119
- extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
120
131
 
121
- max_k = max(logits_metadata.top_logprobs_nums)
122
- ret = all_logprobs.topk(max_k, dim=1)
123
- values = ret.values.tolist()
124
- indices = ret.indices.tolist()
125
-
126
- for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
127
- start_len = logits_metadata.logprob_start_lens_cpu[i]
128
- pruned_len = extend_seq_len - start_len
129
-
130
- if extend_seq_len == 0:
132
+ pt = 0
133
+ for k, pruned_len in zip(
134
+ logits_metadata.top_logprobs_nums,
135
+ logits_metadata.extend_logprob_pruned_lens_cpu,
136
+ ):
137
+ if pruned_len <= 0:
131
138
  input_top_logprobs.append([])
132
139
  output_top_logprobs.append([])
133
140
  continue
134
141
 
135
- k = logits_metadata.top_logprobs_nums[i]
136
142
  input_top_logprobs.append(
137
143
  [
138
144
  list(zip(values[pt + j][:k], indices[pt + j][:k]))
@@ -163,14 +169,11 @@ class LogitsProcessor(nn.Module):
163
169
  assert isinstance(logits_metadata, LogitsMetadata)
164
170
 
165
171
  # Get the last hidden states and last logits for the next token prediction
166
- if logits_metadata.forward_mode == ForwardMode.DECODE:
172
+ if logits_metadata.forward_mode.is_decode():
167
173
  last_index = None
168
174
  last_hidden = hidden_states
169
175
  else:
170
- last_index = (
171
- torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
172
- - 1
173
- )
176
+ last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
174
177
  last_hidden = hidden_states[last_index]
175
178
 
176
179
  last_logits = torch.matmul(last_hidden, weight.T)
@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
194
197
  output_top_logprobs=None,
195
198
  )
196
199
  else:
197
- # When logprob is requested, compute the logits for all tokens.
198
- if logits_metadata.forward_mode == ForwardMode.DECODE:
199
- last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
200
+ last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
200
201
 
201
- # Get the logprob of top-k tokens
202
- return_top_logprob = any(
203
- x > 0 for x in logits_metadata.top_logprobs_nums
204
- )
205
- if return_top_logprob:
202
+ if logits_metadata.forward_mode.is_decode():
203
+ if logits_metadata.return_top_logprob:
206
204
  output_top_logprobs = self.get_top_logprobs(
207
205
  last_logprobs, logits_metadata
208
206
  )[1]
209
207
  else:
210
208
  output_top_logprobs = None
211
-
212
209
  return LogitsProcessorOutput(
213
210
  next_token_logits=last_logits,
214
211
  next_token_logprobs=last_logprobs,
@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
218
215
  output_top_logprobs=output_top_logprobs,
219
216
  )
220
217
  else:
218
+ # Slice the requested tokens to compute logprob
221
219
  pt, states, pruned_input_ids = 0, [], []
222
- for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
223
- start_len = logits_metadata.logprob_start_lens_cpu[i]
220
+ for start_len, extend_len in zip(
221
+ logits_metadata.extend_logprob_start_lens_cpu,
222
+ logits_metadata.extend_seq_lens_cpu,
223
+ ):
224
224
  states.append(hidden_states[pt + start_len : pt + extend_len])
225
225
  pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
226
226
  pt += extend_len
227
227
 
228
+ # Compute the logits and logprobs for all required tokens
228
229
  states = torch.cat(states, dim=0)
229
- pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
230
-
231
- cum_start_len1 = torch.tensor(
232
- logits_metadata.logprob_start_lens_cpu, device="cuda"
233
- ).cumsum(0)
234
- cum_start_len0 = torch.zeros_like(cum_start_len1)
235
- cum_start_len0[1:] = cum_start_len1[:-1]
236
-
237
230
  all_logits = torch.matmul(states, weight.T)
238
231
  if self.do_tensor_parallel_all_gather:
239
232
  all_logits = tensor_model_parallel_all_gather(all_logits)
@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
249
242
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
250
243
 
251
244
  # Get the logprob of top-k tokens
252
- return_top_logprob = any(
253
- x > 0 for x in logits_metadata.top_logprobs_nums
254
- )
255
- if return_top_logprob:
245
+ if logits_metadata.return_top_logprob:
256
246
  input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
257
247
  all_logprobs, logits_metadata
258
248
  )
259
249
  else:
260
250
  input_top_logprobs = output_top_logprobs = None
261
251
 
262
- last_logprobs = all_logprobs[last_index - cum_start_len1]
263
-
264
- # Compute the logprobs and normalized logprobs for the prefill tokens.
265
- # Note that we pad a zero at the end of each sequence for easy computation.
252
+ # Compute the normalized logprobs for the requested tokens.
253
+ # Note that we pad a zero at the end for easy batching.
266
254
  input_token_logprobs = all_logprobs[
267
255
  torch.arange(all_logprobs.shape[0], device="cuda"),
268
- torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
256
+ torch.cat(
257
+ [
258
+ torch.cat(pruned_input_ids)[1:],
259
+ torch.tensor([0], device="cuda"),
260
+ ]
261
+ ),
269
262
  ]
270
-
271
263
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
272
264
  input_token_logprobs,
273
- cum_start_len0,
274
- cum_start_len1,
275
265
  logits_metadata,
276
266
  )
277
267
 
278
- # Remove the last token logprob for the prefill tokens.
279
- input_token_logprobs = input_token_logprobs[:-1]
280
-
281
268
  return LogitsProcessorOutput(
282
269
  next_token_logits=last_logits,
283
270
  next_token_logprobs=last_logprobs,
@@ -15,20 +15,16 @@ limitations under the License.
15
15
 
16
16
  """Radix attention."""
17
17
 
18
- from typing import Optional
19
-
20
- import torch
21
- from flashinfer.cascade import merge_state
22
18
  from torch import nn
23
19
 
24
- from sglang.global_config import global_config
25
- from sglang.srt.layers.decode_attention import decode_attention_fwd
26
- from sglang.srt.layers.extend_attention import extend_attention_fwd
27
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
28
- from sglang.srt.model_executor.model_runner import global_server_args_dict
20
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
21
 
30
22
 
31
23
  class RadixAttention(nn.Module):
24
+ """
25
+ The attention layer implementation.
26
+ """
27
+
32
28
  def __init__(
33
29
  self,
34
30
  num_heads: int,
@@ -36,8 +32,8 @@ class RadixAttention(nn.Module):
36
32
  scaling: float,
37
33
  num_kv_heads: int,
38
34
  layer_id: int,
39
- sliding_window_size: Optional[int] = None,
40
- logit_cap: int = -1,
35
+ sliding_window_size: int = -1,
36
+ logit_cap: float = 0.0,
41
37
  v_head_dim: int = -1,
42
38
  ):
43
39
  super().__init__()
@@ -49,160 +45,14 @@ class RadixAttention(nn.Module):
49
45
  self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
50
46
  self.scaling = scaling
51
47
  self.layer_id = layer_id
52
- self.sliding_window_size = sliding_window_size if sliding_window_size else -1
53
-
54
- if (
55
- not global_server_args_dict.get("disable_flashinfer", False)
56
- and self.qk_head_dim == self.v_head_dim
57
- ):
58
- self.extend_forward = self.extend_forward_flashinfer
59
- self.decode_forward = self.decode_forward_flashinfer
60
- else:
61
- self.extend_forward = self.extend_forward_triton
62
- self.decode_forward = self.decode_forward_triton
63
-
64
- self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
65
-
66
- def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
67
- if self.qk_head_dim != self.v_head_dim:
68
- o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
69
- else:
70
- o = torch.empty_like(q)
71
-
72
- self.store_kv_cache(k, v, input_metadata)
73
- extend_attention_fwd(
74
- q.view(-1, self.tp_q_head_num, self.qk_head_dim),
75
- k.contiguous(),
76
- v.contiguous(),
77
- o.view(-1, self.tp_q_head_num, self.v_head_dim),
78
- input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
79
- input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
80
- input_metadata.req_to_token_pool.req_to_token,
81
- input_metadata.req_pool_indices,
82
- input_metadata.triton_start_loc,
83
- input_metadata.seq_lens,
84
- input_metadata.triton_prefix_lens,
85
- input_metadata.extend_start_loc,
86
- input_metadata.extend_seq_lens,
87
- input_metadata.triton_max_seq_len,
88
- input_metadata.triton_max_extend_len,
89
- sm_scale=self.scaling,
90
- logit_cap=self.logit_cap,
91
- )
92
-
93
- return o
94
-
95
- def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
96
- if self.qk_head_dim != self.v_head_dim:
97
- o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
98
- else:
99
- o = torch.empty_like(q)
100
- self.store_kv_cache(k, v, input_metadata)
101
-
102
- decode_attention_fwd(
103
- q.view(-1, self.tp_q_head_num, self.qk_head_dim),
104
- input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
105
- input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
106
- o.view(-1, self.tp_q_head_num, self.v_head_dim),
107
- input_metadata.req_to_token_pool.req_to_token,
108
- input_metadata.req_pool_indices,
109
- input_metadata.triton_start_loc,
110
- input_metadata.seq_lens,
111
- input_metadata.triton_max_seq_len,
112
- input_metadata.total_num_tokens,
113
- sm_scale=self.scaling,
114
- logit_cap=self.logit_cap,
115
- )
116
-
117
- return o
118
-
119
- def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
120
- # using two wrappers is unnecessary in the current PR, but are prepared for future PRs
121
- prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
122
- if self.sliding_window_size != -1:
123
- prefill_wrapper_paged = prefill_wrapper_paged[0]
124
- else:
125
- if isinstance(prefill_wrapper_paged, list):
126
- prefill_wrapper_paged = prefill_wrapper_paged[1]
127
-
128
- if not input_metadata.flashinfer_use_ragged:
129
- if k is not None:
130
- assert v is not None
131
- self.store_kv_cache(k, v, input_metadata)
132
-
133
- o = prefill_wrapper_paged.forward(
134
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
135
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
136
- causal=True,
137
- sm_scale=self.scaling,
138
- window_left=self.sliding_window_size,
139
- logits_soft_cap=self.logit_cap,
140
- )
141
- else:
142
- o1, s1 = (
143
- input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
144
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
145
- k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
146
- v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
147
- causal=True,
148
- sm_scale=self.scaling,
149
- logits_soft_cap=self.logit_cap,
150
- )
151
- )
152
-
153
- if input_metadata.extend_no_prefix:
154
- o = o1
155
- else:
156
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
157
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
158
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
159
- causal=False,
160
- sm_scale=self.scaling,
161
- logits_soft_cap=self.logit_cap,
162
- )
163
-
164
- o, _ = merge_state(o1, s1, o2, s2)
165
-
166
- self.store_kv_cache(k, v, input_metadata)
167
-
168
- if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
169
- torch.cuda.synchronize()
170
-
171
- return o.view(-1, self.tp_q_head_num * self.head_dim)
172
-
173
- def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
174
- decode_wrapper = input_metadata.flashinfer_decode_wrapper
175
- if self.sliding_window_size != -1:
176
- decode_wrapper = decode_wrapper[0]
177
- else:
178
- if isinstance(decode_wrapper, list):
179
- decode_wrapper = decode_wrapper[1]
180
-
181
- if k is not None:
182
- assert v is not None
183
- self.store_kv_cache(k, v, input_metadata)
184
-
185
- o = decode_wrapper.forward(
186
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
187
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
188
- sm_scale=self.scaling,
189
- logits_soft_cap=self.logit_cap,
190
- )
191
-
192
- return o.view(-1, self.tp_q_head_num * self.head_dim)
48
+ self.logit_cap = logit_cap
49
+ self.sliding_window_size = sliding_window_size or -1
193
50
 
194
51
  def forward(self, q, k, v, input_metadata: InputMetadata):
195
52
  if k is not None:
53
+ # For cross-layer sharing, kv can be None
196
54
  assert v is not None
197
55
  k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
198
56
  v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
199
57
 
200
- if input_metadata.forward_mode == ForwardMode.EXTEND:
201
- return self.extend_forward(q, k, v, input_metadata)
202
- elif input_metadata.forward_mode == ForwardMode.DECODE:
203
- return self.decode_forward(q, k, v, input_metadata)
204
-
205
- def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
206
- input_metadata.token_to_kv_pool.set_kv_buffer(
207
- self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
208
- )
58
+ return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)