sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +14 -1
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +27 -15
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +14 -13
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +9 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
- sglang/srt/managers/scheduler.py +39 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +94 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +4 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +31 -10
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +4 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -64,7 +64,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
64
64
|
if self.use_mla:
|
65
65
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
66
66
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
67
|
-
|
67
|
+
self.native_attn = TorchNativeAttnBackend(model_runner)
|
68
68
|
self.graph_metadata = {}
|
69
69
|
self.max_context_len = model_runner.model_config.context_len
|
70
70
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
@@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
158
158
|
self.graph_mode = True
|
159
159
|
|
160
160
|
def get_cuda_graph_seq_len_fill_value(self):
|
161
|
-
return
|
161
|
+
return 0
|
162
162
|
|
163
163
|
def forward_extend(
|
164
164
|
self,
|
@@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
167
167
|
v,
|
168
168
|
layer: RadixAttention,
|
169
169
|
forward_batch: ForwardBatch,
|
170
|
-
save_kv_cache=True,
|
170
|
+
save_kv_cache: bool = True,
|
171
171
|
):
|
172
172
|
if not self.use_mla:
|
173
173
|
if save_kv_cache:
|
@@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
180
180
|
|
181
181
|
if self.use_fia:
|
182
182
|
"""FIA will support multi-bs in the later version of CANN"""
|
183
|
-
q = q.
|
183
|
+
q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
184
184
|
attn_output = torch.empty(
|
185
185
|
(q.size(0), layer.tp_q_head_num, layer.v_head_dim),
|
186
186
|
device=q.device,
|
@@ -208,26 +208,61 @@ class AscendAttnBackend(AttentionBackend):
|
|
208
208
|
)
|
209
209
|
|
210
210
|
else:
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
211
|
+
if layer.qk_head_dim <= 128:
|
212
|
+
query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
213
|
+
attn_output = torch.empty(
|
214
|
+
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
215
|
+
dtype=query.dtype,
|
216
|
+
device=query.device,
|
217
|
+
)
|
217
218
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
219
|
+
torch_npu._npu_flash_attention_qlens(
|
220
|
+
query=query,
|
221
|
+
key_cache=k_cache,
|
222
|
+
value_cache=v_cache,
|
223
|
+
mask=self.mask,
|
224
|
+
block_table=self.forward_metadata.block_tables,
|
225
|
+
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
226
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
227
|
+
scale_value=layer.scaling,
|
228
|
+
num_heads=layer.tp_q_head_num,
|
229
|
+
num_kv_heads=layer.tp_k_head_num,
|
230
|
+
out=attn_output,
|
231
|
+
)
|
232
|
+
else:
|
233
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
234
|
+
attn_output = q.new_empty(
|
235
|
+
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
236
|
+
)
|
237
|
+
else:
|
238
|
+
attn_output = torch.empty_like(q)
|
239
|
+
|
240
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
241
|
+
|
242
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
243
|
+
o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
244
|
+
|
245
|
+
causal = True
|
246
|
+
if (
|
247
|
+
layer.is_cross_attention
|
248
|
+
or layer.attn_type == AttentionType.ENCODER_ONLY
|
249
|
+
):
|
250
|
+
causal = False
|
251
|
+
|
252
|
+
self.native_attn._run_sdpa_forward_extend(
|
253
|
+
q_,
|
254
|
+
o_,
|
255
|
+
k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
|
256
|
+
v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
257
|
+
forward_batch.req_to_token_pool.req_to_token,
|
258
|
+
forward_batch.req_pool_indices,
|
259
|
+
forward_batch.seq_lens,
|
260
|
+
forward_batch.extend_prefix_lens,
|
261
|
+
forward_batch.extend_seq_lens,
|
262
|
+
scaling=layer.scaling,
|
263
|
+
enable_gqa=use_gqa,
|
264
|
+
causal=causal,
|
265
|
+
)
|
231
266
|
else:
|
232
267
|
assert (
|
233
268
|
layer.qk_head_dim != layer.v_head_dim
|
@@ -253,6 +288,136 @@ class AscendAttnBackend(AttentionBackend):
|
|
253
288
|
|
254
289
|
return attn_output
|
255
290
|
|
291
|
+
def forward_decode_graph(
|
292
|
+
self,
|
293
|
+
q: torch.Tensor,
|
294
|
+
k: torch.Tensor,
|
295
|
+
v: torch.Tensor,
|
296
|
+
layer: RadixAttention,
|
297
|
+
forward_batch: ForwardBatch,
|
298
|
+
save_kv_cache: bool = True,
|
299
|
+
q_rope: Optional[torch.Tensor] = None,
|
300
|
+
k_rope: Optional[torch.Tensor] = None,
|
301
|
+
):
|
302
|
+
if save_kv_cache:
|
303
|
+
if self.use_mla:
|
304
|
+
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
305
|
+
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
306
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
307
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
308
|
+
)
|
309
|
+
else:
|
310
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
311
|
+
layer, forward_batch.out_cache_loc, k, v
|
312
|
+
)
|
313
|
+
|
314
|
+
if not self.use_mla:
|
315
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
316
|
+
layer.layer_id
|
317
|
+
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
318
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
319
|
+
layer.layer_id
|
320
|
+
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
321
|
+
query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
322
|
+
if self.forward_metadata.seq_lens_cpu_int is None:
|
323
|
+
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
324
|
+
else:
|
325
|
+
actual_seq_len_kv = (
|
326
|
+
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
|
327
|
+
)
|
328
|
+
num_tokens = query.shape[0]
|
329
|
+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
330
|
+
query,
|
331
|
+
k_cache,
|
332
|
+
v_cache,
|
333
|
+
block_table=self.forward_metadata.block_tables,
|
334
|
+
block_size=self.page_size,
|
335
|
+
num_heads=layer.tp_q_head_num,
|
336
|
+
num_key_value_heads=layer.tp_k_head_num,
|
337
|
+
input_layout="BSH",
|
338
|
+
scale=layer.scaling,
|
339
|
+
actual_seq_lengths_kv=actual_seq_len_kv,
|
340
|
+
)
|
341
|
+
output = torch.empty(
|
342
|
+
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
343
|
+
dtype=q.dtype,
|
344
|
+
device=q.device,
|
345
|
+
)
|
346
|
+
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
347
|
+
torch_npu.npu_fused_infer_attention_score.out(
|
348
|
+
query,
|
349
|
+
k_cache,
|
350
|
+
v_cache,
|
351
|
+
block_table=self.forward_metadata.block_tables,
|
352
|
+
block_size=self.page_size,
|
353
|
+
num_heads=layer.tp_q_head_num,
|
354
|
+
num_key_value_heads=layer.tp_k_head_num,
|
355
|
+
input_layout="BSH",
|
356
|
+
scale=layer.scaling,
|
357
|
+
actual_seq_lengths_kv=actual_seq_len_kv,
|
358
|
+
workspace=workspace,
|
359
|
+
out=[output, softmax_lse],
|
360
|
+
)
|
361
|
+
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
362
|
+
else:
|
363
|
+
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
364
|
+
k_rope_cache = k_rope.view(
|
365
|
+
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
|
366
|
+
)
|
367
|
+
c_kv_cache = c_kv.view(
|
368
|
+
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
|
369
|
+
)
|
370
|
+
|
371
|
+
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
|
372
|
+
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
|
373
|
+
if self.forward_metadata.seq_lens_cpu_int is None:
|
374
|
+
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
375
|
+
else:
|
376
|
+
actual_seq_len_kv = (
|
377
|
+
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
|
378
|
+
)
|
379
|
+
|
380
|
+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
381
|
+
q_nope,
|
382
|
+
c_kv_cache,
|
383
|
+
c_kv_cache,
|
384
|
+
query_rope=q_rope,
|
385
|
+
key_rope=k_rope_cache,
|
386
|
+
num_heads=layer.tp_q_head_num,
|
387
|
+
num_key_value_heads=layer.tp_k_head_num,
|
388
|
+
block_table=self.forward_metadata.block_tables,
|
389
|
+
block_size=self.page_size,
|
390
|
+
input_layout="BNSD",
|
391
|
+
scale=layer.scaling,
|
392
|
+
actual_seq_lengths_kv=actual_seq_len_kv,
|
393
|
+
antiquant_mode=0,
|
394
|
+
antiquant_scale=None,
|
395
|
+
sparse_mode=0,
|
396
|
+
)
|
397
|
+
output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
|
398
|
+
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
399
|
+
|
400
|
+
torch_npu.npu_fused_infer_attention_score.out(
|
401
|
+
q_nope,
|
402
|
+
c_kv_cache,
|
403
|
+
c_kv_cache,
|
404
|
+
query_rope=q_rope,
|
405
|
+
key_rope=k_rope_cache,
|
406
|
+
num_heads=layer.tp_q_head_num,
|
407
|
+
num_key_value_heads=layer.tp_k_head_num,
|
408
|
+
block_table=self.forward_metadata.block_tables,
|
409
|
+
block_size=self.page_size,
|
410
|
+
input_layout="BNSD",
|
411
|
+
scale=layer.scaling,
|
412
|
+
actual_seq_lengths_kv=actual_seq_len_kv,
|
413
|
+
antiquant_mode=0,
|
414
|
+
antiquant_scale=None,
|
415
|
+
sparse_mode=0,
|
416
|
+
workspace=workspace,
|
417
|
+
out=[output, softmax_lse],
|
418
|
+
)
|
419
|
+
return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
|
420
|
+
|
256
421
|
def forward_decode(
|
257
422
|
self,
|
258
423
|
q: torch.Tensor,
|
@@ -260,106 +425,74 @@ class AscendAttnBackend(AttentionBackend):
|
|
260
425
|
v: torch.Tensor,
|
261
426
|
layer: RadixAttention,
|
262
427
|
forward_batch: ForwardBatch,
|
263
|
-
save_kv_cache: bool =
|
428
|
+
save_kv_cache: bool = True,
|
264
429
|
# For multi-head latent attention
|
265
430
|
q_rope: Optional[torch.Tensor] = None,
|
266
431
|
k_rope: Optional[torch.Tensor] = None,
|
267
432
|
):
|
433
|
+
if self.graph_mode:
|
434
|
+
return self.forward_decode_graph(
|
435
|
+
q,
|
436
|
+
k,
|
437
|
+
v,
|
438
|
+
layer,
|
439
|
+
forward_batch,
|
440
|
+
save_kv_cache,
|
441
|
+
q_rope=q_rope,
|
442
|
+
k_rope=k_rope,
|
443
|
+
)
|
444
|
+
|
268
445
|
if not self.use_mla:
|
269
446
|
if save_kv_cache:
|
270
447
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
271
448
|
layer, forward_batch.out_cache_loc, k, v
|
272
449
|
)
|
273
450
|
num_tokens = q.shape[0]
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
num_key_value_heads=layer.tp_k_head_num,
|
291
|
-
input_layout="BSH",
|
292
|
-
scale=layer.scaling,
|
293
|
-
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
294
|
-
)
|
295
|
-
)
|
296
|
-
attn_output = torch.empty(
|
297
|
-
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
298
|
-
dtype=q.dtype,
|
299
|
-
device=q.device,
|
300
|
-
)
|
301
|
-
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
302
|
-
torch_npu.npu_fused_infer_attention_score.out(
|
303
|
-
query,
|
304
|
-
k_cache,
|
305
|
-
v_cache,
|
306
|
-
block_table=self.forward_metadata.block_tables,
|
307
|
-
block_size=self.page_size,
|
451
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
452
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
453
|
+
if self.use_fia:
|
454
|
+
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
455
|
+
q.view(
|
456
|
+
forward_batch.batch_size,
|
457
|
+
-1,
|
458
|
+
layer.tp_q_head_num,
|
459
|
+
layer.qk_head_dim,
|
460
|
+
),
|
461
|
+
k_cache.view(
|
462
|
+
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
463
|
+
),
|
464
|
+
v_cache.view(
|
465
|
+
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
466
|
+
),
|
308
467
|
num_heads=layer.tp_q_head_num,
|
309
468
|
num_key_value_heads=layer.tp_k_head_num,
|
310
|
-
input_layout="
|
469
|
+
input_layout="BSND",
|
470
|
+
atten_mask=None,
|
471
|
+
block_size=self.page_size,
|
472
|
+
block_table=self.forward_metadata.block_tables,
|
473
|
+
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
311
474
|
scale=layer.scaling,
|
312
|
-
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
313
|
-
workspace=workspace,
|
314
|
-
out=[attn_output, softmax_lse],
|
315
475
|
)
|
316
476
|
else:
|
317
|
-
|
318
|
-
|
319
|
-
|
477
|
+
query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
478
|
+
num_tokens = query.shape[0]
|
479
|
+
attn_output = torch.empty(
|
480
|
+
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
481
|
+
dtype=query.dtype,
|
482
|
+
device=query.device,
|
320
483
|
)
|
321
|
-
if self.use_fia:
|
322
|
-
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
323
|
-
q.view(
|
324
|
-
forward_batch.batch_size,
|
325
|
-
-1,
|
326
|
-
layer.tp_q_head_num,
|
327
|
-
layer.qk_head_dim,
|
328
|
-
),
|
329
|
-
k_cache.view(
|
330
|
-
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
331
|
-
),
|
332
|
-
v_cache.view(
|
333
|
-
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
334
|
-
),
|
335
|
-
num_heads=layer.tp_q_head_num,
|
336
|
-
num_key_value_heads=layer.tp_k_head_num,
|
337
|
-
input_layout="BSND",
|
338
|
-
atten_mask=None,
|
339
|
-
block_size=self.page_size,
|
340
|
-
block_table=self.forward_metadata.block_tables,
|
341
|
-
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
342
|
-
scale=layer.scaling,
|
343
|
-
)
|
344
|
-
else:
|
345
|
-
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
346
|
-
attn_output = torch.empty(
|
347
|
-
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
348
|
-
dtype=query.dtype,
|
349
|
-
device=query.device,
|
350
|
-
)
|
351
484
|
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
485
|
+
torch_npu._npu_paged_attention(
|
486
|
+
query=query,
|
487
|
+
key_cache=k_cache,
|
488
|
+
value_cache=v_cache,
|
489
|
+
num_heads=layer.tp_q_head_num,
|
490
|
+
num_kv_heads=layer.tp_k_head_num,
|
491
|
+
scale_value=layer.scaling,
|
492
|
+
block_table=self.forward_metadata.block_tables,
|
493
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
494
|
+
out=attn_output,
|
495
|
+
)
|
363
496
|
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
364
497
|
else:
|
365
498
|
if save_kv_cache:
|
@@ -370,9 +503,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
370
503
|
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
371
504
|
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
372
505
|
|
373
|
-
if
|
374
|
-
layer.tp_q_head_num // layer.tp_k_head_num
|
375
|
-
) >= 8:
|
506
|
+
if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
|
376
507
|
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
|
377
508
|
kv_c = kv_c.view(
|
378
509
|
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
|
@@ -5,6 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
6
6
|
from sglang.srt.layers.radix_attention import RadixAttention
|
7
7
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
8
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
8
9
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
9
10
|
|
10
11
|
|
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
|
|
12
13
|
"""Support different backends for prefill and decode."""
|
13
14
|
|
14
15
|
def __init__(
|
15
|
-
self,
|
16
|
+
self,
|
17
|
+
model_runner: ModelRunner,
|
18
|
+
prefill_backend: AttentionBackend,
|
19
|
+
decode_backend: AttentionBackend,
|
16
20
|
):
|
21
|
+
self.model_runner = model_runner
|
17
22
|
self.prefill_backend = prefill_backend
|
18
23
|
self.decode_backend = decode_backend
|
19
24
|
|
20
25
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
21
|
-
if forward_batch.forward_mode.
|
26
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
22
27
|
self.decode_backend.init_forward_metadata(forward_batch)
|
23
28
|
else:
|
24
29
|
self.prefill_backend.init_forward_metadata(forward_batch)
|
25
30
|
|
26
31
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
27
32
|
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
33
|
+
if self.model_runner.server_args.speculative_algorithm is not None:
|
34
|
+
# When speculative decoding is enabled, we also need to initialize the
|
35
|
+
# prefill backend's cuda graph state to support target_verify.
|
36
|
+
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
28
37
|
|
29
38
|
def init_forward_metadata_capture_cuda_graph(
|
30
39
|
self,
|
@@ -36,15 +45,26 @@ class HybridAttnBackend(AttentionBackend):
|
|
36
45
|
forward_mode: ForwardMode,
|
37
46
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
38
47
|
):
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
+
if forward_mode.is_decode_or_idle():
|
49
|
+
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
50
|
+
bs,
|
51
|
+
num_tokens,
|
52
|
+
req_pool_indices,
|
53
|
+
seq_lens,
|
54
|
+
encoder_lens,
|
55
|
+
forward_mode,
|
56
|
+
spec_info,
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
|
60
|
+
bs,
|
61
|
+
num_tokens,
|
62
|
+
req_pool_indices,
|
63
|
+
seq_lens,
|
64
|
+
encoder_lens,
|
65
|
+
forward_mode,
|
66
|
+
spec_info,
|
67
|
+
)
|
48
68
|
|
49
69
|
def init_forward_metadata_replay_cuda_graph(
|
50
70
|
self,
|
@@ -57,16 +77,28 @@ class HybridAttnBackend(AttentionBackend):
|
|
57
77
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
58
78
|
seq_lens_cpu: Optional[torch.Tensor],
|
59
79
|
):
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
80
|
+
if forward_mode.is_decode_or_idle():
|
81
|
+
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
82
|
+
bs,
|
83
|
+
req_pool_indices,
|
84
|
+
seq_lens,
|
85
|
+
seq_lens_sum,
|
86
|
+
encoder_lens,
|
87
|
+
forward_mode,
|
88
|
+
spec_info,
|
89
|
+
seq_lens_cpu,
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
|
93
|
+
bs,
|
94
|
+
req_pool_indices,
|
95
|
+
seq_lens,
|
96
|
+
seq_lens_sum,
|
97
|
+
encoder_lens,
|
98
|
+
forward_mode,
|
99
|
+
spec_info,
|
100
|
+
seq_lens_cpu,
|
101
|
+
)
|
70
102
|
|
71
103
|
def get_cuda_graph_seq_len_fill_value(self):
|
72
104
|
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
@@ -51,6 +51,7 @@ class TRTLLMMLADecodeMetadata:
|
|
51
51
|
|
52
52
|
workspace: Optional[torch.Tensor] = None
|
53
53
|
block_kv_indices: Optional[torch.Tensor] = None
|
54
|
+
max_seq_len: Optional[int] = None
|
54
55
|
|
55
56
|
|
56
57
|
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
@@ -207,8 +208,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
207
208
|
)
|
208
209
|
|
209
210
|
# Custom fast-path for decode/idle.
|
210
|
-
|
211
|
-
|
211
|
+
# Capture with full width so future longer sequences are safe during replay
|
212
|
+
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
213
|
+
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]
|
212
214
|
|
213
215
|
create_flashmla_kv_indices_triton[(bs,)](
|
214
216
|
self.req_to_token,
|
@@ -217,13 +219,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
217
219
|
None,
|
218
220
|
block_kv_indices,
|
219
221
|
self.req_to_token.stride(0),
|
220
|
-
|
222
|
+
max_blocks_per_seq,
|
221
223
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
222
224
|
PAGED_SIZE=self.page_size,
|
223
225
|
)
|
224
226
|
|
227
|
+
# Record the true maximum sequence length for this capture batch so that
|
228
|
+
# the kernel launch path (which requires an int not a tensor) can reuse
|
229
|
+
# it safely during both capture and replay.
|
230
|
+
max_seq_len_val = int(seq_lens.max().item())
|
231
|
+
|
225
232
|
metadata = TRTLLMMLADecodeMetadata(
|
226
|
-
self.decode_cuda_graph_workspace,
|
233
|
+
self.decode_cuda_graph_workspace,
|
234
|
+
block_kv_indices,
|
235
|
+
max_seq_len_val,
|
227
236
|
)
|
228
237
|
self.decode_cuda_graph_metadata[bs] = metadata
|
229
238
|
self.forward_metadata = metadata
|
@@ -268,6 +277,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
268
277
|
PAGED_SIZE=self.page_size,
|
269
278
|
)
|
270
279
|
|
280
|
+
# Update stored max_seq_len so subsequent kernel calls use the correct value
|
281
|
+
# Prefer CPU tensor to avoid GPU synchronization when available.
|
282
|
+
if seq_lens_cpu is not None:
|
283
|
+
metadata.max_seq_len = int(seq_lens_cpu.max().item())
|
284
|
+
else:
|
285
|
+
metadata.max_seq_len = int(seq_lens.max().item())
|
286
|
+
|
271
287
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
272
288
|
"""Get the fill value for sequence lengths in CUDA graph."""
|
273
289
|
return 1
|
@@ -295,8 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
295
311
|
forward_batch.seq_lens.device,
|
296
312
|
)
|
297
313
|
|
314
|
+
max_seq_len_val = int(max_seq)
|
298
315
|
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
299
|
-
self.workspace_buffer, block_kv_indices
|
316
|
+
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
300
317
|
)
|
301
318
|
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
302
319
|
|
@@ -471,14 +488,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
471
488
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
472
489
|
block_tables=metadata.block_kv_indices,
|
473
490
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
474
|
-
max_seq_len=
|
491
|
+
max_seq_len=metadata.max_seq_len,
|
475
492
|
bmm1_scale=bmm1_scale,
|
476
493
|
)
|
477
494
|
|
478
|
-
#
|
479
|
-
|
480
|
-
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
481
|
-
|
495
|
+
# Reshape output directly without slicing
|
496
|
+
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
482
497
|
return output
|
483
498
|
|
484
499
|
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -53,7 +53,7 @@ elif _is_hip:
|
|
53
53
|
|
54
54
|
logger = logging.getLogger(__name__)
|
55
55
|
|
56
|
-
if
|
56
|
+
if _is_npu:
|
57
57
|
import torch_npu
|
58
58
|
|
59
59
|
|
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
|
|
266
266
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
267
267
|
return out
|
268
268
|
|
269
|
+
def forward_npu(
|
270
|
+
self,
|
271
|
+
x: torch.Tensor,
|
272
|
+
residual: Optional[torch.Tensor] = None,
|
273
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
274
|
+
orig_dtype = x.dtype
|
275
|
+
if residual is not None:
|
276
|
+
x = x + residual
|
277
|
+
residual = x
|
269
278
|
|
270
|
-
|
279
|
+
x = x.float()
|
280
|
+
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
281
|
+
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
282
|
+
x = x * (1.0 + self.weight.float())
|
283
|
+
x = x.to(orig_dtype)
|
284
|
+
return x if residual is None else (x, residual)
|
285
|
+
|
286
|
+
|
287
|
+
class Gemma3RMSNorm(CustomOp):
|
271
288
|
def __init__(self, dim: int, eps: float = 1e-6):
|
272
289
|
super().__init__()
|
273
290
|
self.eps = eps
|
274
291
|
self.weight = nn.Parameter(torch.zeros(dim))
|
292
|
+
# Re-dispatch
|
275
293
|
|
276
294
|
def _norm(self, x):
|
277
295
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
278
296
|
|
279
|
-
def
|
297
|
+
def forward_native(self, x):
|
280
298
|
output = self._norm(x.float())
|
281
299
|
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
282
300
|
# See https://github.com/huggingface/transformers/pull/29402
|
283
301
|
output = output * (1.0 + self.weight.float())
|
284
302
|
return output.type_as(x)
|
285
303
|
|
304
|
+
def forward_cuda(self, x):
|
305
|
+
return self.forward_native(x)
|
306
|
+
|
307
|
+
def forward_npu(self, x):
|
308
|
+
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
|
309
|
+
return output
|
310
|
+
|
286
311
|
def extra_repr(self):
|
287
312
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
288
313
|
|