sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -59,6 +59,19 @@ class AscendAttnBackend(AttentionBackend):
59
59
  )
60
60
  self.mask_len = max_seq_len
61
61
 
62
+ def get_verify_buffers_to_fill_after_draft(self):
63
+ """
64
+ Return buffers for verify attention kernels that needs to be filled after draft.
65
+
66
+ Typically, these are tree mask and position buffers.
67
+ """
68
+ return [None, None]
69
+
70
+ def update_verify_buffers_to_fill_after_draft(
71
+ self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
72
+ ):
73
+ pass
74
+
62
75
  def __init__(self, model_runner: ModelRunner):
63
76
  super().__init__()
64
77
  self.forward_metadata = None
@@ -87,15 +100,22 @@ class AscendAttnBackend(AttentionBackend):
87
100
  device=model_runner.device,
88
101
  )
89
102
  )
103
+ self.speculative_num_draft_tokens = (
104
+ model_runner.server_args.speculative_num_draft_tokens
105
+ )
106
+ self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu()
107
+ self.mtp_mask = ~self.mtp_mask
90
108
 
91
109
  def init_forward_metadata(self, forward_batch: ForwardBatch):
92
110
  """Init the metadata for a forward pass."""
93
111
  tp_size = get_attention_tp_size()
94
112
  self.forward_metadata = ForwardMetadata()
95
-
113
+ seq_lens_max = forward_batch.seq_lens.max()
114
+ if forward_batch.forward_mode.is_target_verify():
115
+ seq_lens_max += self.speculative_num_draft_tokens
96
116
  self.forward_metadata.block_tables = (
97
117
  forward_batch.req_to_token_pool.req_to_token[
98
- forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
118
+ forward_batch.req_pool_indices, :seq_lens_max
99
119
  ][:, :: self.page_size]
100
120
  // self.page_size
101
121
  )
@@ -104,16 +124,23 @@ class AscendAttnBackend(AttentionBackend):
104
124
  forward_batch.extend_seq_lens.cpu().int()
105
125
  )
106
126
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
127
+ if (
128
+ not forward_batch.forward_mode.is_draft_extend_v2()
129
+ and not forward_batch.forward_mode.is_draft_extend()
130
+ and not forward_batch.forward_mode.is_target_verify()
131
+ ):
132
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
133
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
107
134
 
108
- seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
109
- self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
135
+ if forward_batch.forward_mode.is_target_verify():
136
+ self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens
110
137
 
111
138
  self.graph_mode = False
112
139
 
113
140
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
114
141
  self.graph_metadata = {
115
142
  "block_tables": torch.empty(
116
- (max_bs, self.max_context_len // self.page_size),
143
+ (max_bs, (self.max_context_len + self.page_size - 1) // self.page_size),
117
144
  dtype=torch.int32,
118
145
  device=self.device,
119
146
  ),
@@ -156,6 +183,8 @@ class AscendAttnBackend(AttentionBackend):
156
183
  ):
157
184
  metadata = self.graph_metadata[bs]
158
185
  max_len = seq_lens_cpu[:bs].max().item()
186
+ if forward_mode.is_target_verify():
187
+ max_len += self.speculative_num_draft_tokens
159
188
  max_seq_pages = (max_len + self.page_size - 1) // self.page_size
160
189
 
161
190
  metadata.block_tables[:bs, :max_seq_pages].copy_(
@@ -257,6 +286,25 @@ class AscendAttnBackend(AttentionBackend):
257
286
  k_rope,
258
287
  topk_indices,
259
288
  )
289
+ if (
290
+ forward_batch.forward_mode.is_target_verify()
291
+ or forward_batch.forward_mode.is_draft_extend()
292
+ or forward_batch.forward_mode.is_draft_extend_v2()
293
+ ):
294
+
295
+ if is_mla_preprocess_enabled():
296
+ save_kv_cache = False
297
+ return self.forward_mtp(
298
+ q,
299
+ k,
300
+ v,
301
+ layer,
302
+ forward_batch,
303
+ save_kv_cache,
304
+ q_rope=q_rope,
305
+ k_rope=k_rope,
306
+ )
307
+
260
308
  if not self.use_mla:
261
309
  if save_kv_cache:
262
310
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -393,6 +441,118 @@ class AscendAttnBackend(AttentionBackend):
393
441
  )
394
442
  return attn_output
395
443
 
444
+ def forward_mtp(
445
+ self,
446
+ q,
447
+ k,
448
+ v,
449
+ layer: RadixAttention,
450
+ forward_batch: ForwardBatch,
451
+ save_kv_cache: bool,
452
+ q_rope: Optional[torch.Tensor] = None,
453
+ k_rope: Optional[torch.Tensor] = None,
454
+ ):
455
+ if save_kv_cache:
456
+ if self.use_mla:
457
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
458
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
459
+ forward_batch.token_to_kv_pool.set_kv_buffer(
460
+ layer, forward_batch.out_cache_loc, k, k_rope
461
+ )
462
+ else:
463
+ forward_batch.token_to_kv_pool.set_kv_buffer(
464
+ layer, forward_batch.out_cache_loc, k, v
465
+ )
466
+
467
+ c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
468
+ k_rope_cache = k_rope.view(
469
+ -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
470
+ )
471
+ c_kv_cache = c_kv.view(
472
+ -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
473
+ )
474
+
475
+ q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank)
476
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
477
+ if not self.graph_mode:
478
+ num_token_padding = q.shape[0]
479
+ q_nope = q_nope[: forward_batch.num_token_non_padded_cpu]
480
+ q_rope = q_rope[: forward_batch.num_token_non_padded_cpu]
481
+ if self.forward_metadata.seq_lens_cpu_int is None:
482
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_list
483
+ else:
484
+ actual_seq_lengths_kv = (
485
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
486
+ )
487
+ if forward_batch.forward_mode.is_draft_extend():
488
+ actual_seq_lengths = (
489
+ np.array(forward_batch.extend_seq_lens_cpu).cumsum().tolist()
490
+ )
491
+ else:
492
+ actual_seq_lengths = np.arange(
493
+ self.speculative_num_draft_tokens,
494
+ self.speculative_num_draft_tokens + q_nope.shape[0],
495
+ self.speculative_num_draft_tokens,
496
+ )
497
+
498
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
499
+ q_nope,
500
+ c_kv_cache,
501
+ c_kv_cache,
502
+ query_rope=q_rope,
503
+ key_rope=k_rope_cache,
504
+ num_heads=layer.tp_q_head_num,
505
+ num_key_value_heads=layer.tp_k_head_num,
506
+ input_layout="TND",
507
+ scale=layer.scaling,
508
+ antiquant_mode=0,
509
+ antiquant_scale=None,
510
+ block_table=self.forward_metadata.block_tables,
511
+ block_size=self.page_size,
512
+ sparse_mode=3,
513
+ atten_mask=self.mtp_mask,
514
+ actual_seq_lengths=actual_seq_lengths,
515
+ actual_seq_lengths_kv=actual_seq_lengths_kv,
516
+ )
517
+ attn_output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
518
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
519
+ torch_npu.npu_fused_infer_attention_score.out(
520
+ q_nope,
521
+ c_kv_cache,
522
+ c_kv_cache,
523
+ query_rope=q_rope,
524
+ key_rope=k_rope_cache,
525
+ num_heads=layer.tp_q_head_num,
526
+ num_key_value_heads=layer.tp_k_head_num,
527
+ input_layout="TND",
528
+ scale=layer.scaling,
529
+ antiquant_mode=0,
530
+ antiquant_scale=None,
531
+ block_table=self.forward_metadata.block_tables,
532
+ block_size=self.page_size,
533
+ sparse_mode=3,
534
+ atten_mask=self.mtp_mask,
535
+ actual_seq_lengths=actual_seq_lengths,
536
+ actual_seq_lengths_kv=actual_seq_lengths_kv,
537
+ workspace=workspace,
538
+ out=[attn_output, softmax_lse],
539
+ )
540
+ attn_output = attn_output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
541
+ if (
542
+ not self.graph_mode
543
+ and forward_batch.num_token_non_padded_cpu != num_token_padding
544
+ ):
545
+ attn_output = torch.cat(
546
+ [
547
+ attn_output,
548
+ attn_output.new_zeros(
549
+ num_token_padding - attn_output.shape[0], *attn_output.shape[1:]
550
+ ),
551
+ ],
552
+ dim=0,
553
+ )
554
+ return attn_output
555
+
396
556
  def forward_decode_graph(
397
557
  self,
398
558
  q: torch.Tensor,
@@ -690,3 +850,71 @@ class AscendAttnBackend(AttentionBackend):
690
850
  out=attn_output,
691
851
  )
692
852
  return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
853
+
854
+
855
+ class AscendAttnMultiStepDraftBackend:
856
+ """
857
+ Wrap multiple Ascend attention backends as one for multiple consecutive
858
+ draft decoding steps
859
+ """
860
+
861
+ def __init__(
862
+ self,
863
+ model_runner: ModelRunner,
864
+ topk: int,
865
+ speculative_num_steps: int,
866
+ ):
867
+ self.topk = topk
868
+ self.speculative_num_steps = speculative_num_steps
869
+
870
+ self.attn_backends = []
871
+ for _ in range(self.speculative_num_steps):
872
+ self.attn_backends.append(AscendAttnBackend(model_runner))
873
+
874
+ def common_template(self, forward_batch: ForwardBatch, call_fn: int):
875
+ assert forward_batch.spec_info is not None
876
+
877
+ for i in range(self.speculative_num_steps - 1):
878
+ call_fn(i, forward_batch)
879
+
880
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
881
+ def call_fn(i, forward_batch):
882
+ assert forward_batch.spec_info is not None
883
+ self.attn_backends[i].init_forward_metadata(forward_batch)
884
+
885
+ self.common_template(forward_batch, call_fn)
886
+
887
+ def init_cuda_graph_state(self, max_bs, max_num_tokens):
888
+ for i in range(self.speculative_num_steps):
889
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
890
+
891
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
892
+ def call_fn(i, forward_batch):
893
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
894
+ forward_batch.batch_size,
895
+ forward_batch.batch_size * self.topk,
896
+ forward_batch.req_pool_indices,
897
+ forward_batch.seq_lens,
898
+ encoder_lens=None,
899
+ forward_mode=ForwardMode.DECODE,
900
+ spec_info=forward_batch.spec_info,
901
+ )
902
+
903
+ self.common_template(forward_batch, call_fn)
904
+
905
+ def init_forward_metadata_replay_cuda_graph(
906
+ self, forward_batch: ForwardBatch, bs: int
907
+ ):
908
+ def call_fn(i, forward_batch):
909
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
910
+ bs,
911
+ forward_batch.req_pool_indices,
912
+ forward_batch.seq_lens,
913
+ seq_lens_sum=-1,
914
+ encoder_lens=None,
915
+ forward_mode=ForwardMode.DECODE,
916
+ spec_info=forward_batch.spec_info,
917
+ seq_lens_cpu=None,
918
+ )
919
+
920
+ self.common_template(forward_batch, call_fn)
@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
189
189
  from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
190
190
  GDNAttnBackend,
191
191
  HybridLinearAttnBackend,
192
+ KimiLinearAttnBackend,
192
193
  Mamba2AttnBackend,
193
194
  )
194
195
  from sglang.srt.utils import is_blackwell, is_npu
@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
207
208
  linear_attn_backend = GDNAttnBackend(runner)
208
209
  elif runner.mamba2_config is not None:
209
210
  linear_attn_backend = Mamba2AttnBackend(runner)
211
+ elif runner.kimi_linear_config is not None:
212
+ linear_attn_backend = KimiLinearAttnBackend(runner)
210
213
  else:
211
214
  raise ValueError(
212
215
  "Expected hybrid GDN or NemotronH models, but got unknown model."
@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
21
21
  @triton.heuristics(
22
22
  {
23
23
  "USE_G": lambda args: args["g"] is not None,
24
+ "USE_GK": lambda args: args["gk"] is not None,
24
25
  "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
25
26
  "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
26
27
  "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
44
45
  w,
45
46
  v_new,
46
47
  g,
48
+ gk,
47
49
  h,
48
50
  h0,
49
51
  ht,
@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
57
59
  BT: tl.constexpr,
58
60
  BV: tl.constexpr,
59
61
  USE_G: tl.constexpr,
62
+ USE_GK: tl.constexpr,
60
63
  USE_INITIAL_STATE: tl.constexpr,
61
64
  STORE_FINAL_STATE: tl.constexpr,
62
65
  SAVE_NEW_VALUE: tl.constexpr,
@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
86
89
  b_h4 = tl.zeros([64, BV], dtype=tl.float32)
87
90
 
88
91
  # calculate offset
89
- h += (boh * H + i_h) * K * V
90
- v += (bos * H + i_h) * V
91
- k += (bos * Hg + i_h // (H // Hg)) * K
92
- w += (bos * H + i_h) * K
92
+ h += ((boh * H + i_h) * K * V).to(tl.int64)
93
+ v += ((bos * H + i_h) * V).to(tl.int64)
94
+ k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
95
+ w += ((bos * H + i_h) * K).to(tl.int64)
93
96
  if SAVE_NEW_VALUE:
94
- v_new += (bos * H + i_h) * V
97
+ v_new += ((bos * H + i_h) * V).to(tl.int64)
95
98
  stride_v = H * V
96
99
  stride_h = H * K * V
97
100
  stride_k = Hg * K
@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
143
146
  )
144
147
  tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
145
148
 
146
- p_v = tl.make_block_ptr(
147
- v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
148
- )
149
- p_v_new = (
150
- tl.make_block_ptr(
151
- v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
152
- )
153
- if SAVE_NEW_VALUE
154
- else None
155
- )
156
- b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
157
149
  p_w = tl.make_block_ptr(
158
150
  w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
159
151
  )
160
152
  b_w = tl.load(p_w, boundary_check=(0, 1))
161
- b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
153
+ b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
162
154
  if K > 64:
163
155
  p_w = tl.make_block_ptr(
164
156
  w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
165
157
  )
166
158
  b_w = tl.load(p_w, boundary_check=(0, 1))
167
- b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
159
+ b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
168
160
  if K > 128:
169
161
  p_w = tl.make_block_ptr(
170
162
  w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
171
163
  )
172
164
  b_w = tl.load(p_w, boundary_check=(0, 1))
173
- b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
165
+ b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
174
166
  if K > 192:
175
167
  p_w = tl.make_block_ptr(
176
168
  w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
177
169
  )
178
170
  b_w = tl.load(p_w, boundary_check=(0, 1))
179
- b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
180
- b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
171
+ b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
172
+ p_v = tl.make_block_ptr(
173
+ v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
174
+ )
175
+ b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
181
176
 
182
177
  if SAVE_NEW_VALUE:
183
- p_v_new = tl.make_block_ptr(
178
+ p_v = tl.make_block_ptr(
184
179
  v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
185
180
  )
186
- tl.store(
187
- p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
188
- )
181
+ tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
189
182
 
183
+ last_idx = min((i_t + 1) * BT, T) - 1
190
184
  if USE_G:
191
- last_idx = min((i_t + 1) * BT, T) - 1
192
185
  b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
193
186
  p_g = tl.make_block_ptr(
194
187
  g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
195
188
  )
196
189
  b_g = tl.load(p_g, boundary_check=(0,))
197
- b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
190
+ b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
198
191
  b_g_last = exp(b_g_last)
199
192
  b_h1 = b_h1 * b_g_last
200
193
  if K > 64:
@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
203
196
  b_h3 = b_h3 * b_g_last
204
197
  if K > 192:
205
198
  b_h4 = b_h4 * b_g_last
206
- b_v_new = b_v_new.to(k.dtype.element_ty)
199
+
200
+ if USE_GK:
201
+ o_k1 = tl.arange(0, 64)
202
+ b_gk_last1 = tl.load(
203
+ gk + (bos + last_idx) * H * K + i_h * K + o_k1,
204
+ mask=(o_k1 < K),
205
+ other=0.0,
206
+ )
207
+ b_h1 *= exp(b_gk_last1)[:, None]
208
+ if K > 64:
209
+ o_k2 = 64 + o_k1
210
+ b_gk_last2 = tl.load(
211
+ gk + (bos + last_idx) * H * K + i_h * K + o_k2,
212
+ mask=(o_k2 < K),
213
+ other=0.0,
214
+ )
215
+ b_h2 *= exp(b_gk_last2)[:, None]
216
+ if K > 128:
217
+ o_k3 = 128 + o_k1
218
+ b_gk_last3 = tl.load(
219
+ gk + (bos + last_idx) * H * K + i_h * K + o_k3,
220
+ mask=(o_k3 < K),
221
+ other=0.0,
222
+ )
223
+ b_h3 *= exp(b_gk_last3)[:, None]
224
+ if K > 192:
225
+ o_k4 = 192 + o_k1
226
+ b_gk_last4 = tl.load(
227
+ gk + (bos + last_idx) * H * K + i_h * K + o_k4,
228
+ mask=(o_k4 < K),
229
+ other=0.0,
230
+ )
231
+ b_h4 *= exp(b_gk_last4)[:, None]
232
+ b_v = b_v.to(k.dtype.element_ty)
233
+
207
234
  p_k = tl.make_block_ptr(
208
235
  k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
209
236
  )
210
237
  b_k = tl.load(p_k, boundary_check=(0, 1))
211
- b_h1 += tl.dot(b_k, b_v_new)
238
+ b_h1 += tl.dot(b_k, b_v)
212
239
  if K > 64:
213
240
  p_k = tl.make_block_ptr(
214
241
  k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
215
242
  )
216
243
  b_k = tl.load(p_k, boundary_check=(0, 1))
217
- b_h2 += tl.dot(b_k, b_v_new)
244
+ b_h2 += tl.dot(b_k, b_v)
218
245
  if K > 128:
219
246
  p_k = tl.make_block_ptr(
220
247
  k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
221
248
  )
222
249
  b_k = tl.load(p_k, boundary_check=(0, 1))
223
- b_h3 += tl.dot(b_k, b_v_new)
250
+ b_h3 += tl.dot(b_k, b_v)
224
251
  if K > 192:
225
252
  p_k = tl.make_block_ptr(
226
253
  k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
227
254
  )
228
255
  b_k = tl.load(p_k, boundary_check=(0, 1))
229
- b_h4 += tl.dot(b_k, b_v_new)
256
+ b_h4 += tl.dot(b_k, b_v)
230
257
 
231
258
  # epilogue
232
259
  if STORE_FINAL_STATE:
@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h(
254
281
  w: torch.Tensor,
255
282
  u: torch.Tensor,
256
283
  g: Optional[torch.Tensor] = None,
284
+ gk: Optional[torch.Tensor] = None,
257
285
  initial_state: Optional[torch.Tensor] = None,
258
286
  output_final_state: bool = False,
259
287
  chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h(
296
324
  w=w,
297
325
  v_new=v_new,
298
326
  g=g,
327
+ gk=gk,
299
328
  h=h,
300
329
  h0=initial_state,
301
330
  ht=final_state,
@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
44
44
  IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
45
45
  USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
46
46
  IS_VARLEN: tl.constexpr,
47
+ IS_KDA: tl.constexpr,
47
48
  ):
48
49
  i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
49
50
  i_n, i_hv = i_nh // HV, i_nh % HV
@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
67
68
  p_beta = beta + (bos * HV + i_hv) * V + o_v
68
69
  else:
69
70
  p_beta = beta + bos * HV + i_hv
70
- p_g = g + bos * HV + i_hv
71
+ if not IS_KDA:
72
+ p_g = g + bos * HV + i_hv
73
+ else:
74
+ p_gk = g + (bos * HV + i_hv) * K + o_k
75
+
71
76
  p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
72
77
 
73
78
  mask_k = o_k < K
@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
83
88
  b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
84
89
  b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
85
90
  b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
86
- b_g = tl.load(p_g).to(tl.float32)
87
91
 
88
92
  if USE_QK_L2NORM_IN_KERNEL:
89
93
  b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
90
94
  b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
91
95
  b_q = b_q * scale
92
96
  # [BK, BV]
93
- b_h *= exp(b_g)
97
+ if not IS_KDA:
98
+ b_g = tl.load(p_g).to(tl.float32)
99
+ b_h *= exp(b_g)
100
+ else:
101
+ b_gk = tl.load(p_gk).to(tl.float32)
102
+ b_h *= exp(b_gk[:, None])
94
103
  # [BV]
95
104
  b_v -= tl.sum(b_h * b_k[:, None], 0)
96
105
  if IS_BETA_HEADWISE:
@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
108
117
  p_k += H * K
109
118
  p_o += HV * V
110
119
  p_v += HV * V
111
- p_g += HV
120
+ if not IS_KDA:
121
+ p_g += HV
122
+ else:
123
+ p_gk += HV * K
112
124
  p_beta += HV * (V if IS_BETA_HEADWISE else 1)
113
125
 
114
126
  if STORE_FINAL_STATE:
@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd(
165
177
  BV=BV,
166
178
  IS_BETA_HEADWISE=beta.ndim == v.ndim,
167
179
  USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
180
+ IS_KDA=False,
168
181
  num_warps=num_warps,
169
182
  num_stages=num_stages,
170
183
  )