sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
287
287
  )
288
288
  forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
289
289
 
290
+ def quantize_and_rope_for_fp8(
291
+ self,
292
+ q_nope: torch.Tensor,
293
+ q_rope: torch.Tensor,
294
+ k_nope: torch.Tensor,
295
+ k_rope: torch.Tensor,
296
+ forward_batch: ForwardBatch,
297
+ cos_sin_cache: torch.Tensor,
298
+ is_neox: bool,
299
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
300
+ """Quantize and apply RoPE for FP8 attention path.
301
+
302
+ This function handles the FP8 quantization and RoPE application for MLA attention.
303
+ It takes separate query/key nope and rope components, applies RoPE to the rope parts,
304
+ quantizes all components to FP8, and merges the query components into a single tensor.
305
+
306
+ Args:
307
+ q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
308
+ - expected dtype: torch.bfloat16
309
+ q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
310
+ - expected dtype: torch.bfloat16
311
+ k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
312
+ - expected dtype: torch.bfloat16
313
+ k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
314
+ - expected dtype: torch.bfloat16
315
+ forward_batch: Forward batch containing position information
316
+ cos_sin_cache: Precomputed cosine/sine cache for RoPE
317
+ - expected dtype: matches q_/k_ input dtype (torch.bfloat16)
318
+ is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
319
+
320
+ Returns:
321
+ tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
322
+ - merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
323
+ - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
324
+ - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
325
+ """
326
+ attn_dtype = torch.float8_e4m3fn
327
+ q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
328
+
329
+ # Allocate output tensors with FP8 dtype
330
+ # Query output will contain merged nope + rope components
331
+ q_out = q_rope.new_empty(
332
+ q_len,
333
+ num_heads,
334
+ self.kv_lora_rank + self.qk_rope_head_dim,
335
+ dtype=attn_dtype,
336
+ )
337
+
338
+ # Key outputs maintain original shapes but with FP8 dtype
339
+ k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
340
+ k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
341
+
342
+ # Apply RoPE and quantize all components in a single fused kernel call
343
+ # This kernel handles:
344
+ # 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
345
+ # 2. Quantization of all components to FP8 format
346
+ # 3. Output placement into pre-allocated tensors
347
+ flashinfer.rope.mla_rope_quantize_fp8(
348
+ q_rope=q_rope,
349
+ k_rope=k_rope,
350
+ q_nope=q_nope,
351
+ k_nope=k_nope,
352
+ cos_sin_cache=cos_sin_cache,
353
+ pos_ids=forward_batch.positions,
354
+ is_neox=is_neox,
355
+ quantize_dtype=attn_dtype,
356
+ # Output tensor slicing: q_out contains [nope_part, rope_part]
357
+ q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
358
+ k_rope_out=k_rope_out,
359
+ q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
360
+ k_nope_out=k_nope_out,
361
+ # Quantization scales (set to 1.0 for no additional scaling)
362
+ quant_scale_q=1.0,
363
+ quant_scale_kv=1.0,
364
+ )
365
+
366
+ return q_out, k_nope_out, k_rope_out
367
+
290
368
  def forward_decode(
291
369
  self,
292
- q: torch.Tensor,
293
- k: torch.Tensor,
294
- v: torch.Tensor,
370
+ q: torch.Tensor, # q_nope
371
+ k: torch.Tensor, # k_nope
372
+ v: torch.Tensor, # not used in this backend
295
373
  layer: RadixAttention,
296
374
  forward_batch: ForwardBatch,
297
375
  save_kv_cache: bool = True,
298
376
  q_rope: Optional[torch.Tensor] = None,
299
377
  k_rope: Optional[torch.Tensor] = None,
378
+ cos_sin_cache: Optional[torch.Tensor] = None,
379
+ is_neox: Optional[bool] = False,
300
380
  ) -> torch.Tensor:
301
381
  """Run forward for decode using TRTLLM MLA kernel."""
382
+ merge_query = q_rope is not None
383
+ if self.data_type == torch.float8_e4m3fn:
384
+ # For FP8 path, we quantize the query and rope parts and merge them into a single tensor
385
+ # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
386
+ assert all(
387
+ x is not None for x in [q_rope, k_rope, cos_sin_cache]
388
+ ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
389
+ q, k, k_rope = self.quantize_and_rope_for_fp8(
390
+ q,
391
+ q_rope,
392
+ k.squeeze(1),
393
+ k_rope.squeeze(1),
394
+ forward_batch,
395
+ cos_sin_cache,
396
+ is_neox,
397
+ )
398
+ merge_query = False
399
+
302
400
  # Save KV cache if requested
303
- if k is not None and save_kv_cache:
304
- cache_loc = forward_batch.out_cache_loc
305
- if k_rope is not None:
306
- forward_batch.token_to_kv_pool.set_mla_kv_buffer(
307
- layer, cache_loc, k, k_rope
308
- )
309
- elif v is not None:
310
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
401
+ if save_kv_cache:
402
+ assert (
403
+ k is not None and k_rope is not None
404
+ ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
405
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
406
+ layer, forward_batch.out_cache_loc, k, k_rope
407
+ )
311
408
 
312
409
  # Prepare query tensor inline
313
- if q_rope is not None:
314
- # q contains NOPE part (v_head_dim)
410
+ if merge_query:
411
+ # For FP16 path, we merge the query and rope parts into a single tensor
315
412
  q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
316
413
  q_rope_reshaped = q_rope.view(
317
414
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
318
415
  )
319
416
  query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
320
417
  else:
321
- # q already has both parts
418
+ # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
322
419
  query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
323
420
 
324
421
  # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
327
424
 
328
425
  # Prepare KV cache inline
329
426
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
330
- pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
331
- # TRT-LLM expects single KV data with extra dimension
332
- kv_cache = pages.unsqueeze(1)
427
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
333
428
 
334
429
  # Get metadata
335
430
  metadata = (
@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
337
432
  or self.forward_metadata
338
433
  )
339
434
 
340
- # Scale computation for TRTLLM MLA kernel:
341
- # - BMM1 scale = q_scale * k_scale * softmax_scale
342
- # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
343
- # - k_scale is read from model checkpoint if available
344
- # TODO: Change once fp8 path is supported
435
+ # Scale computation for TRTLLM MLA kernel BMM1 operation:
436
+ # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
437
+ # Scale components:
438
+ # - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
439
+ # - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
440
+ # - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
441
+ # This unified approach works for both FP16 and FP8 quantized attention paths.
345
442
  q_scale = 1.0
346
443
  k_scale = (
347
444
  layer.k_scale_float
@@ -11,6 +11,7 @@ import torch.nn as nn
11
11
  import torch.nn.functional as F
12
12
  from einops import rearrange
13
13
 
14
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
14
15
  from sglang.srt.utils import is_cuda, print_info_once
15
16
 
16
17
  _is_cuda = is_cuda()
@@ -244,6 +245,8 @@ class VisionTritonAttention(nn.Module):
244
245
  k: torch.Tensor,
245
246
  v: torch.Tensor,
246
247
  cu_seqlens: Optional[torch.Tensor],
248
+ bsz: int,
249
+ seq_len: int,
247
250
  **kwargs,
248
251
  ) -> torch.Tensor:
249
252
  r"""
@@ -252,6 +255,8 @@ class VisionTritonAttention(nn.Module):
252
255
  Returns:
253
256
  [b * s, h, head_size]
254
257
  """
258
+ if cu_seqlens is None:
259
+ cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
255
260
 
256
261
  # [b * s, head, head_size]
257
262
  output = torch.empty_like(q)
@@ -365,19 +370,20 @@ class VisionAttention(nn.Module):
365
370
  **kwargs,
366
371
  ):
367
372
  super().__init__()
368
- world_size = parallel_state.get_tensor_model_parallel_world_size()
369
- self.tp_size = world_size
370
- self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
373
+ attn_tp_rank = get_attention_tp_rank()
374
+ attn_tp_size = get_attention_tp_size()
375
+ self.tp_size = attn_tp_size
376
+ self.tp_rank = attn_tp_rank
371
377
  self.dropout = dropout
372
378
  self.head_size = embed_dim // num_heads
373
379
  self.hidden_size_per_attention_head = dist_utils.divide(
374
380
  projection_size, num_heads
375
381
  )
376
382
  self.num_attention_heads_per_partition = dist_utils.divide(
377
- num_dummy_heads + num_heads, world_size
383
+ num_dummy_heads + num_heads, self.tp_size
378
384
  )
379
385
  self.num_attention_kv_heads_per_partition = dist_utils.divide(
380
- num_dummy_heads + num_heads, world_size
386
+ num_dummy_heads + num_heads, self.tp_size
381
387
  )
382
388
 
383
389
  self.q_size = self.num_attention_heads_per_partition * self.head_size
@@ -399,7 +405,11 @@ class VisionAttention(nn.Module):
399
405
  # priority: server_args > passed qkv_backend > sdpa
400
406
  if global_server_args_dict["mm_attention_backend"] is None:
401
407
  if qkv_backend is None:
402
- qkv_backend = "sdpa"
408
+ if is_cuda():
409
+ # Double prefill throughput by setting attn backend to Triton on CUDA
410
+ qkv_backend = "triton_attn"
411
+ else:
412
+ qkv_backend = "sdpa"
403
413
  print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
404
414
  else:
405
415
  qkv_backend = global_server_args_dict["mm_attention_backend"]
@@ -427,6 +437,8 @@ class VisionAttention(nn.Module):
427
437
  total_num_kv_heads=num_dummy_heads + num_heads,
428
438
  bias=qkv_bias,
429
439
  quant_config=quant_config,
440
+ tp_rank=self.tp_rank,
441
+ tp_size=self.tp_size,
430
442
  prefix=add_prefix("qkv_proj", prefix),
431
443
  )
432
444
  else:
@@ -435,6 +447,8 @@ class VisionAttention(nn.Module):
435
447
  output_size=3 * self.dummy_dim,
436
448
  bias=qkv_bias,
437
449
  quant_config=quant_config,
450
+ tp_rank=self.tp_rank,
451
+ tp_size=self.tp_size,
438
452
  prefix=add_prefix("qkv_proj", prefix),
439
453
  )
440
454
  self.proj = RowParallelLinear(
@@ -442,6 +456,8 @@ class VisionAttention(nn.Module):
442
456
  output_size=embed_dim,
443
457
  bias=proj_bias,
444
458
  quant_config=quant_config,
459
+ tp_rank=self.tp_rank,
460
+ tp_size=self.tp_size,
445
461
  prefix=add_prefix("proj", prefix),
446
462
  )
447
463