sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
|
|
9
9
|
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
10
10
|
"""
|
11
11
|
|
12
|
+
import os
|
12
13
|
from dataclasses import dataclass
|
13
14
|
from functools import partial
|
14
15
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
@@ -16,6 +17,12 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
16
17
|
import torch
|
17
18
|
import triton
|
18
19
|
|
20
|
+
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
21
|
+
import logging
|
22
|
+
|
23
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
24
|
+
torch._dynamo.config.suppress_errors = True
|
25
|
+
|
19
26
|
from sglang.global_config import global_config
|
20
27
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
21
28
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
@@ -332,23 +339,39 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
332
339
|
layer: RadixAttention,
|
333
340
|
forward_batch: ForwardBatch,
|
334
341
|
save_kv_cache: bool = True,
|
342
|
+
q_rope: Optional[torch.Tensor] = None,
|
343
|
+
k_rope: Optional[torch.Tensor] = None,
|
335
344
|
):
|
336
345
|
|
337
346
|
cache_loc = forward_batch.out_cache_loc
|
338
347
|
logits_soft_cap = layer.logit_cap
|
339
348
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
340
|
-
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
341
349
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
342
350
|
|
343
351
|
# Save kv cache
|
344
352
|
if save_kv_cache and k is not None:
|
345
353
|
assert v is not None
|
346
354
|
if save_kv_cache:
|
347
|
-
|
355
|
+
if k_rope is not None:
|
356
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
357
|
+
layer, cache_loc, k, k_rope
|
358
|
+
)
|
359
|
+
else:
|
360
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
361
|
+
if q_rope is not None:
|
362
|
+
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
363
|
+
q_rope = q_rope.view(
|
364
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
365
|
+
)
|
348
366
|
|
349
367
|
if self.forward_metadata.use_ragged:
|
350
368
|
# ragged prefill
|
351
|
-
|
369
|
+
if q_rope is not None:
|
370
|
+
q = torch.cat([q, q_rope], dim=-1)
|
371
|
+
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
372
|
+
if k_rope is not None:
|
373
|
+
k = torch.cat([k, k_rope], dim=-1)
|
374
|
+
o = self.prefill_wrapper_ragged.forward(
|
352
375
|
qall,
|
353
376
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
354
377
|
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
@@ -358,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
358
381
|
)
|
359
382
|
else:
|
360
383
|
# mla paged prefill
|
384
|
+
if q_rope is None:
|
385
|
+
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
386
|
+
q, q_rope = (
|
387
|
+
qall[:, :, : layer.v_head_dim],
|
388
|
+
qall[:, :, layer.v_head_dim :],
|
389
|
+
)
|
390
|
+
o = q.new_empty(q.shape)
|
361
391
|
o = prefill_wrapper_paged.run(
|
362
|
-
|
363
|
-
|
392
|
+
q,
|
393
|
+
q_rope,
|
364
394
|
k_buf[:, :, : layer.v_head_dim],
|
365
395
|
k_buf[:, :, layer.v_head_dim :],
|
396
|
+
out=o,
|
366
397
|
)
|
367
398
|
|
368
399
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -375,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
375
406
|
layer: RadixAttention,
|
376
407
|
forward_batch: ForwardBatch,
|
377
408
|
save_kv_cache: bool = True,
|
409
|
+
# For multi-head latent attention
|
410
|
+
q_rope: Optional[torch.Tensor] = None,
|
411
|
+
k_rope: Optional[torch.Tensor] = None,
|
378
412
|
):
|
379
413
|
decode_wrapper = self.forward_metadata.decode_wrapper
|
380
414
|
cache_loc = forward_batch.out_cache_loc
|
@@ -382,20 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
382
416
|
if k is not None:
|
383
417
|
assert v is not None
|
384
418
|
if save_kv_cache:
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
419
|
+
if k_rope is not None:
|
420
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
421
|
+
layer,
|
422
|
+
cache_loc,
|
423
|
+
k,
|
424
|
+
k_rope,
|
425
|
+
)
|
426
|
+
else:
|
427
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
428
|
+
layer,
|
429
|
+
cache_loc,
|
430
|
+
k,
|
431
|
+
v,
|
432
|
+
)
|
433
|
+
|
434
|
+
# Reshape inputs
|
435
|
+
if q_rope is not None:
|
436
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
437
|
+
q_rope = q_rope.view(
|
438
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
439
|
+
)
|
440
|
+
else:
|
441
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
442
|
+
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
443
|
+
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
444
|
+
|
392
445
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
393
|
-
|
446
|
+
|
447
|
+
o = q_nope.new_empty(q_nope.shape)
|
448
|
+
# Direct call to run without the wrapper
|
394
449
|
o = decode_wrapper.run(
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
450
|
+
q_nope,
|
451
|
+
q_rope,
|
452
|
+
k_buffer[:, :, : layer.v_head_dim],
|
453
|
+
k_buffer[:, :, layer.v_head_dim :],
|
454
|
+
out=o,
|
399
455
|
)
|
400
456
|
|
401
457
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -825,16 +881,18 @@ def fast_mla_decode_plan(
|
|
825
881
|
self._sm_scale = sm_scale
|
826
882
|
|
827
883
|
with self.device as device:
|
828
|
-
|
829
|
-
|
830
|
-
self.
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
884
|
+
try:
|
885
|
+
# Standard version with just the required arguments (no use_profiler)
|
886
|
+
self._cached_module.plan.default(
|
887
|
+
self._float_workspace_buffer,
|
888
|
+
self._int_workspace_buffer,
|
889
|
+
self._pin_memory_int_workspace_buffer,
|
890
|
+
qo_indptr_cpu,
|
891
|
+
kv_indptr_cpu,
|
892
|
+
kv_len_arr_cpu,
|
893
|
+
num_heads,
|
894
|
+
head_dim_ckv,
|
895
|
+
causal,
|
896
|
+
)
|
897
|
+
except Exception as e:
|
898
|
+
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from sgl_kernel import merge_state_v2
|
5
|
+
|
6
|
+
from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
|
7
|
+
from sglang.srt.utils import is_cuda
|
8
|
+
|
9
|
+
_is_cuda = is_cuda()
|
10
|
+
|
11
|
+
|
12
|
+
# Automatically fallback to the Triton kernel in some cases
|
13
|
+
# (e.g., for AMD GPUs, when the head dimension is not a multiple
|
14
|
+
# of 4 or 8, and in FP8 precision)
|
15
|
+
def _supported_dtypes(o: torch.Tensor) -> bool:
|
16
|
+
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
17
|
+
|
18
|
+
|
19
|
+
def _supported_headdim(o: torch.Tensor) -> bool:
|
20
|
+
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
21
|
+
if o.dtype == torch.float32:
|
22
|
+
return headdim % 4 == 0
|
23
|
+
return headdim % 8 == 0
|
24
|
+
|
25
|
+
|
26
|
+
def merge_state(
|
27
|
+
prefix_output: torch.Tensor,
|
28
|
+
prefix_lse: torch.Tensor,
|
29
|
+
suffix_output: torch.Tensor,
|
30
|
+
suffix_lse: torch.Tensor,
|
31
|
+
output: Optional[torch.Tensor] = None,
|
32
|
+
output_lse: Optional[torch.Tensor] = None,
|
33
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
34
|
+
if (
|
35
|
+
_is_cuda
|
36
|
+
and _supported_dtypes(prefix_output)
|
37
|
+
and _supported_headdim(prefix_output)
|
38
|
+
):
|
39
|
+
return merge_state_v2(
|
40
|
+
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
# Fallback to Triton kernel
|
44
|
+
return merge_state_triton(
|
45
|
+
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
46
|
+
)
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def merge_state_kernel(
|
10
|
+
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
|
11
|
+
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
|
12
|
+
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
|
13
|
+
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
|
14
|
+
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
|
15
|
+
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
|
16
|
+
HEAD_SIZE: tl.constexpr,
|
17
|
+
PADDED_HEAD_SIZE: tl.constexpr,
|
18
|
+
OUTPUT_LSE: tl.constexpr,
|
19
|
+
):
|
20
|
+
token_idx = tl.program_id(0)
|
21
|
+
num_tokens = tl.num_programs(0)
|
22
|
+
head_idx = tl.program_id(1)
|
23
|
+
num_heads = tl.num_programs(1)
|
24
|
+
|
25
|
+
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
|
26
|
+
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
|
27
|
+
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
28
|
+
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
29
|
+
|
30
|
+
max_lse = tl.maximum(p_lse, s_lse)
|
31
|
+
p_lse = p_lse - max_lse
|
32
|
+
s_lse = s_lse - max_lse
|
33
|
+
out_se = tl.exp(p_lse) + tl.exp(s_lse)
|
34
|
+
|
35
|
+
if OUTPUT_LSE:
|
36
|
+
out_lse = tl.log(out_se) + max_lse
|
37
|
+
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
|
38
|
+
|
39
|
+
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
40
|
+
head_mask = head_arange < HEAD_SIZE
|
41
|
+
p_out = tl.load(
|
42
|
+
prefix_output
|
43
|
+
+ token_idx * num_heads * HEAD_SIZE
|
44
|
+
+ head_idx * HEAD_SIZE
|
45
|
+
+ head_arange,
|
46
|
+
mask=head_mask,
|
47
|
+
)
|
48
|
+
s_out = tl.load(
|
49
|
+
suffix_output
|
50
|
+
+ token_idx * num_heads * HEAD_SIZE
|
51
|
+
+ head_idx * HEAD_SIZE
|
52
|
+
+ head_arange,
|
53
|
+
mask=head_mask,
|
54
|
+
)
|
55
|
+
|
56
|
+
p_scale = tl.exp(p_lse) / out_se
|
57
|
+
s_scale = tl.exp(s_lse) / out_se
|
58
|
+
out = p_out * p_scale + s_out * s_scale
|
59
|
+
tl.store(
|
60
|
+
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
61
|
+
out,
|
62
|
+
mask=head_mask,
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
def merge_state_triton(
|
67
|
+
prefix_output: torch.Tensor,
|
68
|
+
prefix_lse: torch.Tensor,
|
69
|
+
suffix_output: torch.Tensor,
|
70
|
+
suffix_lse: torch.Tensor,
|
71
|
+
output: Optional[torch.Tensor] = None,
|
72
|
+
output_lse: Optional[torch.Tensor] = None,
|
73
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
74
|
+
# Avoid creating new tensors if they are already provided
|
75
|
+
if output is None:
|
76
|
+
output = torch.empty_like(prefix_output)
|
77
|
+
if output_lse is None:
|
78
|
+
output_lse = torch.empty_like(prefix_lse)
|
79
|
+
|
80
|
+
num_tokens = output.shape[0]
|
81
|
+
num_query_heads = output.shape[1]
|
82
|
+
head_size = output.shape[2]
|
83
|
+
padded_head_size = triton.next_power_of_2(head_size)
|
84
|
+
|
85
|
+
merge_state_kernel[(num_tokens, num_query_heads)](
|
86
|
+
output,
|
87
|
+
output_lse,
|
88
|
+
prefix_output,
|
89
|
+
prefix_lse,
|
90
|
+
suffix_output,
|
91
|
+
suffix_lse,
|
92
|
+
head_size,
|
93
|
+
padded_head_size,
|
94
|
+
output_lse is not None,
|
95
|
+
)
|
96
|
+
return output, output_lse
|