sglang 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +3 -2
  3. sglang/global_config.py +1 -1
  4. sglang/lang/backend/runtime_endpoint.py +60 -49
  5. sglang/lang/interpreter.py +4 -2
  6. sglang/lang/ir.py +13 -4
  7. sglang/srt/constrained/jump_forward.py +13 -2
  8. sglang/srt/layers/activation.py +0 -1
  9. sglang/srt/layers/extend_attention.py +3 -1
  10. sglang/srt/layers/fused_moe/__init__.py +1 -0
  11. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  12. sglang/srt/layers/fused_moe/layer.py +587 -0
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/radix_attention.py +38 -14
  15. sglang/srt/managers/schedule_batch.py +9 -14
  16. sglang/srt/managers/tokenizer_manager.py +1 -1
  17. sglang/srt/managers/tp_worker.py +1 -7
  18. sglang/srt/model_executor/cuda_graph_runner.py +48 -17
  19. sglang/srt/model_executor/forward_batch_info.py +132 -58
  20. sglang/srt/model_executor/model_runner.py +61 -28
  21. sglang/srt/models/chatglm.py +2 -2
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/deepseek.py +2 -2
  24. sglang/srt/models/deepseek_v2.py +7 -6
  25. sglang/srt/models/gemma.py +1 -1
  26. sglang/srt/models/gemma2.py +11 -5
  27. sglang/srt/models/grok.py +50 -396
  28. sglang/srt/models/minicpm.py +2 -2
  29. sglang/srt/models/mixtral.py +56 -254
  30. sglang/srt/models/mixtral_quant.py +1 -4
  31. sglang/srt/models/qwen.py +2 -2
  32. sglang/srt/models/qwen2.py +2 -2
  33. sglang/srt/models/qwen2_moe.py +2 -2
  34. sglang/srt/models/stablelm.py +1 -1
  35. sglang/srt/openai_api/adapter.py +32 -21
  36. sglang/srt/sampling_params.py +0 -4
  37. sglang/srt/server.py +23 -15
  38. sglang/srt/server_args.py +7 -1
  39. sglang/srt/utils.py +1 -2
  40. sglang/test/runners.py +18 -10
  41. sglang/test/test_programs.py +32 -5
  42. sglang/test/test_utils.py +5 -1
  43. sglang/version.py +1 -1
  44. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/METADATA +12 -4
  45. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/RECORD +48 -48
  46. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  47. sglang/srt/model_loader/model_loader.py +0 -292
  48. sglang/srt/model_loader/utils.py +0 -275
  49. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  50. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@ limitations under the License.
15
15
 
16
16
  """Radix attention."""
17
17
 
18
+ from typing import Optional
19
+
18
20
  import torch
19
21
  from flashinfer.cascade import merge_state
20
22
  from torch import nn
@@ -34,6 +36,7 @@ class RadixAttention(nn.Module):
34
36
  scaling: float,
35
37
  num_kv_heads: int,
36
38
  layer_id: int,
39
+ sliding_window_size: Optional[int] = None,
37
40
  logit_cap: int = -1,
38
41
  v_head_dim: int = -1,
39
42
  ):
@@ -46,6 +49,7 @@ class RadixAttention(nn.Module):
46
49
  self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
47
50
  self.scaling = scaling
48
51
  self.layer_id = layer_id
52
+ self.sliding_window_size = sliding_window_size if sliding_window_size else -1
49
53
 
50
54
  if (
51
55
  not global_server_args_dict.get("disable_flashinfer", False)
@@ -113,14 +117,25 @@ class RadixAttention(nn.Module):
113
117
  return o
114
118
 
115
119
  def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
120
+ # using two wrappers is unnecessary in the current PR, but are prepared for future PRs
121
+ prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
122
+ if self.sliding_window_size != -1:
123
+ prefill_wrapper_paged = prefill_wrapper_paged[0]
124
+ else:
125
+ if isinstance(prefill_wrapper_paged, list):
126
+ prefill_wrapper_paged = prefill_wrapper_paged[1]
127
+
116
128
  if not input_metadata.flashinfer_use_ragged:
117
- self.store_kv_cache(k, v, input_metadata)
129
+ if k is not None:
130
+ assert v is not None
131
+ self.store_kv_cache(k, v, input_metadata)
118
132
 
119
- o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
133
+ o = prefill_wrapper_paged.forward(
120
134
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
121
135
  input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
122
136
  causal=True,
123
137
  sm_scale=self.scaling,
138
+ window_left=self.sliding_window_size,
124
139
  logits_soft_cap=self.logit_cap,
125
140
  )
126
141
  else:
@@ -138,14 +153,12 @@ class RadixAttention(nn.Module):
138
153
  if input_metadata.extend_no_prefix:
139
154
  o = o1
140
155
  else:
141
- o2, s2 = (
142
- input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
143
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
144
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
145
- causal=False,
146
- sm_scale=self.scaling,
147
- logits_soft_cap=self.logit_cap,
148
- )
156
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
157
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
158
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
159
+ causal=False,
160
+ sm_scale=self.scaling,
161
+ logits_soft_cap=self.logit_cap,
149
162
  )
150
163
 
151
164
  o, _ = merge_state(o1, s1, o2, s2)
@@ -158,9 +171,18 @@ class RadixAttention(nn.Module):
158
171
  return o.view(-1, self.tp_q_head_num * self.head_dim)
159
172
 
160
173
  def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
161
- self.store_kv_cache(k, v, input_metadata)
174
+ decode_wrapper = input_metadata.flashinfer_decode_wrapper
175
+ if self.sliding_window_size != -1:
176
+ decode_wrapper = decode_wrapper[0]
177
+ else:
178
+ if isinstance(decode_wrapper, list):
179
+ decode_wrapper = decode_wrapper[1]
180
+
181
+ if k is not None:
182
+ assert v is not None
183
+ self.store_kv_cache(k, v, input_metadata)
162
184
 
163
- o = input_metadata.flashinfer_decode_wrapper.forward(
185
+ o = decode_wrapper.forward(
164
186
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
165
187
  input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
166
188
  sm_scale=self.scaling,
@@ -170,8 +192,10 @@ class RadixAttention(nn.Module):
170
192
  return o.view(-1, self.tp_q_head_num * self.head_dim)
171
193
 
172
194
  def forward(self, q, k, v, input_metadata: InputMetadata):
173
- k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
174
- v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
195
+ if k is not None:
196
+ assert v is not None
197
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
198
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
175
199
 
176
200
  if input_metadata.forward_mode == ForwardMode.EXTEND:
177
201
  return self.extend_forward(q, k, v, input_metadata)
@@ -235,10 +235,12 @@ class Req:
235
235
  return
236
236
 
237
237
  last_token_id = self.output_ids[-1]
238
- if self.tokenizer is None:
239
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
240
- else:
241
- matched_eos = last_token_id == self.tokenizer.eos_token_id
238
+
239
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
240
+
241
+ if self.tokenizer is not None:
242
+ matched_eos |= last_token_id == self.tokenizer.eos_token_id
243
+
242
244
  if matched_eos and not self.sampling_params.ignore_eos:
243
245
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
244
246
  return
@@ -383,7 +385,7 @@ class ScheduleBatch:
383
385
 
384
386
  return out_cache_loc
385
387
 
386
- def batch_sampling_params(self, vocab_size, int_token_logit_bias):
388
+ def batch_sampling_params(self, vocab_size):
387
389
  device = "cuda"
388
390
  bs, reqs = self.batch_size(), self.reqs
389
391
  self.temperatures = torch.tensor(
@@ -419,15 +421,8 @@ class ScheduleBatch:
419
421
 
420
422
  # Handle logit bias but only allocate when needed
421
423
  self.logit_bias = None
422
- for i in range(bs):
423
- if reqs[i].sampling_params.dtype == "int":
424
- if self.logit_bias is None:
425
- self.logit_bias = torch.zeros(
426
- (bs, vocab_size), dtype=torch.float32, device=device
427
- )
428
- self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
429
424
 
430
- def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
425
+ def prepare_for_extend(self, vocab_size: int):
431
426
  bs = self.batch_size()
432
427
  reqs = self.reqs
433
428
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -466,7 +461,7 @@ class ScheduleBatch:
466
461
  self.out_cache_loc = out_cache_loc
467
462
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
468
463
 
469
- self.batch_sampling_params(vocab_size, int_token_logit_bias)
464
+ self.batch_sampling_params(vocab_size)
470
465
 
471
466
  def check_decode_mem(self):
472
467
  bs = self.batch_size()
@@ -507,7 +507,7 @@ class TokenizerManager:
507
507
  if obj.is_single:
508
508
  self.abort_request(obj.rid)
509
509
  else:
510
- for rid in obj.rids:
510
+ for rid in obj.rid:
511
511
  self.abort_request(rid)
512
512
 
513
513
  background_tasks = BackgroundTasks()
@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
54
54
  from sglang.srt.model_executor.model_runner import ModelRunner
55
55
  from sglang.srt.server_args import ServerArgs
56
56
  from sglang.srt.utils import (
57
- get_int_token_logit_bias,
58
57
  is_multimodal_model,
59
58
  set_random_seed,
60
59
  suppress_other_loggers,
@@ -132,9 +131,6 @@ class ModelTpServer:
132
131
  ),
133
132
  self.model_runner.req_to_token_pool.size - 1,
134
133
  )
135
- self.int_token_logit_bias = torch.tensor(
136
- get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
137
- )
138
134
  self.max_req_input_len = min(
139
135
  self.model_config.context_len - 1,
140
136
  self.max_total_num_tokens - 1,
@@ -442,9 +438,7 @@ class ModelTpServer:
442
438
 
443
439
  def forward_prefill_batch(self, batch: ScheduleBatch):
444
440
  # Build batch tensors
445
- batch.prepare_for_extend(
446
- self.model_config.vocab_size, self.int_token_logit_bias
447
- )
441
+ batch.prepare_for_extend(self.model_config.vocab_size)
448
442
 
449
443
  if self.model_runner.is_generation:
450
444
  # Forward and sample the next tokens
@@ -98,8 +98,8 @@ class CudaGraphRunner:
98
98
  self.req_pool_indices = torch.zeros(
99
99
  (self.max_bs,), dtype=torch.int32, device="cuda"
100
100
  )
101
- self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
102
- self.position_ids_offsets = torch.zeros(
101
+ self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
102
+ self.position_ids_offsets = torch.ones(
103
103
  (self.max_bs,), dtype=torch.int32, device="cuda"
104
104
  )
105
105
  self.out_cache_loc = torch.zeros(
@@ -107,9 +107,6 @@ class CudaGraphRunner:
107
107
  )
108
108
 
109
109
  # FlashInfer inputs
110
- self.flashinfer_workspace_buffer = (
111
- self.model_runner.flashinfer_workspace_buffers[0]
112
- )
113
110
  self.flashinfer_kv_indptr = torch.zeros(
114
111
  (self.max_bs + 1,), dtype=torch.int32, device="cuda"
115
112
  )
@@ -121,6 +118,23 @@ class CudaGraphRunner:
121
118
  self.flashinfer_kv_last_page_len = torch.ones(
122
119
  (self.max_bs,), dtype=torch.int32, device="cuda"
123
120
  )
121
+ if model_runner.sliding_window_size is None:
122
+ self.flashinfer_workspace_buffer = (
123
+ self.model_runner.flashinfer_workspace_buffer
124
+ )
125
+ else:
126
+ self.flashinfer_workspace_buffer = (
127
+ self.model_runner.flashinfer_workspace_buffer
128
+ )
129
+
130
+ self.flashinfer_kv_indptr = [
131
+ self.flashinfer_kv_indptr,
132
+ self.flashinfer_kv_indptr.clone(),
133
+ ]
134
+ self.flashinfer_kv_indices = [
135
+ self.flashinfer_kv_indices,
136
+ self.flashinfer_kv_indices.clone(),
137
+ ]
124
138
 
125
139
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
126
140
 
@@ -171,15 +185,32 @@ class CudaGraphRunner:
171
185
  use_tensor_cores = True
172
186
  else:
173
187
  use_tensor_cores = False
174
- flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
175
- self.flashinfer_workspace_buffer,
176
- "NHD",
177
- use_cuda_graph=True,
178
- use_tensor_cores=use_tensor_cores,
179
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
180
- paged_kv_indices_buffer=self.flashinfer_kv_indices,
181
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
182
- )
188
+ if self.model_runner.sliding_window_size is None:
189
+ flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
190
+ self.flashinfer_workspace_buffer,
191
+ "NHD",
192
+ use_cuda_graph=True,
193
+ use_tensor_cores=use_tensor_cores,
194
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
195
+ paged_kv_indices_buffer=self.flashinfer_kv_indices,
196
+ paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
197
+ )
198
+ else:
199
+ flashinfer_decode_wrapper = []
200
+ for i in range(2):
201
+ flashinfer_decode_wrapper.append(
202
+ BatchDecodeWithPagedKVCacheWrapper(
203
+ self.flashinfer_workspace_buffer,
204
+ "NHD",
205
+ use_cuda_graph=True,
206
+ use_tensor_cores=use_tensor_cores,
207
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
208
+ paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
209
+ paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
210
+ :bs
211
+ ],
212
+ )
213
+ )
183
214
  update_flashinfer_indices(
184
215
  ForwardMode.DECODE,
185
216
  self.model_runner,
@@ -201,7 +232,7 @@ class CudaGraphRunner:
201
232
  out_cache_loc=out_cache_loc,
202
233
  return_logprob=False,
203
234
  top_logprobs_nums=0,
204
- positions=(seq_lens - 1).to(torch.int64),
235
+ positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
205
236
  flashinfer_decode_wrapper=flashinfer_decode_wrapper,
206
237
  )
207
238
 
@@ -225,8 +256,8 @@ class CudaGraphRunner:
225
256
  index = bisect.bisect_left(self.batch_size_list, raw_bs)
226
257
  bs = self.batch_size_list[index]
227
258
  if bs != raw_bs:
228
- self.seq_lens.fill_(1)
229
- self.position_ids_offsets.zero_()
259
+ self.seq_lens.zero_()
260
+ self.position_ids_offsets.fill_(1)
230
261
  self.out_cache_loc.zero_()
231
262
 
232
263
  # Common inputs
@@ -16,7 +16,7 @@ limitations under the License.
16
16
  """ModelRunner runs the forward passes of the models."""
17
17
  from dataclasses import dataclass
18
18
  from enum import IntEnum, auto
19
- from typing import TYPE_CHECKING, List
19
+ from typing import TYPE_CHECKING, List, Optional
20
20
 
21
21
  import numpy as np
22
22
  import torch
@@ -194,6 +194,7 @@ class InputMetadata:
194
194
  if (
195
195
  forward_mode != ForwardMode.DECODE
196
196
  and int(torch.sum(ret.seq_lens)) > 4096
197
+ and model_runner.sliding_window_size is None
197
198
  ):
198
199
  flashinfer_use_ragged = True
199
200
  ret.init_flashinfer_handlers(
@@ -216,7 +217,10 @@ class InputMetadata:
216
217
  self.triton_max_extend_len = int(torch.max(extend_seq_lens))
217
218
 
218
219
  def init_flashinfer_handlers(
219
- self, model_runner, prefix_lens, flashinfer_use_ragged
220
+ self,
221
+ model_runner,
222
+ prefix_lens,
223
+ flashinfer_use_ragged,
220
224
  ):
221
225
  update_flashinfer_indices(
222
226
  self.forward_mode,
@@ -255,65 +259,135 @@ def update_flashinfer_indices(
255
259
  head_dim = model_runner.model_config.head_dim
256
260
  batch_size = len(req_pool_indices)
257
261
 
258
- if flashinfer_use_ragged:
259
- paged_kernel_lens = prefix_lens
260
- else:
261
- paged_kernel_lens = seq_lens
262
-
263
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
264
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
265
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
266
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
267
- kv_indices = torch.cat(
268
- [
269
- model_runner.req_to_token_pool.req_to_token[
270
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
271
- ]
272
- for i in range(batch_size)
273
- ],
274
- dim=0,
275
- ).contiguous()
276
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
277
-
278
- if forward_mode == ForwardMode.DECODE:
279
- # CUDA graph uses different flashinfer_decode_wrapper
280
- if flashinfer_decode_wrapper is None:
281
- flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
282
-
283
- flashinfer_decode_wrapper.end_forward()
284
- flashinfer_decode_wrapper.begin_forward(
285
- kv_indptr,
286
- kv_indices,
287
- kv_last_page_len,
288
- num_qo_heads,
289
- num_kv_heads,
290
- head_dim,
291
- 1,
292
- )
293
- else:
294
- # extend part
295
- qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
296
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
297
-
262
+ if model_runner.sliding_window_size is None:
298
263
  if flashinfer_use_ragged:
299
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
300
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
301
- qo_indptr,
302
- qo_indptr,
264
+ paged_kernel_lens = prefix_lens
265
+ else:
266
+ paged_kernel_lens = seq_lens
267
+
268
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
269
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
270
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
271
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
272
+ kv_indices = torch.cat(
273
+ [
274
+ model_runner.req_to_token_pool.req_to_token[
275
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
276
+ ]
277
+ for i in range(batch_size)
278
+ ],
279
+ dim=0,
280
+ ).contiguous()
281
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
282
+
283
+ if forward_mode == ForwardMode.DECODE:
284
+ # CUDA graph uses different flashinfer_decode_wrapper
285
+ if flashinfer_decode_wrapper is None:
286
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
287
+
288
+ flashinfer_decode_wrapper.end_forward()
289
+ flashinfer_decode_wrapper.begin_forward(
290
+ kv_indptr,
291
+ kv_indices,
292
+ kv_last_page_len,
303
293
  num_qo_heads,
304
294
  num_kv_heads,
305
295
  head_dim,
296
+ 1,
306
297
  )
298
+ else:
299
+ # extend part
300
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
301
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
302
+
303
+ if flashinfer_use_ragged:
304
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
305
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
306
+ qo_indptr,
307
+ qo_indptr,
308
+ num_qo_heads,
309
+ num_kv_heads,
310
+ head_dim,
311
+ )
307
312
 
308
- # cached part
309
- model_runner.flashinfer_prefill_wrapper_paged.end_forward()
310
- model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
311
- qo_indptr,
312
- kv_indptr,
313
- kv_indices,
314
- kv_last_page_len,
315
- num_qo_heads,
316
- num_kv_heads,
317
- head_dim,
318
- 1,
319
- )
313
+ # cached part
314
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
315
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
316
+ qo_indptr,
317
+ kv_indptr,
318
+ kv_indices,
319
+ kv_last_page_len,
320
+ num_qo_heads,
321
+ num_kv_heads,
322
+ head_dim,
323
+ 1,
324
+ )
325
+ else:
326
+ # window attention use paged only
327
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
328
+ for wrapper_id in range(2):
329
+ if wrapper_id == 0:
330
+ if forward_mode == ForwardMode.DECODE:
331
+ paged_kernel_lens = torch.minimum(
332
+ seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
333
+ )
334
+ else:
335
+ paged_kernel_lens = torch.minimum(
336
+ seq_lens,
337
+ torch.tensor(model_runner.sliding_window_size)
338
+ + seq_lens
339
+ - prefix_lens,
340
+ )
341
+ else:
342
+ paged_kernel_lens = seq_lens
343
+
344
+ kv_start_idx = seq_lens - paged_kernel_lens
345
+
346
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
347
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
348
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
349
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
350
+ kv_indices = torch.cat(
351
+ [
352
+ model_runner.req_to_token_pool.req_to_token[
353
+ req_pool_indices_cpu[i],
354
+ kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
355
+ ]
356
+ for i in range(batch_size)
357
+ ],
358
+ dim=0,
359
+ ).contiguous()
360
+
361
+ if forward_mode == ForwardMode.DECODE:
362
+ # CUDA graph uses different flashinfer_decode_wrapper
363
+ if flashinfer_decode_wrapper is None:
364
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
365
+
366
+ flashinfer_decode_wrapper[wrapper_id].end_forward()
367
+ flashinfer_decode_wrapper[wrapper_id].begin_forward(
368
+ kv_indptr,
369
+ kv_indices,
370
+ kv_last_page_len,
371
+ num_qo_heads,
372
+ num_kv_heads,
373
+ head_dim,
374
+ 1,
375
+ )
376
+ else:
377
+ # extend part
378
+ qo_indptr = torch.zeros(
379
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
380
+ )
381
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
382
+
383
+ model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
384
+ model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
385
+ qo_indptr,
386
+ kv_indptr,
387
+ kv_indices,
388
+ kv_last_page_len,
389
+ num_qo_heads,
390
+ num_kv_heads,
391
+ head_dim,
392
+ 1,
393
+ )
@@ -38,6 +38,7 @@ from vllm.distributed import (
38
38
  init_distributed_environment,
39
39
  initialize_model_parallel,
40
40
  )
41
+ from vllm.model_executor.model_loader import get_model
41
42
  from vllm.model_executor.models import ModelRegistry
42
43
 
43
44
  from sglang.global_config import global_config
@@ -53,7 +54,7 @@ from sglang.srt.server_args import ServerArgs
53
54
  from sglang.srt.utils import (
54
55
  get_available_gpu_memory,
55
56
  is_generation_model,
56
- is_llama3_405b_fp8,
57
+ is_llama3_405b_fp8_head_16,
57
58
  is_multimodal_model,
58
59
  monkey_patch_vllm_dummy_weight_loader,
59
60
  monkey_patch_vllm_p2p_access_check,
@@ -158,7 +159,7 @@ class ModelRunner:
158
159
  skip_tokenizer_init=True,
159
160
  )
160
161
 
161
- if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
162
+ if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
162
163
  # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
163
164
  self.model_config.hf_config.num_key_value_heads = 8
164
165
  vllm_model_config.hf_config.num_key_value_heads = 8
@@ -168,15 +169,6 @@ class ModelRunner:
168
169
  if self.model_config.model_overide_args is not None:
169
170
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
170
171
 
171
- if (
172
- self.server_args.efficient_weight_load
173
- and "llama" in self.server_args.model_path.lower()
174
- and self.server_args.quantization == "fp8"
175
- ):
176
- from sglang.srt.model_loader.model_loader import get_model
177
- else:
178
- from vllm.model_executor.model_loader import get_model
179
-
180
172
  self.model = get_model(
181
173
  model_config=vllm_model_config,
182
174
  device_config=device_config,
@@ -187,6 +179,11 @@ class ModelRunner:
187
179
  scheduler_config=None,
188
180
  cache_config=None,
189
181
  )
182
+ self.sliding_window_size = (
183
+ self.model.get_window_size()
184
+ if hasattr(self.model, "get_window_size")
185
+ else None
186
+ )
190
187
  self.is_generation = is_generation_model(
191
188
  self.model_config.hf_config.architectures
192
189
  )
@@ -296,6 +293,9 @@ class ModelRunner:
296
293
 
297
294
  def init_flashinfer(self):
298
295
  if self.server_args.disable_flashinfer:
296
+ assert (
297
+ self.sliding_window_size is None
298
+ ), "turn on flashinfer to support window attention"
299
299
  self.flashinfer_prefill_wrapper_ragged = None
300
300
  self.flashinfer_prefill_wrapper_paged = None
301
301
  self.flashinfer_decode_wrapper = None
@@ -309,20 +309,47 @@ class ModelRunner:
309
309
  else:
310
310
  use_tensor_cores = False
311
311
 
312
- self.flashinfer_workspace_buffers = torch.empty(
313
- 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
314
- )
315
- self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
316
- self.flashinfer_workspace_buffers[0], "NHD"
317
- )
318
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
319
- self.flashinfer_workspace_buffers[1], "NHD"
320
- )
321
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
322
- self.flashinfer_workspace_buffers[0],
323
- "NHD",
324
- use_tensor_cores=use_tensor_cores,
325
- )
312
+ if self.sliding_window_size is None:
313
+ self.flashinfer_workspace_buffer = torch.empty(
314
+ global_config.flashinfer_workspace_size,
315
+ dtype=torch.uint8,
316
+ device="cuda",
317
+ )
318
+ self.flashinfer_prefill_wrapper_ragged = (
319
+ BatchPrefillWithRaggedKVCacheWrapper(
320
+ self.flashinfer_workspace_buffer, "NHD"
321
+ )
322
+ )
323
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
324
+ self.flashinfer_workspace_buffer, "NHD"
325
+ )
326
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
327
+ self.flashinfer_workspace_buffer,
328
+ "NHD",
329
+ use_tensor_cores=use_tensor_cores,
330
+ )
331
+ else:
332
+ self.flashinfer_workspace_buffer = torch.empty(
333
+ global_config.flashinfer_workspace_size,
334
+ dtype=torch.uint8,
335
+ device="cuda",
336
+ )
337
+ self.flashinfer_prefill_wrapper_ragged = None
338
+ self.flashinfer_prefill_wrapper_paged = []
339
+ self.flashinfer_decode_wrapper = []
340
+ for i in range(2):
341
+ self.flashinfer_prefill_wrapper_paged.append(
342
+ BatchPrefillWithPagedKVCacheWrapper(
343
+ self.flashinfer_workspace_buffer, "NHD"
344
+ )
345
+ )
346
+ self.flashinfer_decode_wrapper.append(
347
+ BatchDecodeWithPagedKVCacheWrapper(
348
+ self.flashinfer_workspace_buffer,
349
+ "NHD",
350
+ use_tensor_cores=use_tensor_cores,
351
+ )
352
+ )
326
353
 
327
354
  def init_cuda_graphs(self):
328
355
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
@@ -358,7 +385,9 @@ class ModelRunner:
358
385
  return self.cuda_graph_runner.replay(batch)
359
386
 
360
387
  input_metadata = InputMetadata.from_schedule_batch(
361
- self, batch, ForwardMode.DECODE
388
+ self,
389
+ batch,
390
+ ForwardMode.DECODE,
362
391
  )
363
392
 
364
393
  return self.model.forward(
@@ -368,7 +397,9 @@ class ModelRunner:
368
397
  @torch.inference_mode()
369
398
  def forward_extend(self, batch: ScheduleBatch):
370
399
  input_metadata = InputMetadata.from_schedule_batch(
371
- self, batch, forward_mode=ForwardMode.EXTEND
400
+ self,
401
+ batch,
402
+ forward_mode=ForwardMode.EXTEND,
372
403
  )
373
404
  return self.model.forward(
374
405
  batch.input_ids, input_metadata.positions, input_metadata
@@ -377,7 +408,9 @@ class ModelRunner:
377
408
  @torch.inference_mode()
378
409
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
379
410
  input_metadata = InputMetadata.from_schedule_batch(
380
- self, batch, forward_mode=ForwardMode.EXTEND
411
+ self,
412
+ batch,
413
+ forward_mode=ForwardMode.EXTEND,
381
414
  )
382
415
  return self.model.forward(
383
416
  batch.input_ids,