sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -39,6 +39,7 @@ import triton
|
|
|
39
39
|
import triton.language as tl
|
|
40
40
|
|
|
41
41
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
|
42
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
|
42
43
|
from sglang.srt.layers.dp_attention import (
|
|
43
44
|
DpPaddingMode,
|
|
44
45
|
get_attention_dp_rank,
|
|
@@ -89,12 +90,9 @@ class ForwardMode(IntEnum):
|
|
|
89
90
|
self == ForwardMode.EXTEND
|
|
90
91
|
or self == ForwardMode.MIXED
|
|
91
92
|
or self == ForwardMode.DRAFT_EXTEND
|
|
92
|
-
or (
|
|
93
|
-
self == ForwardMode.DRAFT_EXTEND_V2
|
|
94
|
-
if include_draft_extend_v2
|
|
95
|
-
else False
|
|
96
|
-
)
|
|
93
|
+
or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
|
|
97
94
|
or self == ForwardMode.TARGET_VERIFY
|
|
95
|
+
or self == ForwardMode.SPLIT_PREFILL
|
|
98
96
|
)
|
|
99
97
|
|
|
100
98
|
def is_decode(self):
|
|
@@ -113,22 +111,21 @@ class ForwardMode(IntEnum):
|
|
|
113
111
|
return self == ForwardMode.TARGET_VERIFY
|
|
114
112
|
|
|
115
113
|
def is_draft_extend(self, include_v2: bool = False):
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
)
|
|
120
|
-
return self == ForwardMode.DRAFT_EXTEND
|
|
114
|
+
return self == ForwardMode.DRAFT_EXTEND or (
|
|
115
|
+
include_v2 and self == ForwardMode.DRAFT_EXTEND_V2
|
|
116
|
+
)
|
|
121
117
|
|
|
122
118
|
def is_draft_extend_v2(self):
|
|
123
119
|
# For fixed shape logits output in v2 eagle worker
|
|
124
120
|
return self == ForwardMode.DRAFT_EXTEND_V2
|
|
125
121
|
|
|
126
|
-
def is_extend_or_draft_extend_or_mixed(self):
|
|
122
|
+
def is_extend_or_draft_extend_or_mixed(self, include_draft_extend_v2: bool = False):
|
|
127
123
|
return (
|
|
128
124
|
self == ForwardMode.EXTEND
|
|
129
125
|
or self == ForwardMode.DRAFT_EXTEND
|
|
130
126
|
or self == ForwardMode.MIXED
|
|
131
127
|
or self == ForwardMode.SPLIT_PREFILL
|
|
128
|
+
or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
|
|
132
129
|
)
|
|
133
130
|
|
|
134
131
|
def is_cuda_graph(self):
|
|
@@ -250,6 +247,8 @@ class ForwardBatch:
|
|
|
250
247
|
# For MLA chunked prefix cache used in chunked prefill
|
|
251
248
|
# Tell attention backend whether lse needs to be returned
|
|
252
249
|
mha_return_lse: Optional[bool] = None
|
|
250
|
+
mha_one_shot_kv_indices: Optional[torch.Tensor] = None
|
|
251
|
+
mha_one_shot: Optional[bool] = None
|
|
253
252
|
|
|
254
253
|
# For multimodal
|
|
255
254
|
mm_inputs: Optional[List[MultimodalInputs]] = None
|
|
@@ -316,6 +315,9 @@ class ForwardBatch:
|
|
|
316
315
|
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
|
317
316
|
tbo_children: Optional[List[ForwardBatch]] = None
|
|
318
317
|
|
|
318
|
+
# For matryoshka embeddings
|
|
319
|
+
dimensions: Optional[list[int]] = None
|
|
320
|
+
|
|
319
321
|
@classmethod
|
|
320
322
|
def init_new(
|
|
321
323
|
cls,
|
|
@@ -357,6 +359,7 @@ class ForwardBatch:
|
|
|
357
359
|
input_embeds=batch.input_embeds,
|
|
358
360
|
token_type_ids=batch.token_type_ids,
|
|
359
361
|
tbo_split_seq_index=batch.tbo_split_seq_index,
|
|
362
|
+
dimensions=batch.dimensions,
|
|
360
363
|
)
|
|
361
364
|
device = model_runner.device
|
|
362
365
|
|
|
@@ -572,9 +575,15 @@ class ForwardBatch:
|
|
|
572
575
|
device=model_runner.device,
|
|
573
576
|
)
|
|
574
577
|
else:
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
+
if mm_input.mrope_position_delta.device.type != model_runner.device:
|
|
579
|
+
# transfer mrope_position_delta to device when the first running,
|
|
580
|
+
# avoiding successvie host-to-device data transfer
|
|
581
|
+
mm_input.mrope_position_delta = (
|
|
582
|
+
mm_input.mrope_position_delta.to(
|
|
583
|
+
model_runner.device, non_blocking=True
|
|
584
|
+
)
|
|
585
|
+
)
|
|
586
|
+
mrope_position_deltas = mm_input.mrope_position_delta.flatten()
|
|
578
587
|
mrope_positions_list[batch_idx] = (
|
|
579
588
|
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
|
|
580
589
|
.unsqueeze(0)
|
|
@@ -863,6 +872,10 @@ class ForwardBatch:
|
|
|
863
872
|
self.token_to_kv_pool, MLATokenToKVPool
|
|
864
873
|
), "Currently chunked prefix cache can only be used by Deepseek models"
|
|
865
874
|
|
|
875
|
+
if not any(self.extend_prefix_lens_cpu):
|
|
876
|
+
self.num_prefix_chunks = 0
|
|
877
|
+
return
|
|
878
|
+
|
|
866
879
|
if self.prefix_chunk_len is not None:
|
|
867
880
|
# Chunked kv cache info already prepared by prior modules
|
|
868
881
|
return
|
|
@@ -917,6 +930,34 @@ class ForwardBatch:
|
|
|
917
930
|
def can_run_tbo(self):
|
|
918
931
|
return self.tbo_split_seq_index is not None
|
|
919
932
|
|
|
933
|
+
def fetch_mha_one_shot_kv_indices(self):
|
|
934
|
+
if self.mha_one_shot_kv_indices is not None:
|
|
935
|
+
return self.mha_one_shot_kv_indices
|
|
936
|
+
batch_size = self.batch_size
|
|
937
|
+
paged_kernel_lens_sum = sum(self.seq_lens_cpu)
|
|
938
|
+
kv_indices = torch.empty(
|
|
939
|
+
paged_kernel_lens_sum,
|
|
940
|
+
dtype=torch.int32,
|
|
941
|
+
device=self.req_pool_indices.device,
|
|
942
|
+
)
|
|
943
|
+
kv_indptr = torch.zeros(
|
|
944
|
+
batch_size + 1,
|
|
945
|
+
dtype=torch.int32,
|
|
946
|
+
device=self.req_pool_indices.device,
|
|
947
|
+
)
|
|
948
|
+
kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
|
949
|
+
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
|
950
|
+
self.req_to_token_pool.req_to_token,
|
|
951
|
+
self.req_pool_indices,
|
|
952
|
+
self.seq_lens,
|
|
953
|
+
kv_indptr,
|
|
954
|
+
None,
|
|
955
|
+
kv_indices,
|
|
956
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
|
957
|
+
)
|
|
958
|
+
self.mha_one_shot_kv_indices = kv_indices
|
|
959
|
+
return kv_indices
|
|
960
|
+
|
|
920
961
|
|
|
921
962
|
def enable_num_token_non_padded(server_args):
|
|
922
963
|
return get_moe_expert_parallel_world_size() > 1
|
|
@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
|
|
|
29
29
|
import torch
|
|
30
30
|
import torch.distributed as dist
|
|
31
31
|
|
|
32
|
-
from sglang.srt.configs import
|
|
32
|
+
from sglang.srt.configs import (
|
|
33
|
+
FalconH1Config,
|
|
34
|
+
KimiLinearConfig,
|
|
35
|
+
NemotronHConfig,
|
|
36
|
+
Qwen3NextConfig,
|
|
37
|
+
)
|
|
33
38
|
from sglang.srt.configs.device_config import DeviceConfig
|
|
34
39
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
|
35
40
|
from sglang.srt.configs.model_config import (
|
|
@@ -40,6 +45,9 @@ from sglang.srt.configs.model_config import (
|
|
|
40
45
|
)
|
|
41
46
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
|
42
47
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
|
48
|
+
from sglang.srt.debug_utils.tensor_dump_forward_hook import (
|
|
49
|
+
register_forward_hook_for_model,
|
|
50
|
+
)
|
|
43
51
|
from sglang.srt.distributed import (
|
|
44
52
|
get_pp_group,
|
|
45
53
|
get_tp_group,
|
|
@@ -77,7 +85,6 @@ from sglang.srt.layers.dp_attention import (
|
|
|
77
85
|
initialize_dp_attention,
|
|
78
86
|
)
|
|
79
87
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
80
|
-
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
|
81
88
|
from sglang.srt.layers.sampler import Sampler
|
|
82
89
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
|
83
90
|
from sglang.srt.lora.lora_manager import LoRAManager
|
|
@@ -131,16 +138,10 @@ from sglang.srt.utils import (
|
|
|
131
138
|
get_bool_env_var,
|
|
132
139
|
get_cpu_ids_by_node,
|
|
133
140
|
init_custom_process_group,
|
|
134
|
-
is_fa3_default_architecture,
|
|
135
|
-
is_flashinfer_available,
|
|
136
141
|
is_hip,
|
|
137
|
-
is_hopper_with_cuda_12_3,
|
|
138
|
-
is_no_spec_infer_or_topk_one,
|
|
139
142
|
is_npu,
|
|
140
|
-
is_sm100_supported,
|
|
141
143
|
log_info_on_rank0,
|
|
142
144
|
monkey_patch_p2p_access_check,
|
|
143
|
-
monkey_patch_vllm_gguf_config,
|
|
144
145
|
set_cuda_arch,
|
|
145
146
|
slow_rank_detector,
|
|
146
147
|
xpu_has_xmx_support,
|
|
@@ -355,7 +356,11 @@ class ModelRunner:
|
|
|
355
356
|
|
|
356
357
|
if not self.is_draft_worker:
|
|
357
358
|
set_global_expert_location_metadata(
|
|
358
|
-
compute_initial_expert_location_metadata(
|
|
359
|
+
compute_initial_expert_location_metadata(
|
|
360
|
+
server_args=server_args,
|
|
361
|
+
model_config=self.model_config,
|
|
362
|
+
moe_ep_rank=self.moe_ep_rank,
|
|
363
|
+
)
|
|
359
364
|
)
|
|
360
365
|
if self.tp_rank == 0 and get_bool_env_var(
|
|
361
366
|
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
|
|
@@ -503,121 +508,6 @@ class ModelRunner:
|
|
|
503
508
|
def model_specific_adjustment(self):
|
|
504
509
|
server_args = self.server_args
|
|
505
510
|
|
|
506
|
-
if (
|
|
507
|
-
server_args.attention_backend == "intel_amx"
|
|
508
|
-
and server_args.device == "cpu"
|
|
509
|
-
and not _is_cpu_amx_available
|
|
510
|
-
):
|
|
511
|
-
logger.info(
|
|
512
|
-
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
|
513
|
-
)
|
|
514
|
-
server_args.attention_backend = "torch_native"
|
|
515
|
-
|
|
516
|
-
if (
|
|
517
|
-
server_args.attention_backend == "intel_xpu"
|
|
518
|
-
and server_args.device == "xpu"
|
|
519
|
-
and not _is_xpu_xmx_available
|
|
520
|
-
):
|
|
521
|
-
logger.info(
|
|
522
|
-
"The current platform does not support Intel XMX, will fallback to triton backend."
|
|
523
|
-
)
|
|
524
|
-
server_args.attention_backend = "triton"
|
|
525
|
-
|
|
526
|
-
if server_args.prefill_attention_backend is not None and (
|
|
527
|
-
server_args.prefill_attention_backend
|
|
528
|
-
== server_args.decode_attention_backend
|
|
529
|
-
): # override the default attention backend
|
|
530
|
-
server_args.attention_backend = server_args.prefill_attention_backend
|
|
531
|
-
|
|
532
|
-
if (
|
|
533
|
-
getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
|
|
534
|
-
is not None
|
|
535
|
-
):
|
|
536
|
-
if server_args.attention_backend is None:
|
|
537
|
-
server_args.attention_backend = "dual_chunk_flash_attn"
|
|
538
|
-
logger.info("Dual chunk attention is turned on by default.")
|
|
539
|
-
elif server_args.attention_backend != "dual_chunk_flash_attn":
|
|
540
|
-
raise ValueError(
|
|
541
|
-
"Dual chunk attention is enabled, but attention backend is set to "
|
|
542
|
-
f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
|
|
543
|
-
)
|
|
544
|
-
|
|
545
|
-
if server_args.attention_backend is None:
|
|
546
|
-
"""
|
|
547
|
-
Auto select the fastest attention backend.
|
|
548
|
-
|
|
549
|
-
1. Models with MHA Architecture (e.g: Llama, QWen)
|
|
550
|
-
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
|
551
|
-
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
|
552
|
-
2. Models with MLA Architecture and using FA3
|
|
553
|
-
2.1 We will use FA3 backend on hopper.
|
|
554
|
-
2.2 We will use Flashinfer backend on blackwell.
|
|
555
|
-
2.3 Otherwise, we will use triton backend.
|
|
556
|
-
"""
|
|
557
|
-
|
|
558
|
-
if not self.use_mla_backend:
|
|
559
|
-
# MHA architecture
|
|
560
|
-
if (
|
|
561
|
-
is_hopper_with_cuda_12_3()
|
|
562
|
-
and is_no_spec_infer_or_topk_one(server_args)
|
|
563
|
-
and is_fa3_default_architecture(self.model_config.hf_config)
|
|
564
|
-
):
|
|
565
|
-
server_args.attention_backend = "fa3"
|
|
566
|
-
elif _is_hip:
|
|
567
|
-
server_args.attention_backend = "aiter"
|
|
568
|
-
elif _is_npu:
|
|
569
|
-
server_args.attention_backend = "ascend"
|
|
570
|
-
else:
|
|
571
|
-
server_args.attention_backend = (
|
|
572
|
-
"flashinfer" if is_flashinfer_available() else "triton"
|
|
573
|
-
)
|
|
574
|
-
else:
|
|
575
|
-
# MLA architecture
|
|
576
|
-
if is_hopper_with_cuda_12_3():
|
|
577
|
-
server_args.attention_backend = "fa3"
|
|
578
|
-
elif is_sm100_supported():
|
|
579
|
-
server_args.attention_backend = "flashinfer"
|
|
580
|
-
elif _is_hip:
|
|
581
|
-
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
|
582
|
-
# TODO current aiter only support head number 16 or 128 head number
|
|
583
|
-
if head_num == 128 or head_num == 16:
|
|
584
|
-
server_args.attention_backend = "aiter"
|
|
585
|
-
else:
|
|
586
|
-
server_args.attention_backend = "triton"
|
|
587
|
-
elif _is_npu:
|
|
588
|
-
server_args.attention_backend = "ascend"
|
|
589
|
-
else:
|
|
590
|
-
server_args.attention_backend = "triton"
|
|
591
|
-
log_info_on_rank0(
|
|
592
|
-
logger,
|
|
593
|
-
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
|
|
594
|
-
)
|
|
595
|
-
elif self.use_mla_backend:
|
|
596
|
-
if server_args.device != "cpu":
|
|
597
|
-
if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
|
|
598
|
-
logger.info(
|
|
599
|
-
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
|
600
|
-
)
|
|
601
|
-
else:
|
|
602
|
-
raise ValueError(
|
|
603
|
-
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
|
604
|
-
)
|
|
605
|
-
else:
|
|
606
|
-
if server_args.attention_backend != "intel_amx":
|
|
607
|
-
raise ValueError(
|
|
608
|
-
"MLA optimization not supported on CPU except for intel_amx backend."
|
|
609
|
-
)
|
|
610
|
-
|
|
611
|
-
if (
|
|
612
|
-
server_args.attention_backend == "fa3"
|
|
613
|
-
and server_args.kv_cache_dtype == "fp8_e5m2"
|
|
614
|
-
):
|
|
615
|
-
logger.warning(
|
|
616
|
-
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
|
|
617
|
-
"Setting attention backend to triton."
|
|
618
|
-
)
|
|
619
|
-
server_args.attention_backend = "triton"
|
|
620
|
-
|
|
621
511
|
if server_args.enable_double_sparsity:
|
|
622
512
|
logger.info(
|
|
623
513
|
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
|
@@ -643,37 +533,12 @@ class ModelRunner:
|
|
|
643
533
|
if not server_args.disable_chunked_prefix_cache:
|
|
644
534
|
log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
|
|
645
535
|
|
|
646
|
-
if server_args.attention_backend == "aiter":
|
|
647
|
-
if self.model_config.context_len > 8192:
|
|
648
|
-
self.mem_fraction_static *= 0.85
|
|
649
|
-
|
|
650
|
-
if (
|
|
651
|
-
server_args.enable_hierarchical_cache
|
|
652
|
-
and server_args.hicache_io_backend == "kernel"
|
|
653
|
-
):
|
|
654
|
-
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
|
|
655
|
-
if server_args.decode_attention_backend is None:
|
|
656
|
-
if not self.use_mla_backend:
|
|
657
|
-
server_args.decode_attention_backend = (
|
|
658
|
-
"flashinfer" if is_flashinfer_available() else "triton"
|
|
659
|
-
)
|
|
660
|
-
else:
|
|
661
|
-
server_args.decode_attention_backend = (
|
|
662
|
-
"flashinfer" if is_sm100_supported() else "triton"
|
|
663
|
-
)
|
|
664
|
-
elif server_args.decode_attention_backend == "fa3":
|
|
665
|
-
server_args.hicache_io_backend = "direct"
|
|
666
|
-
logger.warning(
|
|
667
|
-
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
|
668
|
-
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
|
669
|
-
)
|
|
670
|
-
|
|
671
536
|
if self.model_config.hf_config.model_type == "qwen3_vl_moe":
|
|
672
537
|
if (
|
|
673
538
|
quantization_config := getattr(
|
|
674
539
|
self.model_config.hf_config, "quantization_config", None
|
|
675
540
|
)
|
|
676
|
-
) is not None:
|
|
541
|
+
) is not None and "weight_block_size" in quantization_config:
|
|
677
542
|
weight_block_size_n = quantization_config["weight_block_size"][0]
|
|
678
543
|
|
|
679
544
|
if self.tp_size % self.moe_ep_size != 0:
|
|
@@ -858,8 +723,6 @@ class ModelRunner:
|
|
|
858
723
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
|
859
724
|
self.model_config, self.load_config, self.tp_size
|
|
860
725
|
)
|
|
861
|
-
if self.server_args.load_format == "gguf":
|
|
862
|
-
monkey_patch_vllm_gguf_config()
|
|
863
726
|
|
|
864
727
|
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
|
|
865
728
|
if self.tp_rank == 0:
|
|
@@ -878,7 +741,6 @@ class ModelRunner:
|
|
|
878
741
|
# Load the model
|
|
879
742
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
|
880
743
|
monkey_patch_vllm_parallel_state()
|
|
881
|
-
monkey_patch_isinstance_for_vllm_base_layer()
|
|
882
744
|
|
|
883
745
|
with self.memory_saver_adapter.region(
|
|
884
746
|
GPU_MEMORY_TYPE_WEIGHTS,
|
|
@@ -890,7 +752,6 @@ class ModelRunner:
|
|
|
890
752
|
device_config=DeviceConfig(self.device, self.gpu_id),
|
|
891
753
|
)
|
|
892
754
|
monkey_patch_vllm_parallel_state(reverse=True)
|
|
893
|
-
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
|
894
755
|
|
|
895
756
|
get_offloader().post_init()
|
|
896
757
|
|
|
@@ -938,6 +799,15 @@ class ModelRunner:
|
|
|
938
799
|
f"avail mem={after_avail_memory:.2f} GB, "
|
|
939
800
|
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
|
940
801
|
)
|
|
802
|
+
if self.server_args.debug_tensor_dump_output_folder is not None:
|
|
803
|
+
register_forward_hook_for_model(
|
|
804
|
+
self.model,
|
|
805
|
+
self.server_args.debug_tensor_dump_output_folder,
|
|
806
|
+
self.server_args.debug_tensor_dump_layers,
|
|
807
|
+
self.tp_size,
|
|
808
|
+
self.tp_rank,
|
|
809
|
+
self.pp_rank,
|
|
810
|
+
)
|
|
941
811
|
|
|
942
812
|
if self.server_args.elastic_ep_backend == "mooncake":
|
|
943
813
|
# Mooncake does not support `monitored_barrier`
|
|
@@ -1493,9 +1363,16 @@ class ModelRunner:
|
|
|
1493
1363
|
return config
|
|
1494
1364
|
return None
|
|
1495
1365
|
|
|
1366
|
+
@property
|
|
1367
|
+
def kimi_linear_config(self):
|
|
1368
|
+
config = self.model_config.hf_config
|
|
1369
|
+
if isinstance(config, KimiLinearConfig):
|
|
1370
|
+
return config
|
|
1371
|
+
return None
|
|
1372
|
+
|
|
1496
1373
|
@property
|
|
1497
1374
|
def mambaish_config(self):
|
|
1498
|
-
return self.mamba2_config or self.hybrid_gdn_config
|
|
1375
|
+
return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
|
|
1499
1376
|
|
|
1500
1377
|
def set_num_token_hybrid(self):
|
|
1501
1378
|
if (
|
|
@@ -1806,9 +1683,11 @@ class ModelRunner:
|
|
|
1806
1683
|
get_attention_tp_size()
|
|
1807
1684
|
),
|
|
1808
1685
|
head_dim=self.model_config.head_dim,
|
|
1809
|
-
layer_num=self.
|
|
1686
|
+
layer_num=self.num_effective_layers,
|
|
1810
1687
|
device=self.device,
|
|
1811
1688
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
1689
|
+
start_layer=self.start_layer,
|
|
1690
|
+
end_layer=self.end_layer,
|
|
1812
1691
|
)
|
|
1813
1692
|
elif self.use_mla_backend and is_nsa_model:
|
|
1814
1693
|
self.token_to_kv_pool = NSATokenToKVPool(
|
|
@@ -1824,7 +1703,7 @@ class ModelRunner:
|
|
|
1824
1703
|
end_layer=self.end_layer,
|
|
1825
1704
|
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
|
|
1826
1705
|
)
|
|
1827
|
-
elif self.use_mla_backend:
|
|
1706
|
+
elif self.use_mla_backend and not self.mambaish_config:
|
|
1828
1707
|
assert not is_nsa_model
|
|
1829
1708
|
self.token_to_kv_pool = MLATokenToKVPool(
|
|
1830
1709
|
self.max_total_num_tokens,
|
|
@@ -1868,6 +1747,12 @@ class ModelRunner:
|
|
|
1868
1747
|
device=self.device,
|
|
1869
1748
|
)
|
|
1870
1749
|
elif config := self.mambaish_config:
|
|
1750
|
+
extra_args = {}
|
|
1751
|
+
if self.use_mla_backend:
|
|
1752
|
+
extra_args = {
|
|
1753
|
+
"kv_lora_rank": self.model_config.kv_lora_rank,
|
|
1754
|
+
"qk_rope_head_dim": self.model_config.qk_rope_head_dim,
|
|
1755
|
+
}
|
|
1871
1756
|
self.token_to_kv_pool = HybridLinearKVPool(
|
|
1872
1757
|
page_size=self.page_size,
|
|
1873
1758
|
size=self.max_total_num_tokens,
|
|
@@ -1883,6 +1768,8 @@ class ModelRunner:
|
|
|
1883
1768
|
enable_kvcache_transpose=False,
|
|
1884
1769
|
device=self.device,
|
|
1885
1770
|
mamba_pool=self.req_to_token_pool.mamba_pool,
|
|
1771
|
+
use_mla=self.use_mla_backend,
|
|
1772
|
+
**extra_args,
|
|
1886
1773
|
)
|
|
1887
1774
|
else:
|
|
1888
1775
|
self.token_to_kv_pool = MHATokenToKVPool(
|
|
@@ -1898,6 +1785,7 @@ class ModelRunner:
|
|
|
1898
1785
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
1899
1786
|
start_layer=self.start_layer,
|
|
1900
1787
|
end_layer=self.end_layer,
|
|
1788
|
+
enable_alt_stream=not self.server_args.enable_pdmux,
|
|
1901
1789
|
enable_kv_cache_copy=(
|
|
1902
1790
|
self.server_args.speculative_algorithm is not None
|
|
1903
1791
|
),
|
|
@@ -1966,12 +1854,18 @@ class ModelRunner:
|
|
|
1966
1854
|
|
|
1967
1855
|
def init_attention_backend(self):
|
|
1968
1856
|
"""Init attention kernel backend."""
|
|
1969
|
-
if self.server_args.
|
|
1857
|
+
if self.server_args.enable_pdmux:
|
|
1858
|
+
self.attn_backend = self._get_attention_backend(init_new_workspace=True)
|
|
1859
|
+
self.decode_attn_backend_group = []
|
|
1860
|
+
for _ in range(self.server_args.sm_group_num):
|
|
1861
|
+
self.decode_attn_backend_group.append(self._get_attention_backend())
|
|
1862
|
+
self.decode_attn_backend = self.decode_attn_backend_group[0]
|
|
1863
|
+
elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
|
|
1970
1864
|
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
|
1971
1865
|
else:
|
|
1972
1866
|
self.attn_backend = self._get_attention_backend()
|
|
1973
1867
|
|
|
1974
|
-
def _get_attention_backend(self):
|
|
1868
|
+
def _get_attention_backend(self, init_new_workspace: bool = False):
|
|
1975
1869
|
"""Init attention kernel backend."""
|
|
1976
1870
|
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
|
|
1977
1871
|
self.server_args.get_attention_backends()
|
|
@@ -1985,10 +1879,12 @@ class ModelRunner:
|
|
|
1985
1879
|
attn_backend = HybridAttnBackend(
|
|
1986
1880
|
self,
|
|
1987
1881
|
decode_backend=self._get_attention_backend_from_str(
|
|
1988
|
-
self.decode_attention_backend_str
|
|
1882
|
+
self.decode_attention_backend_str,
|
|
1883
|
+
init_new_workspace=init_new_workspace,
|
|
1989
1884
|
),
|
|
1990
1885
|
prefill_backend=self._get_attention_backend_from_str(
|
|
1991
|
-
self.prefill_attention_backend_str
|
|
1886
|
+
self.prefill_attention_backend_str,
|
|
1887
|
+
init_new_workspace=init_new_workspace,
|
|
1992
1888
|
),
|
|
1993
1889
|
)
|
|
1994
1890
|
logger.info(
|
|
@@ -2002,7 +1898,8 @@ class ModelRunner:
|
|
|
2002
1898
|
)
|
|
2003
1899
|
else:
|
|
2004
1900
|
attn_backend = self._get_attention_backend_from_str(
|
|
2005
|
-
self.server_args.attention_backend
|
|
1901
|
+
self.server_args.attention_backend,
|
|
1902
|
+
init_new_workspace=init_new_workspace,
|
|
2006
1903
|
)
|
|
2007
1904
|
|
|
2008
1905
|
(
|
|
@@ -2011,9 +1908,12 @@ class ModelRunner:
|
|
|
2011
1908
|
) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
|
|
2012
1909
|
return attn_backend
|
|
2013
1910
|
|
|
2014
|
-
def _get_attention_backend_from_str(
|
|
1911
|
+
def _get_attention_backend_from_str(
|
|
1912
|
+
self, backend_str: str, init_new_workspace: bool = False
|
|
1913
|
+
):
|
|
2015
1914
|
if backend_str not in ATTENTION_BACKENDS:
|
|
2016
1915
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
|
1916
|
+
self.init_new_workspace = init_new_workspace
|
|
2017
1917
|
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
|
|
2018
1918
|
return attn_backend_wrapper(self, full_attention_backend)
|
|
2019
1919
|
|
|
@@ -2111,6 +2011,9 @@ class ModelRunner:
|
|
|
2111
2011
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
|
2112
2012
|
tensor_parallel(self.model, device_mesh)
|
|
2113
2013
|
|
|
2014
|
+
def update_decode_attn_backend(self, stream_idx: int):
|
|
2015
|
+
self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]
|
|
2016
|
+
|
|
2114
2017
|
def forward_decode(
|
|
2115
2018
|
self,
|
|
2116
2019
|
forward_batch: ForwardBatch,
|
|
@@ -2118,7 +2021,11 @@ class ModelRunner:
|
|
|
2118
2021
|
pp_proxy_tensors=None,
|
|
2119
2022
|
) -> LogitsProcessorOutput:
|
|
2120
2023
|
if not skip_attn_backend_init:
|
|
2121
|
-
self.
|
|
2024
|
+
if self.server_args.enable_pdmux:
|
|
2025
|
+
self.decode_attn_backend.init_forward_metadata(forward_batch)
|
|
2026
|
+
forward_batch.attn_backend = self.decode_attn_backend
|
|
2027
|
+
else:
|
|
2028
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
|
2122
2029
|
# FIXME: add pp_proxy_tensors arg to all models
|
|
2123
2030
|
kwargs = {}
|
|
2124
2031
|
if self.support_pp:
|
|
@@ -2256,18 +2163,18 @@ class ModelRunner:
|
|
|
2256
2163
|
skip_attn_backend_init=skip_attn_backend_init,
|
|
2257
2164
|
pp_proxy_tensors=pp_proxy_tensors,
|
|
2258
2165
|
)
|
|
2259
|
-
elif forward_batch.forward_mode.is_extend():
|
|
2260
|
-
ret = self.forward_extend(
|
|
2261
|
-
forward_batch,
|
|
2262
|
-
skip_attn_backend_init=skip_attn_backend_init,
|
|
2263
|
-
pp_proxy_tensors=pp_proxy_tensors,
|
|
2264
|
-
)
|
|
2265
2166
|
elif forward_batch.forward_mode.is_split_prefill():
|
|
2266
2167
|
ret = self.forward_split_prefill(
|
|
2267
2168
|
forward_batch,
|
|
2268
2169
|
reinit_attn_backend=reinit_attn_backend,
|
|
2269
2170
|
forward_count=split_forward_count,
|
|
2270
2171
|
)
|
|
2172
|
+
elif forward_batch.forward_mode.is_extend():
|
|
2173
|
+
ret = self.forward_extend(
|
|
2174
|
+
forward_batch,
|
|
2175
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
|
2176
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
|
2177
|
+
)
|
|
2271
2178
|
elif forward_batch.forward_mode.is_idle():
|
|
2272
2179
|
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
|
2273
2180
|
else:
|
|
@@ -75,9 +75,13 @@ class NPUGraphRunner(CudaGraphRunner):
|
|
|
75
75
|
|
|
76
76
|
# Replay
|
|
77
77
|
if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
78
|
+
if forward_batch.forward_mode.is_target_verify():
|
|
79
|
+
seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs
|
|
80
|
+
seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
|
|
81
|
+
else:
|
|
82
|
+
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
|
|
83
|
+
self.bs - self.raw_bs
|
|
84
|
+
)
|
|
81
85
|
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
|
82
86
|
thread.start()
|
|
83
87
|
self.graphs[self.bs].replay()
|
|
@@ -32,7 +32,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
|
32
32
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
33
33
|
set_graph_pool_id,
|
|
34
34
|
)
|
|
35
|
-
from sglang.srt.distributed.parallel_state import graph_capture
|
|
36
35
|
from sglang.srt.layers.dp_attention import (
|
|
37
36
|
DpPaddingMode,
|
|
38
37
|
get_attention_tp_rank,
|
|
@@ -250,6 +249,9 @@ class PiecewiseCudaGraphRunner:
|
|
|
250
249
|
lora_ids=None,
|
|
251
250
|
)
|
|
252
251
|
|
|
252
|
+
# Attention backend
|
|
253
|
+
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
|
254
|
+
|
|
253
255
|
with set_forward_context(forward_batch, self.attention_layers):
|
|
254
256
|
_ = self.model_runner.model.forward(
|
|
255
257
|
forward_batch.input_ids,
|
|
@@ -262,9 +264,14 @@ class PiecewiseCudaGraphRunner:
|
|
|
262
264
|
|
|
263
265
|
def can_run(self, forward_batch: ForwardBatch):
|
|
264
266
|
num_tokens = len(forward_batch.input_ids)
|
|
265
|
-
# TODO(yuwei): support return logprob
|
|
267
|
+
# TODO(yuwei): support return input_ids' logprob
|
|
266
268
|
if forward_batch.return_logprob:
|
|
267
|
-
|
|
269
|
+
for start_len, seq_len in zip(
|
|
270
|
+
forward_batch.extend_logprob_start_lens_cpu,
|
|
271
|
+
forward_batch.extend_seq_lens_cpu,
|
|
272
|
+
):
|
|
273
|
+
if start_len is not None and start_len < seq_len:
|
|
274
|
+
return False
|
|
268
275
|
if num_tokens <= self.max_num_tokens:
|
|
269
276
|
return True
|
|
270
277
|
return False
|
|
@@ -273,10 +280,10 @@ class PiecewiseCudaGraphRunner:
|
|
|
273
280
|
# Trigger CUDA graph capture for specific shapes.
|
|
274
281
|
# Capture the large shapes first so that the smaller shapes
|
|
275
282
|
# can reuse the memory pool allocated for the large shapes.
|
|
276
|
-
with freeze_gc(
|
|
277
|
-
self.model_runner.
|
|
278
|
-
|
|
279
|
-
|
|
283
|
+
with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
|
|
284
|
+
if self.model_runner.tp_group.ca_comm is not None:
|
|
285
|
+
old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
|
|
286
|
+
self.model_runner.tp_group.ca_comm.disabled = True
|
|
280
287
|
avail_mem = get_available_gpu_memory(
|
|
281
288
|
self.model_runner.device,
|
|
282
289
|
self.model_runner.gpu_id,
|
|
@@ -304,9 +311,10 @@ class PiecewiseCudaGraphRunner:
|
|
|
304
311
|
|
|
305
312
|
# Save gemlite cache after each capture
|
|
306
313
|
save_gemlite_cache()
|
|
314
|
+
if self.model_runner.tp_group.ca_comm is not None:
|
|
315
|
+
self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
|
|
307
316
|
|
|
308
317
|
def capture_one_batch_size(self, num_tokens: int):
|
|
309
|
-
stream = self.stream
|
|
310
318
|
bs = 1
|
|
311
319
|
|
|
312
320
|
# Graph inputs
|
|
@@ -370,9 +378,6 @@ class PiecewiseCudaGraphRunner:
|
|
|
370
378
|
if lora_ids is not None:
|
|
371
379
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
|
372
380
|
|
|
373
|
-
# # Attention backend
|
|
374
|
-
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
|
375
|
-
|
|
376
381
|
# Run and capture
|
|
377
382
|
def run_once():
|
|
378
383
|
# Clean intermediate result cache for DP attention
|
|
@@ -438,7 +443,7 @@ class PiecewiseCudaGraphRunner:
|
|
|
438
443
|
out_cache_loc=out_cache_loc,
|
|
439
444
|
seq_lens_sum=forward_batch.seq_lens_sum,
|
|
440
445
|
encoder_lens=forward_batch.encoder_lens,
|
|
441
|
-
return_logprob=
|
|
446
|
+
return_logprob=False,
|
|
442
447
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
|
443
448
|
extend_prefix_lens=forward_batch.extend_prefix_lens,
|
|
444
449
|
extend_start_loc=forward_batch.extend_start_loc,
|
|
@@ -474,6 +479,9 @@ class PiecewiseCudaGraphRunner:
|
|
|
474
479
|
forward_batch: ForwardBatch,
|
|
475
480
|
**kwargs,
|
|
476
481
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
|
482
|
+
if self.model_runner.tp_group.ca_comm is not None:
|
|
483
|
+
old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
|
|
484
|
+
self.model_runner.tp_group.ca_comm.disabled = True
|
|
477
485
|
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
|
|
478
486
|
# Replay
|
|
479
487
|
with set_forward_context(static_forward_batch, self.attention_layers):
|
|
@@ -499,6 +507,8 @@ class PiecewiseCudaGraphRunner:
|
|
|
499
507
|
raise NotImplementedError(
|
|
500
508
|
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
|
|
501
509
|
)
|
|
510
|
+
if self.model_runner.tp_group.ca_comm is not None:
|
|
511
|
+
self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
|
|
502
512
|
|
|
503
513
|
def get_spec_info(self, num_tokens: int):
|
|
504
514
|
spec_info = None
|