sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
|
39
39
|
# compute the LCM with other padding constraints.
|
40
40
|
TRTLLM_BLOCK_CONSTRAINT = 128
|
41
41
|
|
42
|
+
global_zero_init_workspace_buffer = None
|
43
|
+
|
42
44
|
|
43
45
|
@dataclass
|
44
46
|
class TRTLLMMLADecodeMetadata:
|
@@ -83,9 +85,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
83
85
|
|
84
86
|
# Workspace allocation
|
85
87
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
86
|
-
|
87
|
-
|
88
|
-
|
88
|
+
global global_zero_init_workspace_buffer
|
89
|
+
if global_zero_init_workspace_buffer is None:
|
90
|
+
global_zero_init_workspace_buffer = torch.zeros(
|
91
|
+
self.workspace_size,
|
92
|
+
dtype=torch.uint8,
|
93
|
+
device=model_runner.device,
|
94
|
+
)
|
95
|
+
self.workspace_buffer = global_zero_init_workspace_buffer
|
89
96
|
|
90
97
|
# CUDA graph state
|
91
98
|
self.decode_cuda_graph_metadata = {}
|
@@ -287,38 +294,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
287
294
|
)
|
288
295
|
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
289
296
|
|
297
|
+
def quantize_and_rope_for_fp8(
|
298
|
+
self,
|
299
|
+
q_nope: torch.Tensor,
|
300
|
+
q_rope: torch.Tensor,
|
301
|
+
k_nope: torch.Tensor,
|
302
|
+
k_rope: torch.Tensor,
|
303
|
+
forward_batch: ForwardBatch,
|
304
|
+
cos_sin_cache: torch.Tensor,
|
305
|
+
is_neox: bool,
|
306
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
307
|
+
"""Quantize and apply RoPE for FP8 attention path.
|
308
|
+
|
309
|
+
This function handles the FP8 quantization and RoPE application for MLA attention.
|
310
|
+
It takes separate query/key nope and rope components, applies RoPE to the rope parts,
|
311
|
+
quantizes all components to FP8, and merges the query components into a single tensor.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
315
|
+
- expected dtype: torch.bfloat16
|
316
|
+
q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
317
|
+
- expected dtype: torch.bfloat16
|
318
|
+
k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
319
|
+
- expected dtype: torch.bfloat16
|
320
|
+
k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
321
|
+
- expected dtype: torch.bfloat16
|
322
|
+
forward_batch: Forward batch containing position information
|
323
|
+
cos_sin_cache: Precomputed cosine/sine cache for RoPE
|
324
|
+
- expected dtype: matches q_/k_ input dtype (torch.bfloat16)
|
325
|
+
is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
|
329
|
+
- merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
330
|
+
- k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
|
331
|
+
- k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
332
|
+
"""
|
333
|
+
attn_dtype = torch.float8_e4m3fn
|
334
|
+
q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
|
335
|
+
|
336
|
+
# Allocate output tensors with FP8 dtype
|
337
|
+
# Query output will contain merged nope + rope components
|
338
|
+
q_out = q_rope.new_empty(
|
339
|
+
q_len,
|
340
|
+
num_heads,
|
341
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
342
|
+
dtype=attn_dtype,
|
343
|
+
)
|
344
|
+
|
345
|
+
# Key outputs maintain original shapes but with FP8 dtype
|
346
|
+
k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
|
347
|
+
k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
|
348
|
+
|
349
|
+
# Apply RoPE and quantize all components in a single fused kernel call
|
350
|
+
# This kernel handles:
|
351
|
+
# 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
|
352
|
+
# 2. Quantization of all components to FP8 format
|
353
|
+
# 3. Output placement into pre-allocated tensors
|
354
|
+
flashinfer.rope.mla_rope_quantize_fp8(
|
355
|
+
q_rope=q_rope,
|
356
|
+
k_rope=k_rope,
|
357
|
+
q_nope=q_nope,
|
358
|
+
k_nope=k_nope,
|
359
|
+
cos_sin_cache=cos_sin_cache,
|
360
|
+
pos_ids=forward_batch.positions,
|
361
|
+
is_neox=is_neox,
|
362
|
+
quantize_dtype=attn_dtype,
|
363
|
+
# Output tensor slicing: q_out contains [nope_part, rope_part]
|
364
|
+
q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
|
365
|
+
k_rope_out=k_rope_out,
|
366
|
+
q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
|
367
|
+
k_nope_out=k_nope_out,
|
368
|
+
# Quantization scales (set to 1.0 for no additional scaling)
|
369
|
+
quant_scale_q=1.0,
|
370
|
+
quant_scale_kv=1.0,
|
371
|
+
)
|
372
|
+
|
373
|
+
return q_out, k_nope_out, k_rope_out
|
374
|
+
|
290
375
|
def forward_decode(
|
291
376
|
self,
|
292
|
-
q: torch.Tensor,
|
293
|
-
k: torch.Tensor,
|
294
|
-
v: torch.Tensor,
|
377
|
+
q: torch.Tensor, # q_nope
|
378
|
+
k: torch.Tensor, # k_nope
|
379
|
+
v: torch.Tensor, # not used in this backend
|
295
380
|
layer: RadixAttention,
|
296
381
|
forward_batch: ForwardBatch,
|
297
382
|
save_kv_cache: bool = True,
|
298
383
|
q_rope: Optional[torch.Tensor] = None,
|
299
384
|
k_rope: Optional[torch.Tensor] = None,
|
385
|
+
cos_sin_cache: Optional[torch.Tensor] = None,
|
386
|
+
is_neox: Optional[bool] = False,
|
300
387
|
) -> torch.Tensor:
|
301
388
|
"""Run forward for decode using TRTLLM MLA kernel."""
|
389
|
+
merge_query = q_rope is not None
|
390
|
+
if self.data_type == torch.float8_e4m3fn:
|
391
|
+
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
392
|
+
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
393
|
+
assert all(
|
394
|
+
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
395
|
+
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
396
|
+
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
397
|
+
q,
|
398
|
+
q_rope,
|
399
|
+
k.squeeze(1),
|
400
|
+
k_rope.squeeze(1),
|
401
|
+
forward_batch,
|
402
|
+
cos_sin_cache,
|
403
|
+
is_neox,
|
404
|
+
)
|
405
|
+
merge_query = False
|
406
|
+
|
302
407
|
# Save KV cache if requested
|
303
|
-
if
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
408
|
+
if save_kv_cache:
|
409
|
+
assert (
|
410
|
+
k is not None and k_rope is not None
|
411
|
+
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
412
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
413
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
414
|
+
)
|
311
415
|
|
312
416
|
# Prepare query tensor inline
|
313
|
-
if
|
314
|
-
#
|
417
|
+
if merge_query:
|
418
|
+
# For FP16 path, we merge the query and rope parts into a single tensor
|
315
419
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
316
420
|
q_rope_reshaped = q_rope.view(
|
317
421
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
318
422
|
)
|
319
423
|
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
320
424
|
else:
|
321
|
-
#
|
425
|
+
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
322
426
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
323
427
|
|
324
428
|
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
@@ -327,9 +431,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
327
431
|
|
328
432
|
# Prepare KV cache inline
|
329
433
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
330
|
-
|
331
|
-
# TRT-LLM expects single KV data with extra dimension
|
332
|
-
kv_cache = pages.unsqueeze(1)
|
434
|
+
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
333
435
|
|
334
436
|
# Get metadata
|
335
437
|
metadata = (
|
@@ -337,11 +439,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
337
439
|
or self.forward_metadata
|
338
440
|
)
|
339
441
|
|
340
|
-
# Scale computation for TRTLLM MLA kernel:
|
341
|
-
#
|
342
|
-
#
|
343
|
-
# -
|
344
|
-
#
|
442
|
+
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
443
|
+
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
|
444
|
+
# Scale components:
|
445
|
+
# - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
|
446
|
+
# - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
|
447
|
+
# - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
|
448
|
+
# This unified approach works for both FP16 and FP8 quantized attention paths.
|
345
449
|
q_scale = 1.0
|
346
450
|
k_scale = (
|
347
451
|
layer.k_scale_float
|
@@ -245,6 +245,8 @@ class VisionTritonAttention(nn.Module):
|
|
245
245
|
k: torch.Tensor,
|
246
246
|
v: torch.Tensor,
|
247
247
|
cu_seqlens: Optional[torch.Tensor],
|
248
|
+
bsz: int,
|
249
|
+
seq_len: int,
|
248
250
|
**kwargs,
|
249
251
|
) -> torch.Tensor:
|
250
252
|
r"""
|
@@ -253,6 +255,8 @@ class VisionTritonAttention(nn.Module):
|
|
253
255
|
Returns:
|
254
256
|
[b * s, h, head_size]
|
255
257
|
"""
|
258
|
+
if cu_seqlens is None:
|
259
|
+
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
256
260
|
|
257
261
|
# [b * s, head, head_size]
|
258
262
|
output = torch.empty_like(q)
|
@@ -401,7 +405,11 @@ class VisionAttention(nn.Module):
|
|
401
405
|
# priority: server_args > passed qkv_backend > sdpa
|
402
406
|
if global_server_args_dict["mm_attention_backend"] is None:
|
403
407
|
if qkv_backend is None:
|
404
|
-
|
408
|
+
if is_cuda():
|
409
|
+
# Double prefill throughput by setting attn backend to Triton on CUDA
|
410
|
+
qkv_backend = "triton_attn"
|
411
|
+
else:
|
412
|
+
qkv_backend = "sdpa"
|
405
413
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
406
414
|
else:
|
407
415
|
qkv_backend = global_server_args_dict["mm_attention_backend"]
|