sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import (
|
|
22
22
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
23
23
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
24
24
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
25
|
-
from sglang.srt.utils import is_flashinfer_available
|
25
|
+
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
26
26
|
|
27
27
|
if is_flashinfer_available():
|
28
28
|
import flashinfer
|
@@ -30,7 +30,12 @@ if is_flashinfer_available():
|
|
30
30
|
if TYPE_CHECKING:
|
31
31
|
from sglang.srt.layers.radix_attention import RadixAttention
|
32
32
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
33
|
-
from sglang.srt.speculative.spec_info import
|
33
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
34
|
+
|
35
|
+
_is_cuda = is_cuda()
|
36
|
+
|
37
|
+
if _is_cuda:
|
38
|
+
from sgl_kernel import concat_mla_absorb_q
|
34
39
|
|
35
40
|
# Constants
|
36
41
|
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
@@ -122,6 +127,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
122
127
|
"disable_chunked_prefix_cache"
|
123
128
|
]
|
124
129
|
|
130
|
+
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
131
|
+
|
125
132
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
126
133
|
"""
|
127
134
|
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
@@ -207,12 +214,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
207
214
|
seq_lens: torch.Tensor,
|
208
215
|
encoder_lens: Optional[torch.Tensor],
|
209
216
|
forward_mode: ForwardMode,
|
210
|
-
spec_info: Optional[
|
217
|
+
spec_info: Optional[SpecInput],
|
211
218
|
):
|
212
219
|
"""Initialize metadata for CUDA graph capture."""
|
213
220
|
|
214
221
|
# Delegate to parent for non-decode modes.
|
215
|
-
if not forward_mode.is_decode_or_idle():
|
222
|
+
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
216
223
|
return super().init_forward_metadata_capture_cuda_graph(
|
217
224
|
bs,
|
218
225
|
num_tokens,
|
@@ -223,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
223
230
|
spec_info,
|
224
231
|
)
|
225
232
|
|
233
|
+
if forward_mode.is_target_verify():
|
234
|
+
seq_lens = seq_lens + self.num_draft_tokens
|
235
|
+
|
226
236
|
# Custom fast-path for decode/idle.
|
227
237
|
# Capture with full width so future longer sequences are safe during replay
|
228
238
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
@@ -260,12 +270,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
260
270
|
seq_lens_sum: int,
|
261
271
|
encoder_lens: Optional[torch.Tensor],
|
262
272
|
forward_mode: ForwardMode,
|
263
|
-
spec_info: Optional[
|
273
|
+
spec_info: Optional[SpecInput],
|
264
274
|
seq_lens_cpu: Optional[torch.Tensor],
|
265
275
|
):
|
266
276
|
"""Replay CUDA graph with new inputs."""
|
267
277
|
# Delegate to parent for non-decode modes.
|
268
|
-
if not forward_mode.is_decode_or_idle():
|
278
|
+
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
269
279
|
return super().init_forward_metadata_replay_cuda_graph(
|
270
280
|
bs,
|
271
281
|
req_pool_indices,
|
@@ -277,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
277
287
|
seq_lens_cpu,
|
278
288
|
)
|
279
289
|
|
290
|
+
if forward_mode.is_target_verify():
|
291
|
+
seq_lens = seq_lens + self.num_draft_tokens
|
292
|
+
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
|
293
|
+
|
280
294
|
metadata = self.decode_cuda_graph_metadata[bs]
|
281
295
|
|
282
296
|
# Update block indices for new sequences.
|
@@ -327,7 +341,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
327
341
|
cum_seq_lens_q,
|
328
342
|
seq_lens,
|
329
343
|
)
|
330
|
-
elif
|
344
|
+
elif (
|
345
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
346
|
+
or forward_batch.forward_mode.is_target_verify()
|
347
|
+
):
|
331
348
|
bs = forward_batch.batch_size
|
332
349
|
|
333
350
|
# Get maximum sequence length.
|
@@ -336,13 +353,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
336
353
|
else:
|
337
354
|
max_seq = forward_batch.seq_lens.max().item()
|
338
355
|
|
356
|
+
seq_lens = forward_batch.seq_lens
|
357
|
+
|
358
|
+
if forward_batch.forward_mode.is_target_verify():
|
359
|
+
max_seq = max_seq + self.num_draft_tokens
|
360
|
+
seq_lens = seq_lens + self.num_draft_tokens
|
361
|
+
|
339
362
|
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
340
363
|
block_kv_indices = self._create_block_kv_indices(
|
341
364
|
bs,
|
342
365
|
max_seqlen_pad,
|
343
366
|
forward_batch.req_pool_indices,
|
344
|
-
|
345
|
-
|
367
|
+
seq_lens,
|
368
|
+
seq_lens.device,
|
346
369
|
)
|
347
370
|
|
348
371
|
max_seq_len_val = int(max_seq)
|
@@ -482,7 +505,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
482
505
|
q_rope_reshaped = q_rope.view(
|
483
506
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
484
507
|
)
|
485
|
-
query =
|
508
|
+
query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
486
509
|
else:
|
487
510
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
488
511
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
@@ -545,49 +568,163 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
545
568
|
save_kv_cache: bool = True,
|
546
569
|
q_rope: Optional[torch.Tensor] = None,
|
547
570
|
k_rope: Optional[torch.Tensor] = None,
|
571
|
+
cos_sin_cache: Optional[torch.Tensor] = None,
|
572
|
+
is_neox: Optional[bool] = False,
|
548
573
|
) -> torch.Tensor:
|
549
|
-
if (
|
550
|
-
forward_batch.forward_mode.is_target_verify()
|
551
|
-
or forward_batch.forward_mode.is_draft_extend()
|
552
|
-
):
|
574
|
+
if forward_batch.forward_mode.is_draft_extend():
|
553
575
|
return super().forward_extend(
|
554
576
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
555
577
|
)
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
578
|
+
|
579
|
+
# TODO refactor to avoid code duplication
|
580
|
+
merge_query = q_rope is not None
|
581
|
+
if (
|
582
|
+
self.data_type == torch.float8_e4m3fn
|
583
|
+
) and forward_batch.forward_mode.is_target_verify():
|
584
|
+
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
585
|
+
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
586
|
+
assert all(
|
587
|
+
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
588
|
+
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
589
|
+
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
590
|
+
q,
|
591
|
+
q_rope,
|
592
|
+
k.squeeze(1),
|
593
|
+
k_rope.squeeze(1),
|
594
|
+
forward_batch,
|
595
|
+
cos_sin_cache,
|
596
|
+
is_neox,
|
597
|
+
)
|
598
|
+
merge_query = False
|
599
|
+
|
600
|
+
# Save KV cache if requested
|
601
|
+
if save_kv_cache:
|
602
|
+
assert (
|
603
|
+
k is not None and k_rope is not None
|
604
|
+
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
605
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
606
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
560
607
|
)
|
561
608
|
|
562
|
-
|
609
|
+
# TODO refactor to avoid code duplication
|
610
|
+
# Prepare query tensor inline
|
611
|
+
if merge_query:
|
612
|
+
# For FP16 path, we merge the query and rope parts into a single tensor
|
613
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
614
|
+
q_rope_reshaped = q_rope.view(
|
615
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
616
|
+
)
|
617
|
+
q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
618
|
+
else:
|
619
|
+
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
563
620
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
564
|
-
|
565
|
-
|
566
|
-
|
621
|
+
|
622
|
+
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
623
|
+
|
624
|
+
if k_rope is not None:
|
625
|
+
k = torch.cat([k, k_rope], dim=-1)
|
626
|
+
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
627
|
+
|
628
|
+
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
629
|
+
|
630
|
+
if forward_batch.forward_mode.is_target_verify():
|
631
|
+
metadata = (
|
632
|
+
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
633
|
+
or self.forward_decode_metadata
|
634
|
+
)
|
635
|
+
|
636
|
+
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
637
|
+
bs = forward_batch.batch_size
|
638
|
+
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
639
|
+
|
640
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
641
|
+
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
642
|
+
|
643
|
+
q_scale = 1.0
|
644
|
+
k_scale = (
|
645
|
+
layer.k_scale_float
|
646
|
+
if getattr(layer, "k_scale_float", None) is not None
|
647
|
+
else 1.0
|
648
|
+
)
|
649
|
+
|
650
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
651
|
+
|
652
|
+
seq_lens = (
|
653
|
+
forward_batch.seq_lens.to(torch.int32)
|
654
|
+
+ forward_batch.spec_info.draft_token_num
|
655
|
+
)
|
656
|
+
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
|
657
|
+
|
658
|
+
# TODO may use `mla_rope_quantize_fp8` fusion
|
659
|
+
q = q.to(self.data_type)
|
660
|
+
assert kv_cache.dtype == self.data_type
|
661
|
+
|
662
|
+
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
663
|
+
query=q,
|
664
|
+
kv_cache=kv_cache,
|
665
|
+
workspace_buffer=self.workspace_buffer,
|
666
|
+
qk_nope_head_dim=self.qk_nope_head_dim,
|
667
|
+
kv_lora_rank=self.kv_lora_rank,
|
668
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
669
|
+
block_tables=metadata.block_kv_indices,
|
670
|
+
seq_lens=seq_lens,
|
671
|
+
max_seq_len=max_seq_len,
|
672
|
+
bmm1_scale=bmm1_scale,
|
673
|
+
)
|
674
|
+
|
675
|
+
# Reshape output directly without slicing
|
676
|
+
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
677
|
+
return output
|
678
|
+
|
679
|
+
if forward_batch.attn_attend_prefix_cache:
|
680
|
+
# MHA for chunked prefix kv cache when running model with MLA
|
681
|
+
assert forward_batch.prefix_chunk_idx is not None
|
682
|
+
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
683
|
+
assert q_rope is None
|
684
|
+
assert k_rope is None
|
685
|
+
chunk_idx = forward_batch.prefix_chunk_idx
|
686
|
+
|
687
|
+
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
688
|
+
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
567
689
|
query=q,
|
568
690
|
key=k,
|
569
691
|
value=v,
|
570
692
|
workspace_buffer=self.workspace_buffer,
|
571
|
-
seq_lens=
|
693
|
+
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
572
694
|
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
573
|
-
max_kv_len=
|
695
|
+
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
574
696
|
bmm1_scale=layer.scaling,
|
575
697
|
bmm2_scale=1.0,
|
576
|
-
o_sf_scale
|
698
|
+
o_sf_scale=-1.0,
|
577
699
|
batch_size=forward_batch.batch_size,
|
578
700
|
window_left=-1,
|
579
701
|
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
580
|
-
cum_seq_lens_kv=
|
702
|
+
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
581
703
|
enable_pdl=False,
|
582
|
-
is_causal=
|
583
|
-
return_lse=
|
704
|
+
is_causal=False,
|
705
|
+
return_lse=True,
|
706
|
+
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
584
707
|
)
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
708
|
+
|
709
|
+
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
710
|
+
query=q,
|
711
|
+
key=k,
|
712
|
+
value=v,
|
713
|
+
workspace_buffer=self.workspace_buffer,
|
714
|
+
seq_lens=self.forward_prefill_metadata.seq_lens,
|
715
|
+
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
716
|
+
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
717
|
+
bmm1_scale=layer.scaling,
|
718
|
+
bmm2_scale=1.0,
|
719
|
+
o_sf_scale=1.0,
|
720
|
+
batch_size=forward_batch.batch_size,
|
721
|
+
window_left=-1,
|
722
|
+
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
723
|
+
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
724
|
+
enable_pdl=False,
|
725
|
+
is_causal=True,
|
726
|
+
return_lse=forward_batch.mha_return_lse,
|
727
|
+
)
|
591
728
|
|
592
729
|
|
593
730
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
@@ -605,3 +742,10 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
|
605
742
|
kv_indptr_buf=self.kv_indptr[i],
|
606
743
|
q_indptr_decode_buf=self.q_indptr_decode,
|
607
744
|
)
|
745
|
+
|
746
|
+
|
747
|
+
def _concat_mla_absorb_q_general(q_nope, q_rope):
|
748
|
+
if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
|
749
|
+
return concat_mla_absorb_q(q_nope, q_rope)
|
750
|
+
else:
|
751
|
+
return torch.cat([q_nope, q_rope], dim=-1)
|
@@ -16,14 +16,19 @@ from sglang.srt.utils import (
|
|
16
16
|
get_device_capability,
|
17
17
|
is_blackwell,
|
18
18
|
is_cuda,
|
19
|
+
is_npu,
|
19
20
|
print_info_once,
|
20
21
|
)
|
21
22
|
|
22
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_npu = is_npu()
|
23
25
|
|
24
26
|
if _is_cuda:
|
25
27
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
26
28
|
|
29
|
+
if _is_npu:
|
30
|
+
import torch_npu
|
31
|
+
|
27
32
|
from sglang.srt.distributed import (
|
28
33
|
split_tensor_along_last_dim,
|
29
34
|
tensor_model_parallel_all_gather,
|
@@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module):
|
|
331
336
|
return output
|
332
337
|
|
333
338
|
|
339
|
+
class VisionAscendAttention(nn.Module):
|
340
|
+
|
341
|
+
def __init__(
|
342
|
+
self,
|
343
|
+
**kwargs,
|
344
|
+
):
|
345
|
+
if not _is_npu:
|
346
|
+
raise Exception("VisionAscendAttention is only available for ascend npu")
|
347
|
+
super().__init__()
|
348
|
+
|
349
|
+
def forward(
|
350
|
+
self,
|
351
|
+
q: torch.Tensor,
|
352
|
+
k: torch.Tensor,
|
353
|
+
v: torch.Tensor,
|
354
|
+
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
355
|
+
bsz: int,
|
356
|
+
seq_len: int,
|
357
|
+
**kwargs,
|
358
|
+
) -> torch.Tensor:
|
359
|
+
r"""
|
360
|
+
Args:
|
361
|
+
cu_seqlens: [b]
|
362
|
+
Returns:
|
363
|
+
[b * s, h, head_size]
|
364
|
+
"""
|
365
|
+
if cu_seqlens is None:
|
366
|
+
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
367
|
+
|
368
|
+
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
369
|
+
if seq_lens.is_npu:
|
370
|
+
# cu_seqlens must be on cpu because of operator restriction
|
371
|
+
seq_lens = seq_lens.to("cpu")
|
372
|
+
_, num_heads, head_size = q.shape
|
373
|
+
num_kv_heads = k.shape[1]
|
374
|
+
output = torch.empty_like(q)
|
375
|
+
|
376
|
+
# operator requires pta version >= 2.5.1
|
377
|
+
torch_npu._npu_flash_attention_unpad(
|
378
|
+
query=q,
|
379
|
+
key=k,
|
380
|
+
value=v,
|
381
|
+
seq_len=seq_lens.to(torch.int32),
|
382
|
+
scale_value=head_size**-0.5,
|
383
|
+
num_heads=num_heads,
|
384
|
+
num_kv_heads=num_kv_heads,
|
385
|
+
out=output,
|
386
|
+
)
|
387
|
+
|
388
|
+
return output
|
389
|
+
|
390
|
+
|
334
391
|
QKV_BACKEND_IMPL = {
|
335
392
|
"triton_attn": VisionTritonAttention,
|
336
393
|
"sdpa": VisionSdpaAttention,
|
337
394
|
"fa3": VisionFlash3Attention,
|
395
|
+
"ascend_attn": VisionAscendAttention,
|
338
396
|
}
|
339
397
|
|
340
398
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import TYPE_CHECKING, Optional
|
5
|
+
from typing import TYPE_CHECKING, Optional
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import triton
|
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
19
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
20
|
-
from sglang.srt.speculative.
|
20
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
|
|
393
393
|
seq_lens: torch.Tensor,
|
394
394
|
encoder_lens: Optional[torch.Tensor],
|
395
395
|
forward_mode: ForwardMode,
|
396
|
-
spec_info: Optional[
|
396
|
+
spec_info: Optional[SpecInput],
|
397
397
|
):
|
398
398
|
assert encoder_lens is None, "Not supported"
|
399
399
|
|
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
|
|
477
477
|
seq_lens_sum: int,
|
478
478
|
encoder_lens: Optional[torch.Tensor],
|
479
479
|
forward_mode: ForwardMode,
|
480
|
-
spec_info: Optional[
|
480
|
+
spec_info: Optional[SpecInput],
|
481
481
|
seq_lens_cpu: Optional[torch.Tensor],
|
482
482
|
):
|
483
483
|
# NOTE: encoder_lens expected to be zeros or None
|
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
|
|
50
50
|
is_hip,
|
51
51
|
is_sm90_supported,
|
52
52
|
is_sm100_supported,
|
53
|
+
prepare_weight_cache,
|
53
54
|
)
|
54
55
|
|
55
56
|
_is_flashinfer_available = is_flashinfer_available()
|
@@ -275,7 +276,11 @@ class LayerCommunicator:
|
|
275
276
|
hidden_states: torch.Tensor,
|
276
277
|
residual: torch.Tensor,
|
277
278
|
forward_batch: ForwardBatch,
|
279
|
+
cache=None,
|
278
280
|
):
|
281
|
+
if cache is not None:
|
282
|
+
self._context.cache = cache
|
283
|
+
|
279
284
|
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
280
285
|
hidden_states=hidden_states,
|
281
286
|
residual=residual,
|
@@ -349,6 +354,7 @@ class CommunicateContext:
|
|
349
354
|
attn_tp_size: int
|
350
355
|
attn_dp_size: int
|
351
356
|
tp_size: int
|
357
|
+
cache = None
|
352
358
|
|
353
359
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
354
360
|
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
533
539
|
)
|
534
540
|
else:
|
535
541
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
542
|
+
if context.cache is not None:
|
543
|
+
_ = prepare_weight_cache(hidden_states, context.cache)
|
536
544
|
hidden_states, residual = layernorm(hidden_states, residual)
|
537
545
|
return hidden_states, residual
|
538
546
|
|
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
|
|
17
17
|
get_tp_group,
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
20
21
|
|
21
22
|
if TYPE_CHECKING:
|
22
23
|
from sglang.srt.configs.model_config import ModelConfig
|
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
|
36
37
|
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
37
38
|
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
38
39
|
|
40
|
+
_is_hip = is_hip()
|
41
|
+
_USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
|
42
|
+
|
39
43
|
|
40
44
|
class DpPaddingMode(IntEnum):
|
41
45
|
|
@@ -67,7 +71,12 @@ class DpPaddingMode(IntEnum):
|
|
67
71
|
|
68
72
|
@classmethod
|
69
73
|
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
70
|
-
|
74
|
+
# TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
|
75
|
+
# it can be safely removed later, once RCCL fixed
|
76
|
+
if _USE_ROCM700A_WA:
|
77
|
+
return cls.SUM_LEN
|
78
|
+
else:
|
79
|
+
return cls.MAX_LEN
|
71
80
|
|
72
81
|
|
73
82
|
class _DpGatheredBufferWrapper:
|
@@ -254,6 +263,7 @@ def initialize_dp_attention(
|
|
254
263
|
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
255
264
|
use_pymscclpp=False,
|
256
265
|
use_custom_allreduce=False,
|
266
|
+
use_torch_symm_mem=False,
|
257
267
|
use_hpu_communicator=False,
|
258
268
|
use_xpu_communicator=False,
|
259
269
|
use_npu_communicator=False,
|
sglang/srt/layers/elementwise.py
CHANGED
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
|
187
187
|
|
188
188
|
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
189
189
|
assert len(x.shape) == 2
|
190
|
-
assert
|
190
|
+
assert (
|
191
|
+
x.shape == residual.shape and x.dtype == residual.dtype
|
192
|
+
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
|
191
193
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
192
194
|
bs, hidden_dim = x.shape
|
193
195
|
if autotune:
|
sglang/srt/layers/layernorm.py
CHANGED
sglang/srt/layers/linear.py
CHANGED
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
|
|
31
31
|
_ColumnvLLMParameter,
|
32
32
|
)
|
33
33
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
34
|
+
from sglang.srt.layers.utils import pad_or_narrow_weight
|
34
35
|
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
35
36
|
|
36
37
|
if TYPE_CHECKING:
|
@@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
625
626
|
# bitsandbytes loads the weights of the specific portion
|
626
627
|
# no need to narrow here
|
627
628
|
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
628
|
-
|
629
|
-
|
630
|
-
|
629
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
630
|
+
end_idx = start_idx + shard_size
|
631
|
+
if end_idx > loaded_weight.shape[output_dim]:
|
632
|
+
loaded_weight = pad_or_narrow_weight(
|
633
|
+
loaded_weight, output_dim, start_idx, shard_size
|
634
|
+
)
|
635
|
+
else:
|
636
|
+
loaded_weight = loaded_weight.narrow(
|
637
|
+
output_dim, start_idx, shard_size
|
638
|
+
)
|
631
639
|
|
632
640
|
# Special case for AQLM codebooks.
|
633
641
|
elif is_metadata:
|
@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase):
|
|
1302
1310
|
shard_size,
|
1303
1311
|
)
|
1304
1312
|
else:
|
1305
|
-
|
1313
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
1314
|
+
end_idx = start_idx + shard_size
|
1315
|
+
if end_idx > loaded_weight.shape[input_dim]:
|
1316
|
+
loaded_weight = pad_or_narrow_weight(
|
1317
|
+
loaded_weight, input_dim, start_idx, shard_size
|
1318
|
+
)
|
1319
|
+
else:
|
1320
|
+
loaded_weight = loaded_weight.narrow(
|
1321
|
+
input_dim, start_idx, shard_size
|
1322
|
+
)
|
1306
1323
|
|
1307
1324
|
# Special case for loading scales off disk, which often do not
|
1308
1325
|
# have a shape (such as in the case of AutoFP8).
|
@@ -220,6 +220,7 @@ class LogitsProcessor(nn.Module):
|
|
220
220
|
self.config = config
|
221
221
|
self.logit_scale = logit_scale
|
222
222
|
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
223
|
+
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
223
224
|
if self.use_attn_tp_group:
|
224
225
|
self.attn_tp_size = get_attention_tp_size()
|
225
226
|
self.do_tensor_parallel_all_gather = (
|
@@ -461,7 +462,11 @@ class LogitsProcessor(nn.Module):
|
|
461
462
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
462
463
|
|
463
464
|
if hasattr(lm_head, "weight"):
|
464
|
-
if
|
465
|
+
if self.use_fp32_lm_head:
|
466
|
+
logits = torch.matmul(
|
467
|
+
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
|
468
|
+
)
|
469
|
+
elif use_intel_amx_backend(lm_head):
|
465
470
|
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
466
471
|
hidden_states.to(lm_head.weight.dtype),
|
467
472
|
lm_head.weight,
|
@@ -475,7 +480,15 @@ class LogitsProcessor(nn.Module):
|
|
475
480
|
else:
|
476
481
|
# GGUF models
|
477
482
|
# TODO: use weight_packed_linear for GGUF models
|
478
|
-
|
483
|
+
if self.use_fp32_lm_head:
|
484
|
+
with torch.cuda.amp.autocast(enabled=False):
|
485
|
+
logits = lm_head.quant_method.apply(
|
486
|
+
lm_head, hidden_states.to(torch.float32), embedding_bias
|
487
|
+
)
|
488
|
+
else:
|
489
|
+
logits = lm_head.quant_method.apply(
|
490
|
+
lm_head, hidden_states, embedding_bias
|
491
|
+
)
|
479
492
|
|
480
493
|
if self.logit_scale is not None:
|
481
494
|
logits.mul_(self.logit_scale)
|
@@ -1104,10 +1104,10 @@ def ep_gather(
|
|
1104
1104
|
input_index: torch.Tensor,
|
1105
1105
|
output_tensor: torch.Tensor,
|
1106
1106
|
):
|
1107
|
-
BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
|
1108
1107
|
num_warps = 2
|
1109
1108
|
num_tokens = output_tensor.shape[0]
|
1110
1109
|
hidden_size = input_tensor.shape[1]
|
1110
|
+
BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization
|
1111
1111
|
assert hidden_size % BLOCK_D == 0
|
1112
1112
|
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
1113
1113
|
_fwd_kernel_ep_gather[grid](
|