sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -64,7 +64,7 @@ class AscendAttnBackend(AttentionBackend):
64
64
  if self.use_mla:
65
65
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
66
66
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
67
- self.native_attn = TorchNativeAttnBackend(model_runner)
67
+ self.native_attn = TorchNativeAttnBackend(model_runner)
68
68
  self.graph_metadata = {}
69
69
  self.max_context_len = model_runner.model_config.context_len
70
70
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
@@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
158
158
  self.graph_mode = True
159
159
 
160
160
  def get_cuda_graph_seq_len_fill_value(self):
161
- return 1
161
+ return 0
162
162
 
163
163
  def forward_extend(
164
164
  self,
@@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend):
167
167
  v,
168
168
  layer: RadixAttention,
169
169
  forward_batch: ForwardBatch,
170
- save_kv_cache=True,
170
+ save_kv_cache: bool = True,
171
171
  ):
172
172
  if not self.use_mla:
173
173
  if save_kv_cache:
@@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend):
180
180
 
181
181
  if self.use_fia:
182
182
  """FIA will support multi-bs in the later version of CANN"""
183
- q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
183
+ q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
184
184
  attn_output = torch.empty(
185
185
  (q.size(0), layer.tp_q_head_num, layer.v_head_dim),
186
186
  device=q.device,
@@ -208,26 +208,61 @@ class AscendAttnBackend(AttentionBackend):
208
208
  )
209
209
 
210
210
  else:
211
- query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
212
- attn_output = torch.empty(
213
- (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
214
- dtype=query.dtype,
215
- device=query.device,
216
- )
211
+ if layer.qk_head_dim <= 128:
212
+ query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
213
+ attn_output = torch.empty(
214
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
215
+ dtype=query.dtype,
216
+ device=query.device,
217
+ )
217
218
 
218
- torch_npu._npu_flash_attention_qlens(
219
- query=query,
220
- key_cache=k_cache,
221
- value_cache=v_cache,
222
- mask=self.mask,
223
- block_table=self.forward_metadata.block_tables,
224
- seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
225
- context_lens=self.forward_metadata.seq_lens_cpu_int,
226
- scale_value=layer.scaling,
227
- num_heads=layer.tp_q_head_num,
228
- num_kv_heads=layer.tp_k_head_num,
229
- out=attn_output,
230
- )
219
+ torch_npu._npu_flash_attention_qlens(
220
+ query=query,
221
+ key_cache=k_cache,
222
+ value_cache=v_cache,
223
+ mask=self.mask,
224
+ block_table=self.forward_metadata.block_tables,
225
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
226
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
227
+ scale_value=layer.scaling,
228
+ num_heads=layer.tp_q_head_num,
229
+ num_kv_heads=layer.tp_k_head_num,
230
+ out=attn_output,
231
+ )
232
+ else:
233
+ if layer.qk_head_dim != layer.v_head_dim:
234
+ attn_output = q.new_empty(
235
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
236
+ )
237
+ else:
238
+ attn_output = torch.empty_like(q)
239
+
240
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
241
+
242
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
243
+ o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
244
+
245
+ causal = True
246
+ if (
247
+ layer.is_cross_attention
248
+ or layer.attn_type == AttentionType.ENCODER_ONLY
249
+ ):
250
+ causal = False
251
+
252
+ self.native_attn._run_sdpa_forward_extend(
253
+ q_,
254
+ o_,
255
+ k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
256
+ v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),
257
+ forward_batch.req_to_token_pool.req_to_token,
258
+ forward_batch.req_pool_indices,
259
+ forward_batch.seq_lens,
260
+ forward_batch.extend_prefix_lens,
261
+ forward_batch.extend_seq_lens,
262
+ scaling=layer.scaling,
263
+ enable_gqa=use_gqa,
264
+ causal=causal,
265
+ )
231
266
  else:
232
267
  assert (
233
268
  layer.qk_head_dim != layer.v_head_dim
@@ -253,6 +288,136 @@ class AscendAttnBackend(AttentionBackend):
253
288
 
254
289
  return attn_output
255
290
 
291
+ def forward_decode_graph(
292
+ self,
293
+ q: torch.Tensor,
294
+ k: torch.Tensor,
295
+ v: torch.Tensor,
296
+ layer: RadixAttention,
297
+ forward_batch: ForwardBatch,
298
+ save_kv_cache: bool = True,
299
+ q_rope: Optional[torch.Tensor] = None,
300
+ k_rope: Optional[torch.Tensor] = None,
301
+ ):
302
+ if save_kv_cache:
303
+ if self.use_mla:
304
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
305
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
306
+ forward_batch.token_to_kv_pool.set_kv_buffer(
307
+ layer, forward_batch.out_cache_loc, k, k_rope
308
+ )
309
+ else:
310
+ forward_batch.token_to_kv_pool.set_kv_buffer(
311
+ layer, forward_batch.out_cache_loc, k, v
312
+ )
313
+
314
+ if not self.use_mla:
315
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
316
+ layer.layer_id
317
+ ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
318
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
319
+ layer.layer_id
320
+ ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
321
+ query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
322
+ if self.forward_metadata.seq_lens_cpu_int is None:
323
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
324
+ else:
325
+ actual_seq_len_kv = (
326
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
327
+ )
328
+ num_tokens = query.shape[0]
329
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
330
+ query,
331
+ k_cache,
332
+ v_cache,
333
+ block_table=self.forward_metadata.block_tables,
334
+ block_size=self.page_size,
335
+ num_heads=layer.tp_q_head_num,
336
+ num_key_value_heads=layer.tp_k_head_num,
337
+ input_layout="BSH",
338
+ scale=layer.scaling,
339
+ actual_seq_lengths_kv=actual_seq_len_kv,
340
+ )
341
+ output = torch.empty(
342
+ (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
343
+ dtype=q.dtype,
344
+ device=q.device,
345
+ )
346
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
347
+ torch_npu.npu_fused_infer_attention_score.out(
348
+ query,
349
+ k_cache,
350
+ v_cache,
351
+ block_table=self.forward_metadata.block_tables,
352
+ block_size=self.page_size,
353
+ num_heads=layer.tp_q_head_num,
354
+ num_key_value_heads=layer.tp_k_head_num,
355
+ input_layout="BSH",
356
+ scale=layer.scaling,
357
+ actual_seq_lengths_kv=actual_seq_len_kv,
358
+ workspace=workspace,
359
+ out=[output, softmax_lse],
360
+ )
361
+ return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
362
+ else:
363
+ c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
364
+ k_rope_cache = k_rope.view(
365
+ -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
366
+ )
367
+ c_kv_cache = c_kv.view(
368
+ -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
369
+ )
370
+
371
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
372
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
373
+ if self.forward_metadata.seq_lens_cpu_int is None:
374
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
375
+ else:
376
+ actual_seq_len_kv = (
377
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
378
+ )
379
+
380
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
381
+ q_nope,
382
+ c_kv_cache,
383
+ c_kv_cache,
384
+ query_rope=q_rope,
385
+ key_rope=k_rope_cache,
386
+ num_heads=layer.tp_q_head_num,
387
+ num_key_value_heads=layer.tp_k_head_num,
388
+ block_table=self.forward_metadata.block_tables,
389
+ block_size=self.page_size,
390
+ input_layout="BNSD",
391
+ scale=layer.scaling,
392
+ actual_seq_lengths_kv=actual_seq_len_kv,
393
+ antiquant_mode=0,
394
+ antiquant_scale=None,
395
+ sparse_mode=0,
396
+ )
397
+ output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
398
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
399
+
400
+ torch_npu.npu_fused_infer_attention_score.out(
401
+ q_nope,
402
+ c_kv_cache,
403
+ c_kv_cache,
404
+ query_rope=q_rope,
405
+ key_rope=k_rope_cache,
406
+ num_heads=layer.tp_q_head_num,
407
+ num_key_value_heads=layer.tp_k_head_num,
408
+ block_table=self.forward_metadata.block_tables,
409
+ block_size=self.page_size,
410
+ input_layout="BNSD",
411
+ scale=layer.scaling,
412
+ actual_seq_lengths_kv=actual_seq_len_kv,
413
+ antiquant_mode=0,
414
+ antiquant_scale=None,
415
+ sparse_mode=0,
416
+ workspace=workspace,
417
+ out=[output, softmax_lse],
418
+ )
419
+ return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
420
+
256
421
  def forward_decode(
257
422
  self,
258
423
  q: torch.Tensor,
@@ -260,106 +425,74 @@ class AscendAttnBackend(AttentionBackend):
260
425
  v: torch.Tensor,
261
426
  layer: RadixAttention,
262
427
  forward_batch: ForwardBatch,
263
- save_kv_cache: bool = False,
428
+ save_kv_cache: bool = True,
264
429
  # For multi-head latent attention
265
430
  q_rope: Optional[torch.Tensor] = None,
266
431
  k_rope: Optional[torch.Tensor] = None,
267
432
  ):
433
+ if self.graph_mode:
434
+ return self.forward_decode_graph(
435
+ q,
436
+ k,
437
+ v,
438
+ layer,
439
+ forward_batch,
440
+ save_kv_cache,
441
+ q_rope=q_rope,
442
+ k_rope=k_rope,
443
+ )
444
+
268
445
  if not self.use_mla:
269
446
  if save_kv_cache:
270
447
  forward_batch.token_to_kv_pool.set_kv_buffer(
271
448
  layer, forward_batch.out_cache_loc, k, v
272
449
  )
273
450
  num_tokens = q.shape[0]
274
- if self.graph_mode:
275
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
276
- layer.layer_id
277
- ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
278
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
279
- layer.layer_id
280
- ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
281
- query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
282
- workspace = (
283
- torch_npu._npu_fused_infer_attention_score_get_max_workspace(
284
- query,
285
- k_cache,
286
- v_cache,
287
- block_table=self.forward_metadata.block_tables,
288
- block_size=self.page_size,
289
- num_heads=layer.tp_q_head_num,
290
- num_key_value_heads=layer.tp_k_head_num,
291
- input_layout="BSH",
292
- scale=layer.scaling,
293
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
294
- )
295
- )
296
- attn_output = torch.empty(
297
- (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
298
- dtype=q.dtype,
299
- device=q.device,
300
- )
301
- softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
302
- torch_npu.npu_fused_infer_attention_score.out(
303
- query,
304
- k_cache,
305
- v_cache,
306
- block_table=self.forward_metadata.block_tables,
307
- block_size=self.page_size,
451
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
452
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
453
+ if self.use_fia:
454
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
455
+ q.view(
456
+ forward_batch.batch_size,
457
+ -1,
458
+ layer.tp_q_head_num,
459
+ layer.qk_head_dim,
460
+ ),
461
+ k_cache.view(
462
+ -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
463
+ ),
464
+ v_cache.view(
465
+ -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
466
+ ),
308
467
  num_heads=layer.tp_q_head_num,
309
468
  num_key_value_heads=layer.tp_k_head_num,
310
- input_layout="BSH",
469
+ input_layout="BSND",
470
+ atten_mask=None,
471
+ block_size=self.page_size,
472
+ block_table=self.forward_metadata.block_tables,
473
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
311
474
  scale=layer.scaling,
312
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
313
- workspace=workspace,
314
- out=[attn_output, softmax_lse],
315
475
  )
316
476
  else:
317
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
318
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
319
- layer.layer_id
477
+ query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
478
+ num_tokens = query.shape[0]
479
+ attn_output = torch.empty(
480
+ (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
481
+ dtype=query.dtype,
482
+ device=query.device,
320
483
  )
321
- if self.use_fia:
322
- attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
323
- q.view(
324
- forward_batch.batch_size,
325
- -1,
326
- layer.tp_q_head_num,
327
- layer.qk_head_dim,
328
- ),
329
- k_cache.view(
330
- -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
331
- ),
332
- v_cache.view(
333
- -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
334
- ),
335
- num_heads=layer.tp_q_head_num,
336
- num_key_value_heads=layer.tp_k_head_num,
337
- input_layout="BSND",
338
- atten_mask=None,
339
- block_size=self.page_size,
340
- block_table=self.forward_metadata.block_tables,
341
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
342
- scale=layer.scaling,
343
- )
344
- else:
345
- query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
346
- attn_output = torch.empty(
347
- (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
348
- dtype=query.dtype,
349
- device=query.device,
350
- )
351
484
 
352
- torch_npu._npu_paged_attention(
353
- query=query,
354
- key_cache=k_cache,
355
- value_cache=v_cache,
356
- num_heads=layer.tp_q_head_num,
357
- num_kv_heads=layer.tp_k_head_num,
358
- scale_value=layer.scaling,
359
- block_table=self.forward_metadata.block_tables,
360
- context_lens=self.forward_metadata.seq_lens_cpu_int,
361
- out=attn_output,
362
- )
485
+ torch_npu._npu_paged_attention(
486
+ query=query,
487
+ key_cache=k_cache,
488
+ value_cache=v_cache,
489
+ num_heads=layer.tp_q_head_num,
490
+ num_kv_heads=layer.tp_k_head_num,
491
+ scale_value=layer.scaling,
492
+ block_table=self.forward_metadata.block_tables,
493
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
494
+ out=attn_output,
495
+ )
363
496
  return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
364
497
  else:
365
498
  if save_kv_cache:
@@ -370,9 +503,7 @@ class AscendAttnBackend(AttentionBackend):
370
503
  kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
371
504
  k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
372
505
 
373
- if (self.graph_mode or self.use_fia) and (
374
- layer.tp_q_head_num // layer.tp_k_head_num
375
- ) >= 8:
506
+ if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
376
507
  """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
377
508
  kv_c = kv_c.view(
378
509
  -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
@@ -5,6 +5,7 @@ import torch
5
5
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
6
  from sglang.srt.layers.radix_attention import RadixAttention
7
7
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8
+ from sglang.srt.model_executor.model_runner import ModelRunner
8
9
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
9
10
 
10
11
 
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
12
13
  """Support different backends for prefill and decode."""
13
14
 
14
15
  def __init__(
15
- self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
16
+ self,
17
+ model_runner: ModelRunner,
18
+ prefill_backend: AttentionBackend,
19
+ decode_backend: AttentionBackend,
16
20
  ):
21
+ self.model_runner = model_runner
17
22
  self.prefill_backend = prefill_backend
18
23
  self.decode_backend = decode_backend
19
24
 
20
25
  def init_forward_metadata(self, forward_batch: ForwardBatch):
21
- if forward_batch.forward_mode.is_decode():
26
+ if forward_batch.forward_mode.is_decode_or_idle():
22
27
  self.decode_backend.init_forward_metadata(forward_batch)
23
28
  else:
24
29
  self.prefill_backend.init_forward_metadata(forward_batch)
25
30
 
26
31
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
27
32
  self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
33
+ if self.model_runner.server_args.speculative_algorithm is not None:
34
+ # When speculative decoding is enabled, we also need to initialize the
35
+ # prefill backend's cuda graph state to support target_verify.
36
+ self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
28
37
 
29
38
  def init_forward_metadata_capture_cuda_graph(
30
39
  self,
@@ -36,15 +45,26 @@ class HybridAttnBackend(AttentionBackend):
36
45
  forward_mode: ForwardMode,
37
46
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
38
47
  ):
39
- self.decode_backend.init_forward_metadata_capture_cuda_graph(
40
- bs,
41
- num_tokens,
42
- req_pool_indices,
43
- seq_lens,
44
- encoder_lens,
45
- forward_mode,
46
- spec_info,
47
- )
48
+ if forward_mode.is_decode_or_idle():
49
+ self.decode_backend.init_forward_metadata_capture_cuda_graph(
50
+ bs,
51
+ num_tokens,
52
+ req_pool_indices,
53
+ seq_lens,
54
+ encoder_lens,
55
+ forward_mode,
56
+ spec_info,
57
+ )
58
+ else:
59
+ self.prefill_backend.init_forward_metadata_capture_cuda_graph(
60
+ bs,
61
+ num_tokens,
62
+ req_pool_indices,
63
+ seq_lens,
64
+ encoder_lens,
65
+ forward_mode,
66
+ spec_info,
67
+ )
48
68
 
49
69
  def init_forward_metadata_replay_cuda_graph(
50
70
  self,
@@ -57,16 +77,28 @@ class HybridAttnBackend(AttentionBackend):
57
77
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
58
78
  seq_lens_cpu: Optional[torch.Tensor],
59
79
  ):
60
- self.decode_backend.init_forward_metadata_replay_cuda_graph(
61
- bs,
62
- req_pool_indices,
63
- seq_lens,
64
- seq_lens_sum,
65
- encoder_lens,
66
- forward_mode,
67
- spec_info,
68
- seq_lens_cpu,
69
- )
80
+ if forward_mode.is_decode_or_idle():
81
+ self.decode_backend.init_forward_metadata_replay_cuda_graph(
82
+ bs,
83
+ req_pool_indices,
84
+ seq_lens,
85
+ seq_lens_sum,
86
+ encoder_lens,
87
+ forward_mode,
88
+ spec_info,
89
+ seq_lens_cpu,
90
+ )
91
+ else:
92
+ self.prefill_backend.init_forward_metadata_replay_cuda_graph(
93
+ bs,
94
+ req_pool_indices,
95
+ seq_lens,
96
+ seq_lens_sum,
97
+ encoder_lens,
98
+ forward_mode,
99
+ spec_info,
100
+ seq_lens_cpu,
101
+ )
70
102
 
71
103
  def get_cuda_graph_seq_len_fill_value(self):
72
104
  return self.decode_backend.get_cuda_graph_seq_len_fill_value()
@@ -51,6 +51,7 @@ class TRTLLMMLADecodeMetadata:
51
51
 
52
52
  workspace: Optional[torch.Tensor] = None
53
53
  block_kv_indices: Optional[torch.Tensor] = None
54
+ max_seq_len: Optional[int] = None
54
55
 
55
56
 
56
57
  class TRTLLMMLABackend(FlashInferMLAAttnBackend):
@@ -207,8 +208,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
207
208
  )
208
209
 
209
210
  # Custom fast-path for decode/idle.
210
- max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
211
- block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
211
+ # Capture with full width so future longer sequences are safe during replay
212
+ max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
213
+ block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]
212
214
 
213
215
  create_flashmla_kv_indices_triton[(bs,)](
214
216
  self.req_to_token,
@@ -217,13 +219,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
217
219
  None,
218
220
  block_kv_indices,
219
221
  self.req_to_token.stride(0),
220
- max_seqlen_pad,
222
+ max_blocks_per_seq,
221
223
  NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
222
224
  PAGED_SIZE=self.page_size,
223
225
  )
224
226
 
227
+ # Record the true maximum sequence length for this capture batch so that
228
+ # the kernel launch path (which requires an int not a tensor) can reuse
229
+ # it safely during both capture and replay.
230
+ max_seq_len_val = int(seq_lens.max().item())
231
+
225
232
  metadata = TRTLLMMLADecodeMetadata(
226
- self.decode_cuda_graph_workspace, block_kv_indices
233
+ self.decode_cuda_graph_workspace,
234
+ block_kv_indices,
235
+ max_seq_len_val,
227
236
  )
228
237
  self.decode_cuda_graph_metadata[bs] = metadata
229
238
  self.forward_metadata = metadata
@@ -268,6 +277,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
268
277
  PAGED_SIZE=self.page_size,
269
278
  )
270
279
 
280
+ # Update stored max_seq_len so subsequent kernel calls use the correct value
281
+ # Prefer CPU tensor to avoid GPU synchronization when available.
282
+ if seq_lens_cpu is not None:
283
+ metadata.max_seq_len = int(seq_lens_cpu.max().item())
284
+ else:
285
+ metadata.max_seq_len = int(seq_lens.max().item())
286
+
271
287
  def get_cuda_graph_seq_len_fill_value(self) -> int:
272
288
  """Get the fill value for sequence lengths in CUDA graph."""
273
289
  return 1
@@ -295,8 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
295
311
  forward_batch.seq_lens.device,
296
312
  )
297
313
 
314
+ max_seq_len_val = int(max_seq)
298
315
  self.forward_metadata = TRTLLMMLADecodeMetadata(
299
- self.workspace_buffer, block_kv_indices
316
+ self.workspace_buffer, block_kv_indices, max_seq_len_val
300
317
  )
301
318
  forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
302
319
 
@@ -471,14 +488,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
471
488
  qk_rope_head_dim=self.qk_rope_head_dim,
472
489
  block_tables=metadata.block_kv_indices,
473
490
  seq_lens=forward_batch.seq_lens.to(torch.int32),
474
- max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
491
+ max_seq_len=metadata.max_seq_len,
475
492
  bmm1_scale=bmm1_scale,
476
493
  )
477
494
 
478
- # Extract value projection part and reshape
479
- raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
480
- output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
481
-
495
+ # Reshape output directly without slicing
496
+ output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
482
497
  return output
483
498
 
484
499
 
@@ -53,7 +53,7 @@ elif _is_hip:
53
53
 
54
54
  logger = logging.getLogger(__name__)
55
55
 
56
- if is_npu():
56
+ if _is_npu:
57
57
  import torch_npu
58
58
 
59
59
 
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
266
266
  out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
267
267
  return out
268
268
 
269
+ def forward_npu(
270
+ self,
271
+ x: torch.Tensor,
272
+ residual: Optional[torch.Tensor] = None,
273
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
274
+ orig_dtype = x.dtype
275
+ if residual is not None:
276
+ x = x + residual
277
+ residual = x
269
278
 
270
- class Gemma3RMSNorm(nn.Module):
279
+ x = x.float()
280
+ variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
281
+ x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
282
+ x = x * (1.0 + self.weight.float())
283
+ x = x.to(orig_dtype)
284
+ return x if residual is None else (x, residual)
285
+
286
+
287
+ class Gemma3RMSNorm(CustomOp):
271
288
  def __init__(self, dim: int, eps: float = 1e-6):
272
289
  super().__init__()
273
290
  self.eps = eps
274
291
  self.weight = nn.Parameter(torch.zeros(dim))
292
+ # Re-dispatch
275
293
 
276
294
  def _norm(self, x):
277
295
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
278
296
 
279
- def forward(self, x):
297
+ def forward_native(self, x):
280
298
  output = self._norm(x.float())
281
299
  # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
282
300
  # See https://github.com/huggingface/transformers/pull/29402
283
301
  output = output * (1.0 + self.weight.float())
284
302
  return output.type_as(x)
285
303
 
304
+ def forward_cuda(self, x):
305
+ return self.forward_native(x)
306
+
307
+ def forward_npu(self, x):
308
+ output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
309
+ return output
310
+
286
311
  def extra_repr(self):
287
312
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
288
313