sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
287
287
|
)
|
288
288
|
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
289
289
|
|
290
|
+
def quantize_and_rope_for_fp8(
|
291
|
+
self,
|
292
|
+
q_nope: torch.Tensor,
|
293
|
+
q_rope: torch.Tensor,
|
294
|
+
k_nope: torch.Tensor,
|
295
|
+
k_rope: torch.Tensor,
|
296
|
+
forward_batch: ForwardBatch,
|
297
|
+
cos_sin_cache: torch.Tensor,
|
298
|
+
is_neox: bool,
|
299
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
300
|
+
"""Quantize and apply RoPE for FP8 attention path.
|
301
|
+
|
302
|
+
This function handles the FP8 quantization and RoPE application for MLA attention.
|
303
|
+
It takes separate query/key nope and rope components, applies RoPE to the rope parts,
|
304
|
+
quantizes all components to FP8, and merges the query components into a single tensor.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
308
|
+
- expected dtype: torch.bfloat16
|
309
|
+
q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
310
|
+
- expected dtype: torch.bfloat16
|
311
|
+
k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
312
|
+
- expected dtype: torch.bfloat16
|
313
|
+
k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
314
|
+
- expected dtype: torch.bfloat16
|
315
|
+
forward_batch: Forward batch containing position information
|
316
|
+
cos_sin_cache: Precomputed cosine/sine cache for RoPE
|
317
|
+
- expected dtype: matches q_/k_ input dtype (torch.bfloat16)
|
318
|
+
is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
|
322
|
+
- merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
323
|
+
- k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
|
324
|
+
- k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
325
|
+
"""
|
326
|
+
attn_dtype = torch.float8_e4m3fn
|
327
|
+
q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
|
328
|
+
|
329
|
+
# Allocate output tensors with FP8 dtype
|
330
|
+
# Query output will contain merged nope + rope components
|
331
|
+
q_out = q_rope.new_empty(
|
332
|
+
q_len,
|
333
|
+
num_heads,
|
334
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
335
|
+
dtype=attn_dtype,
|
336
|
+
)
|
337
|
+
|
338
|
+
# Key outputs maintain original shapes but with FP8 dtype
|
339
|
+
k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
|
340
|
+
k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
|
341
|
+
|
342
|
+
# Apply RoPE and quantize all components in a single fused kernel call
|
343
|
+
# This kernel handles:
|
344
|
+
# 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
|
345
|
+
# 2. Quantization of all components to FP8 format
|
346
|
+
# 3. Output placement into pre-allocated tensors
|
347
|
+
flashinfer.rope.mla_rope_quantize_fp8(
|
348
|
+
q_rope=q_rope,
|
349
|
+
k_rope=k_rope,
|
350
|
+
q_nope=q_nope,
|
351
|
+
k_nope=k_nope,
|
352
|
+
cos_sin_cache=cos_sin_cache,
|
353
|
+
pos_ids=forward_batch.positions,
|
354
|
+
is_neox=is_neox,
|
355
|
+
quantize_dtype=attn_dtype,
|
356
|
+
# Output tensor slicing: q_out contains [nope_part, rope_part]
|
357
|
+
q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
|
358
|
+
k_rope_out=k_rope_out,
|
359
|
+
q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
|
360
|
+
k_nope_out=k_nope_out,
|
361
|
+
# Quantization scales (set to 1.0 for no additional scaling)
|
362
|
+
quant_scale_q=1.0,
|
363
|
+
quant_scale_kv=1.0,
|
364
|
+
)
|
365
|
+
|
366
|
+
return q_out, k_nope_out, k_rope_out
|
367
|
+
|
290
368
|
def forward_decode(
|
291
369
|
self,
|
292
|
-
q: torch.Tensor,
|
293
|
-
k: torch.Tensor,
|
294
|
-
v: torch.Tensor,
|
370
|
+
q: torch.Tensor, # q_nope
|
371
|
+
k: torch.Tensor, # k_nope
|
372
|
+
v: torch.Tensor, # not used in this backend
|
295
373
|
layer: RadixAttention,
|
296
374
|
forward_batch: ForwardBatch,
|
297
375
|
save_kv_cache: bool = True,
|
298
376
|
q_rope: Optional[torch.Tensor] = None,
|
299
377
|
k_rope: Optional[torch.Tensor] = None,
|
378
|
+
cos_sin_cache: Optional[torch.Tensor] = None,
|
379
|
+
is_neox: Optional[bool] = False,
|
300
380
|
) -> torch.Tensor:
|
301
381
|
"""Run forward for decode using TRTLLM MLA kernel."""
|
382
|
+
merge_query = q_rope is not None
|
383
|
+
if self.data_type == torch.float8_e4m3fn:
|
384
|
+
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
385
|
+
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
386
|
+
assert all(
|
387
|
+
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
388
|
+
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
389
|
+
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
390
|
+
q,
|
391
|
+
q_rope,
|
392
|
+
k.squeeze(1),
|
393
|
+
k_rope.squeeze(1),
|
394
|
+
forward_batch,
|
395
|
+
cos_sin_cache,
|
396
|
+
is_neox,
|
397
|
+
)
|
398
|
+
merge_query = False
|
399
|
+
|
302
400
|
# Save KV cache if requested
|
303
|
-
if
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
401
|
+
if save_kv_cache:
|
402
|
+
assert (
|
403
|
+
k is not None and k_rope is not None
|
404
|
+
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
405
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
406
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
407
|
+
)
|
311
408
|
|
312
409
|
# Prepare query tensor inline
|
313
|
-
if
|
314
|
-
#
|
410
|
+
if merge_query:
|
411
|
+
# For FP16 path, we merge the query and rope parts into a single tensor
|
315
412
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
316
413
|
q_rope_reshaped = q_rope.view(
|
317
414
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
318
415
|
)
|
319
416
|
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
320
417
|
else:
|
321
|
-
#
|
418
|
+
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
322
419
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
323
420
|
|
324
421
|
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
327
424
|
|
328
425
|
# Prepare KV cache inline
|
329
426
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
330
|
-
|
331
|
-
# TRT-LLM expects single KV data with extra dimension
|
332
|
-
kv_cache = pages.unsqueeze(1)
|
427
|
+
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
333
428
|
|
334
429
|
# Get metadata
|
335
430
|
metadata = (
|
@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
337
432
|
or self.forward_metadata
|
338
433
|
)
|
339
434
|
|
340
|
-
# Scale computation for TRTLLM MLA kernel:
|
341
|
-
#
|
342
|
-
#
|
343
|
-
# -
|
344
|
-
#
|
435
|
+
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
436
|
+
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
|
437
|
+
# Scale components:
|
438
|
+
# - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
|
439
|
+
# - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
|
440
|
+
# - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
|
441
|
+
# This unified approach works for both FP16 and FP8 quantized attention paths.
|
345
442
|
q_scale = 1.0
|
346
443
|
k_scale = (
|
347
444
|
layer.k_scale_float
|
@@ -11,6 +11,7 @@ import torch.nn as nn
|
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from einops import rearrange
|
13
13
|
|
14
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
14
15
|
from sglang.srt.utils import is_cuda, print_info_once
|
15
16
|
|
16
17
|
_is_cuda = is_cuda()
|
@@ -244,6 +245,8 @@ class VisionTritonAttention(nn.Module):
|
|
244
245
|
k: torch.Tensor,
|
245
246
|
v: torch.Tensor,
|
246
247
|
cu_seqlens: Optional[torch.Tensor],
|
248
|
+
bsz: int,
|
249
|
+
seq_len: int,
|
247
250
|
**kwargs,
|
248
251
|
) -> torch.Tensor:
|
249
252
|
r"""
|
@@ -252,6 +255,8 @@ class VisionTritonAttention(nn.Module):
|
|
252
255
|
Returns:
|
253
256
|
[b * s, h, head_size]
|
254
257
|
"""
|
258
|
+
if cu_seqlens is None:
|
259
|
+
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
255
260
|
|
256
261
|
# [b * s, head, head_size]
|
257
262
|
output = torch.empty_like(q)
|
@@ -365,19 +370,20 @@ class VisionAttention(nn.Module):
|
|
365
370
|
**kwargs,
|
366
371
|
):
|
367
372
|
super().__init__()
|
368
|
-
|
369
|
-
|
370
|
-
self.
|
373
|
+
attn_tp_rank = get_attention_tp_rank()
|
374
|
+
attn_tp_size = get_attention_tp_size()
|
375
|
+
self.tp_size = attn_tp_size
|
376
|
+
self.tp_rank = attn_tp_rank
|
371
377
|
self.dropout = dropout
|
372
378
|
self.head_size = embed_dim // num_heads
|
373
379
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
374
380
|
projection_size, num_heads
|
375
381
|
)
|
376
382
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
377
|
-
num_dummy_heads + num_heads,
|
383
|
+
num_dummy_heads + num_heads, self.tp_size
|
378
384
|
)
|
379
385
|
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
380
|
-
num_dummy_heads + num_heads,
|
386
|
+
num_dummy_heads + num_heads, self.tp_size
|
381
387
|
)
|
382
388
|
|
383
389
|
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
@@ -399,7 +405,11 @@ class VisionAttention(nn.Module):
|
|
399
405
|
# priority: server_args > passed qkv_backend > sdpa
|
400
406
|
if global_server_args_dict["mm_attention_backend"] is None:
|
401
407
|
if qkv_backend is None:
|
402
|
-
|
408
|
+
if is_cuda():
|
409
|
+
# Double prefill throughput by setting attn backend to Triton on CUDA
|
410
|
+
qkv_backend = "triton_attn"
|
411
|
+
else:
|
412
|
+
qkv_backend = "sdpa"
|
403
413
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
404
414
|
else:
|
405
415
|
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
@@ -427,6 +437,8 @@ class VisionAttention(nn.Module):
|
|
427
437
|
total_num_kv_heads=num_dummy_heads + num_heads,
|
428
438
|
bias=qkv_bias,
|
429
439
|
quant_config=quant_config,
|
440
|
+
tp_rank=self.tp_rank,
|
441
|
+
tp_size=self.tp_size,
|
430
442
|
prefix=add_prefix("qkv_proj", prefix),
|
431
443
|
)
|
432
444
|
else:
|
@@ -435,6 +447,8 @@ class VisionAttention(nn.Module):
|
|
435
447
|
output_size=3 * self.dummy_dim,
|
436
448
|
bias=qkv_bias,
|
437
449
|
quant_config=quant_config,
|
450
|
+
tp_rank=self.tp_rank,
|
451
|
+
tp_size=self.tp_size,
|
438
452
|
prefix=add_prefix("qkv_proj", prefix),
|
439
453
|
)
|
440
454
|
self.proj = RowParallelLinear(
|
@@ -442,6 +456,8 @@ class VisionAttention(nn.Module):
|
|
442
456
|
output_size=embed_dim,
|
443
457
|
bias=proj_bias,
|
444
458
|
quant_config=quant_config,
|
459
|
+
tp_rank=self.tp_rank,
|
460
|
+
tp_size=self.tp_size,
|
445
461
|
prefix=add_prefix("proj", prefix),
|
446
462
|
)
|
447
463
|
|