sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
39
39
  # compute the LCM with other padding constraints.
40
40
  TRTLLM_BLOCK_CONSTRAINT = 128
41
41
 
42
+ global_zero_init_workspace_buffer = None
43
+
42
44
 
43
45
  @dataclass
44
46
  class TRTLLMMLADecodeMetadata:
@@ -83,9 +85,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
83
85
 
84
86
  # Workspace allocation
85
87
  self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
86
- self.workspace_buffer = torch.empty(
87
- self.workspace_size, dtype=torch.int8, device=self.device
88
- )
88
+ global global_zero_init_workspace_buffer
89
+ if global_zero_init_workspace_buffer is None:
90
+ global_zero_init_workspace_buffer = torch.zeros(
91
+ self.workspace_size,
92
+ dtype=torch.uint8,
93
+ device=model_runner.device,
94
+ )
95
+ self.workspace_buffer = global_zero_init_workspace_buffer
89
96
 
90
97
  # CUDA graph state
91
98
  self.decode_cuda_graph_metadata = {}
@@ -287,38 +294,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
287
294
  )
288
295
  forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
289
296
 
297
+ def quantize_and_rope_for_fp8(
298
+ self,
299
+ q_nope: torch.Tensor,
300
+ q_rope: torch.Tensor,
301
+ k_nope: torch.Tensor,
302
+ k_rope: torch.Tensor,
303
+ forward_batch: ForwardBatch,
304
+ cos_sin_cache: torch.Tensor,
305
+ is_neox: bool,
306
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
307
+ """Quantize and apply RoPE for FP8 attention path.
308
+
309
+ This function handles the FP8 quantization and RoPE application for MLA attention.
310
+ It takes separate query/key nope and rope components, applies RoPE to the rope parts,
311
+ quantizes all components to FP8, and merges the query components into a single tensor.
312
+
313
+ Args:
314
+ q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
315
+ - expected dtype: torch.bfloat16
316
+ q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
317
+ - expected dtype: torch.bfloat16
318
+ k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
319
+ - expected dtype: torch.bfloat16
320
+ k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
321
+ - expected dtype: torch.bfloat16
322
+ forward_batch: Forward batch containing position information
323
+ cos_sin_cache: Precomputed cosine/sine cache for RoPE
324
+ - expected dtype: matches q_/k_ input dtype (torch.bfloat16)
325
+ is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
326
+
327
+ Returns:
328
+ tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
329
+ - merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
330
+ - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
331
+ - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
332
+ """
333
+ attn_dtype = torch.float8_e4m3fn
334
+ q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
335
+
336
+ # Allocate output tensors with FP8 dtype
337
+ # Query output will contain merged nope + rope components
338
+ q_out = q_rope.new_empty(
339
+ q_len,
340
+ num_heads,
341
+ self.kv_lora_rank + self.qk_rope_head_dim,
342
+ dtype=attn_dtype,
343
+ )
344
+
345
+ # Key outputs maintain original shapes but with FP8 dtype
346
+ k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
347
+ k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
348
+
349
+ # Apply RoPE and quantize all components in a single fused kernel call
350
+ # This kernel handles:
351
+ # 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
352
+ # 2. Quantization of all components to FP8 format
353
+ # 3. Output placement into pre-allocated tensors
354
+ flashinfer.rope.mla_rope_quantize_fp8(
355
+ q_rope=q_rope,
356
+ k_rope=k_rope,
357
+ q_nope=q_nope,
358
+ k_nope=k_nope,
359
+ cos_sin_cache=cos_sin_cache,
360
+ pos_ids=forward_batch.positions,
361
+ is_neox=is_neox,
362
+ quantize_dtype=attn_dtype,
363
+ # Output tensor slicing: q_out contains [nope_part, rope_part]
364
+ q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
365
+ k_rope_out=k_rope_out,
366
+ q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
367
+ k_nope_out=k_nope_out,
368
+ # Quantization scales (set to 1.0 for no additional scaling)
369
+ quant_scale_q=1.0,
370
+ quant_scale_kv=1.0,
371
+ )
372
+
373
+ return q_out, k_nope_out, k_rope_out
374
+
290
375
  def forward_decode(
291
376
  self,
292
- q: torch.Tensor,
293
- k: torch.Tensor,
294
- v: torch.Tensor,
377
+ q: torch.Tensor, # q_nope
378
+ k: torch.Tensor, # k_nope
379
+ v: torch.Tensor, # not used in this backend
295
380
  layer: RadixAttention,
296
381
  forward_batch: ForwardBatch,
297
382
  save_kv_cache: bool = True,
298
383
  q_rope: Optional[torch.Tensor] = None,
299
384
  k_rope: Optional[torch.Tensor] = None,
385
+ cos_sin_cache: Optional[torch.Tensor] = None,
386
+ is_neox: Optional[bool] = False,
300
387
  ) -> torch.Tensor:
301
388
  """Run forward for decode using TRTLLM MLA kernel."""
389
+ merge_query = q_rope is not None
390
+ if self.data_type == torch.float8_e4m3fn:
391
+ # For FP8 path, we quantize the query and rope parts and merge them into a single tensor
392
+ # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
393
+ assert all(
394
+ x is not None for x in [q_rope, k_rope, cos_sin_cache]
395
+ ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
396
+ q, k, k_rope = self.quantize_and_rope_for_fp8(
397
+ q,
398
+ q_rope,
399
+ k.squeeze(1),
400
+ k_rope.squeeze(1),
401
+ forward_batch,
402
+ cos_sin_cache,
403
+ is_neox,
404
+ )
405
+ merge_query = False
406
+
302
407
  # Save KV cache if requested
303
- if k is not None and save_kv_cache:
304
- cache_loc = forward_batch.out_cache_loc
305
- if k_rope is not None:
306
- forward_batch.token_to_kv_pool.set_mla_kv_buffer(
307
- layer, cache_loc, k, k_rope
308
- )
309
- elif v is not None:
310
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
408
+ if save_kv_cache:
409
+ assert (
410
+ k is not None and k_rope is not None
411
+ ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
412
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
413
+ layer, forward_batch.out_cache_loc, k, k_rope
414
+ )
311
415
 
312
416
  # Prepare query tensor inline
313
- if q_rope is not None:
314
- # q contains NOPE part (v_head_dim)
417
+ if merge_query:
418
+ # For FP16 path, we merge the query and rope parts into a single tensor
315
419
  q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
316
420
  q_rope_reshaped = q_rope.view(
317
421
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
318
422
  )
319
423
  query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
320
424
  else:
321
- # q already has both parts
425
+ # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
322
426
  query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
323
427
 
324
428
  # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
@@ -327,9 +431,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
327
431
 
328
432
  # Prepare KV cache inline
329
433
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
330
- pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
331
- # TRT-LLM expects single KV data with extra dimension
332
- kv_cache = pages.unsqueeze(1)
434
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
333
435
 
334
436
  # Get metadata
335
437
  metadata = (
@@ -337,11 +439,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
337
439
  or self.forward_metadata
338
440
  )
339
441
 
340
- # Scale computation for TRTLLM MLA kernel:
341
- # - BMM1 scale = q_scale * k_scale * softmax_scale
342
- # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
343
- # - k_scale is read from model checkpoint if available
344
- # TODO: Change once fp8 path is supported
442
+ # Scale computation for TRTLLM MLA kernel BMM1 operation:
443
+ # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
444
+ # Scale components:
445
+ # - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
446
+ # - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
447
+ # - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
448
+ # This unified approach works for both FP16 and FP8 quantized attention paths.
345
449
  q_scale = 1.0
346
450
  k_scale = (
347
451
  layer.k_scale_float
@@ -245,6 +245,8 @@ class VisionTritonAttention(nn.Module):
245
245
  k: torch.Tensor,
246
246
  v: torch.Tensor,
247
247
  cu_seqlens: Optional[torch.Tensor],
248
+ bsz: int,
249
+ seq_len: int,
248
250
  **kwargs,
249
251
  ) -> torch.Tensor:
250
252
  r"""
@@ -253,6 +255,8 @@ class VisionTritonAttention(nn.Module):
253
255
  Returns:
254
256
  [b * s, h, head_size]
255
257
  """
258
+ if cu_seqlens is None:
259
+ cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
256
260
 
257
261
  # [b * s, head, head_size]
258
262
  output = torch.empty_like(q)
@@ -401,7 +405,11 @@ class VisionAttention(nn.Module):
401
405
  # priority: server_args > passed qkv_backend > sdpa
402
406
  if global_server_args_dict["mm_attention_backend"] is None:
403
407
  if qkv_backend is None:
404
- qkv_backend = "sdpa"
408
+ if is_cuda():
409
+ # Double prefill throughput by setting attn backend to Triton on CUDA
410
+ qkv_backend = "triton_attn"
411
+ else:
412
+ qkv_backend = "sdpa"
405
413
  print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
406
414
  else:
407
415
  qkv_backend = global_server_args_dict["mm_attention_backend"]