vllm-ascend 0.10.0rc1__cp310-cp310-manylinux_2_24_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vllm_ascend/__init__.py +27 -0
- vllm_ascend/_build_info.py +3 -0
- vllm_ascend/_version.py +21 -0
- vllm_ascend/ascend_config.py +183 -0
- vllm_ascend/ascend_forward_context.py +114 -0
- vllm_ascend/attention/__init__.py +0 -0
- vllm_ascend/attention/attention_mask.py +104 -0
- vllm_ascend/attention/attention_v1.py +477 -0
- vllm_ascend/attention/attention_v1_torchair.py +496 -0
- vllm_ascend/attention/mla_v1.py +1279 -0
- vllm_ascend/compilation/__init__.py +0 -0
- vllm_ascend/compilation/piecewise_backend.py +225 -0
- vllm_ascend/core/__init__.py +0 -0
- vllm_ascend/core/schedule_config.py +74 -0
- vllm_ascend/core/scheduler.py +487 -0
- vllm_ascend/device_allocator/__init__.py +0 -0
- vllm_ascend/device_allocator/camem.py +278 -0
- vllm_ascend/distributed/__init__.py +24 -0
- vllm_ascend/distributed/communication_op.py +25 -0
- vllm_ascend/distributed/communicator.py +96 -0
- vllm_ascend/distributed/device_communicators/__init__.py +0 -0
- vllm_ascend/distributed/device_communicators/pyhccl.py +165 -0
- vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py +253 -0
- vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +894 -0
- vllm_ascend/distributed/parallel_state.py +48 -0
- vllm_ascend/distributed/tensor_parallel.py +248 -0
- vllm_ascend/envs.py +175 -0
- vllm_ascend/libvllm_ascend_kernels.so +0 -0
- vllm_ascend/lora/__init__.py +0 -0
- vllm_ascend/lora/punica_wrapper/__init__.py +0 -0
- vllm_ascend/lora/punica_wrapper/lora_ops.py +112 -0
- vllm_ascend/lora/punica_wrapper/punica_npu.py +364 -0
- vllm_ascend/models/__init__.py +61 -0
- vllm_ascend/models/deepseek_dbo.py +1046 -0
- vllm_ascend/models/deepseek_mtp.py +218 -0
- vllm_ascend/models/deepseek_v2.py +990 -0
- vllm_ascend/models/deepseek_v3.py +27 -0
- vllm_ascend/models/pangu_moe.py +1117 -0
- vllm_ascend/models/qwen2_5_vl.py +499 -0
- vllm_ascend/models/qwen2_5_vl_without_padding.py +377 -0
- vllm_ascend/models/qwen2_vl.py +352 -0
- vllm_ascend/models/qwen3.py +156 -0
- vllm_ascend/models/qwen3_moe.py +388 -0
- vllm_ascend/multistream/__init__.py +0 -0
- vllm_ascend/multistream/base.py +29 -0
- vllm_ascend/multistream/context.py +67 -0
- vllm_ascend/multistream/decorator.py +22 -0
- vllm_ascend/multistream/layers.py +61 -0
- vllm_ascend/multistream/metadata.py +182 -0
- vllm_ascend/multistream/ms_split.py +247 -0
- vllm_ascend/ops/__init__.py +49 -0
- vllm_ascend/ops/activation.py +42 -0
- vllm_ascend/ops/attention.py +309 -0
- vllm_ascend/ops/cache.py +35 -0
- vllm_ascend/ops/comm_utils.py +62 -0
- vllm_ascend/ops/common_fused_moe.py +115 -0
- vllm_ascend/ops/expert_load_balancer.py +99 -0
- vllm_ascend/ops/fused_moe.py +1557 -0
- vllm_ascend/ops/layernorm.py +86 -0
- vllm_ascend/ops/moe_dispatcher/__init__.py +0 -0
- vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +453 -0
- vllm_ascend/ops/rotary_embedding.py +292 -0
- vllm_ascend/ops/sequence_parallel.py +120 -0
- vllm_ascend/ops/vocab_parallel_embedding.py +74 -0
- vllm_ascend/patch/__init__.py +104 -0
- vllm_ascend/patch/platform/__init__.py +25 -0
- vllm_ascend/patch/platform/patch_0_10_0/__init__.py +16 -0
- vllm_ascend/patch/platform/patch_common/__init__.py +18 -0
- vllm_ascend/patch/platform/patch_common/patch_distributed.py +115 -0
- vllm_ascend/patch/platform/patch_main/__init__.py +16 -0
- vllm_ascend/patch/worker/__init__.py +26 -0
- vllm_ascend/patch/worker/patch_0_10_0/__init__.py +18 -0
- vllm_ascend/patch/worker/patch_0_10_0/patch_sampler_gather_logprobs.py +87 -0
- vllm_ascend/patch/worker/patch_common/__init__.py +20 -0
- vllm_ascend/patch/worker/patch_common/patch_distributed.py +49 -0
- vllm_ascend/patch/worker/patch_common/patch_linear.py +145 -0
- vllm_ascend/patch/worker/patch_common/patch_minicpm.py +36 -0
- vllm_ascend/patch/worker/patch_main/__init__.py +16 -0
- vllm_ascend/platform.py +288 -0
- vllm_ascend/quantization/__init__.py +0 -0
- vllm_ascend/quantization/func_wrapper.py +184 -0
- vllm_ascend/quantization/quant_config.py +354 -0
- vllm_ascend/quantization/quantizer.py +311 -0
- vllm_ascend/quantization/w4a8_dynamic.py +396 -0
- vllm_ascend/quantization/w8a8.py +767 -0
- vllm_ascend/quantization/w8a8_dynamic.py +1033 -0
- vllm_ascend/sample/__init__.py +0 -0
- vllm_ascend/sample/rejection_sampler.py +453 -0
- vllm_ascend/sample/sampler.py +65 -0
- vllm_ascend/torchair/__init__.py +0 -0
- vllm_ascend/torchair/torchair_model_runner.py +29 -0
- vllm_ascend/torchair/torchair_worker.py +61 -0
- vllm_ascend/torchair/utils.py +98 -0
- vllm_ascend/utils.py +507 -0
- vllm_ascend/vllm_ascend_C.cpython-310-aarch64-linux-gnu.so +0 -0
- vllm_ascend/worker/__init__.py +0 -0
- vllm_ascend/worker/eagle_proposer_v1.py +384 -0
- vllm_ascend/worker/model_runner_v1.py +2791 -0
- vllm_ascend/worker/mtp_proposer_v1.py +400 -0
- vllm_ascend/worker/npu_input_batch.py +758 -0
- vllm_ascend/worker/worker_v1.py +355 -0
- vllm_ascend-0.10.0rc1.dist-info/LICENSE +201 -0
- vllm_ascend-0.10.0rc1.dist-info/METADATA +130 -0
- vllm_ascend-0.10.0rc1.dist-info/RECORD +107 -0
- vllm_ascend-0.10.0rc1.dist-info/WHEEL +5 -0
- vllm_ascend-0.10.0rc1.dist-info/entry_points.txt +5 -0
- vllm_ascend-0.10.0rc1.dist-info/top_level.txt +1 -0
vllm_ascend/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# This file is a part of the vllm-ascend project.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def register():
|
|
20
|
+
"""Register the NPU platform."""
|
|
21
|
+
|
|
22
|
+
return "vllm_ascend.platform.NPUPlatform"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_model():
|
|
26
|
+
from .models import register_model
|
|
27
|
+
register_model()
|
vllm_ascend/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
5
|
+
|
|
6
|
+
TYPE_CHECKING = False
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
12
|
+
else:
|
|
13
|
+
VERSION_TUPLE = object
|
|
14
|
+
|
|
15
|
+
version: str
|
|
16
|
+
__version__: str
|
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
|
18
|
+
version_tuple: VERSION_TUPLE
|
|
19
|
+
|
|
20
|
+
__version__ = version = '0.10.0rc1'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 10, 0, 'rc1')
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
3
|
+
# This file is a part of the vllm-ascend project.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from vllm.logger import logger
|
|
19
|
+
|
|
20
|
+
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _check_torchair_supported(model_type: str):
|
|
24
|
+
for supported_model in TORCHAIR_MODEL_LIST:
|
|
25
|
+
if supported_model in model_type.lower():
|
|
26
|
+
return True
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AscendConfig:
|
|
31
|
+
"""
|
|
32
|
+
Configuration Object for additional_config from vllm.configs.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, vllm_config):
|
|
36
|
+
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
|
37
|
+
|
|
38
|
+
torchair_graph_config = additional_config.get("torchair_graph_config",
|
|
39
|
+
{})
|
|
40
|
+
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
|
|
41
|
+
|
|
42
|
+
ascend_scheduler_config = additional_config.get(
|
|
43
|
+
"ascend_scheduler_config", {})
|
|
44
|
+
self.ascend_scheduler_config = AscendSchedulerConfig(
|
|
45
|
+
ascend_scheduler_config)
|
|
46
|
+
|
|
47
|
+
self.expert_map_path = additional_config.get("expert_map_path", None)
|
|
48
|
+
self.chunked_prefill_for_mla = additional_config.get(
|
|
49
|
+
"chunked_prefill_for_mla", False)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TorchairGraphConfig:
|
|
53
|
+
"""
|
|
54
|
+
Configuration Object for torchair_graph_config from additional_config
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, torchair_graph_config):
|
|
58
|
+
self.enabled = torchair_graph_config.get("enabled", False)
|
|
59
|
+
self.use_cached_graph = torchair_graph_config.get(
|
|
60
|
+
"use_cached_graph", False)
|
|
61
|
+
self.graph_batch_sizes = torchair_graph_config.get(
|
|
62
|
+
"graph_batch_sizes", [])
|
|
63
|
+
self.graph_batch_sizes_init = torchair_graph_config.get(
|
|
64
|
+
"graph_batch_sizes_init", False)
|
|
65
|
+
self.enable_multistream_mla = torchair_graph_config.get(
|
|
66
|
+
"enable_multistream_mla", False)
|
|
67
|
+
self.enable_multistream_moe = torchair_graph_config.get(
|
|
68
|
+
"enable_multistream_moe", False)
|
|
69
|
+
self.enable_view_optimize = torchair_graph_config.get(
|
|
70
|
+
"enable_view_optimize", True)
|
|
71
|
+
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
|
72
|
+
|
|
73
|
+
if not isinstance(self.graph_batch_sizes, list):
|
|
74
|
+
raise TypeError("graph_batch_sizes must be list[int]")
|
|
75
|
+
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
|
|
78
|
+
)
|
|
79
|
+
if not self.enabled:
|
|
80
|
+
if self.use_cached_graph:
|
|
81
|
+
raise RuntimeError(
|
|
82
|
+
"use_cached_graph is valid only when Torchair graph mode is enabled"
|
|
83
|
+
)
|
|
84
|
+
if self.graph_batch_sizes:
|
|
85
|
+
raise RuntimeError(
|
|
86
|
+
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
|
|
87
|
+
)
|
|
88
|
+
if self.graph_batch_sizes_init:
|
|
89
|
+
raise RuntimeError(
|
|
90
|
+
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
|
|
91
|
+
)
|
|
92
|
+
if self.enable_multistream_mla:
|
|
93
|
+
raise RuntimeError(
|
|
94
|
+
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
|
95
|
+
)
|
|
96
|
+
if self.enable_multistream_moe:
|
|
97
|
+
raise RuntimeError(
|
|
98
|
+
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
|
|
99
|
+
)
|
|
100
|
+
if self.enable_kv_nz:
|
|
101
|
+
raise RuntimeError(
|
|
102
|
+
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class AscendSchedulerConfig:
|
|
107
|
+
"""
|
|
108
|
+
Configuration Object for ascend_scheduler_config from additional_config
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, ascend_scheduler_config: dict):
|
|
112
|
+
self.enabled = ascend_scheduler_config.get("enabled", False)
|
|
113
|
+
# Ascend scheduler is based on vllm v0 scheduler, so we should support
|
|
114
|
+
# all vllm v0 scheduler configs as well.
|
|
115
|
+
for k, v in ascend_scheduler_config.items():
|
|
116
|
+
if not hasattr(self, k):
|
|
117
|
+
setattr(self, k, v)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def init_ascend_config(vllm_config):
|
|
124
|
+
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
|
125
|
+
refresh = additional_config.get("refresh",
|
|
126
|
+
False) if additional_config else False
|
|
127
|
+
global _ASCEND_CONFIG
|
|
128
|
+
if _ASCEND_CONFIG is not None and not refresh:
|
|
129
|
+
return _ASCEND_CONFIG
|
|
130
|
+
_ASCEND_CONFIG = AscendConfig(vllm_config)
|
|
131
|
+
return _ASCEND_CONFIG
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def clear_ascend_config():
|
|
135
|
+
global _ASCEND_CONFIG
|
|
136
|
+
_ASCEND_CONFIG = None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def get_ascend_config():
|
|
140
|
+
global _ASCEND_CONFIG
|
|
141
|
+
if _ASCEND_CONFIG is None:
|
|
142
|
+
raise RuntimeError(
|
|
143
|
+
"Ascend config is not initialized. Please call init_ascend_config first."
|
|
144
|
+
)
|
|
145
|
+
return _ASCEND_CONFIG
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def check_ascend_config(vllm_config, enforce_eager):
|
|
149
|
+
ascend_config = get_ascend_config()
|
|
150
|
+
|
|
151
|
+
# for eager mode
|
|
152
|
+
if enforce_eager:
|
|
153
|
+
# torchair_graph cannot be enabled with eager mode.
|
|
154
|
+
if ascend_config.torchair_graph_config.enabled:
|
|
155
|
+
raise RuntimeError(
|
|
156
|
+
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
|
157
|
+
)
|
|
158
|
+
# for graph mode
|
|
159
|
+
else:
|
|
160
|
+
# torchair_graph case
|
|
161
|
+
if ascend_config.torchair_graph_config.enabled:
|
|
162
|
+
# torchair_graph is supported for deepseek/pangu model only.
|
|
163
|
+
if vllm_config.model_config:
|
|
164
|
+
model_type = vllm_config.model_config.hf_config.model_type
|
|
165
|
+
if not _check_torchair_supported(model_type):
|
|
166
|
+
raise NotImplementedError(
|
|
167
|
+
"Torchair graph mode only works with following model types:"
|
|
168
|
+
f"{TORCHAIR_MODEL_LIST}.")
|
|
169
|
+
# aclgraph case
|
|
170
|
+
else:
|
|
171
|
+
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
|
|
172
|
+
if vllm_config.model_config:
|
|
173
|
+
model_type = vllm_config.model_config.hf_config.model_type
|
|
174
|
+
if "deepseek" in model_type:
|
|
175
|
+
raise NotImplementedError(
|
|
176
|
+
"ACL Graph does not support deepseek. Please "
|
|
177
|
+
"try torchair graph mode to serve deepseek models on vllm-ascend."
|
|
178
|
+
" Or set `enforce_eager=True` to use eager mode.")
|
|
179
|
+
if "qwen" not in model_type:
|
|
180
|
+
logger.warning(
|
|
181
|
+
"ACL Graph is currently experimental. Please "
|
|
182
|
+
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
|
183
|
+
" if you encourage any Error")
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from vllm.config import VllmConfig
|
|
8
|
+
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
|
9
|
+
from vllm.forward_context import get_forward_context, set_forward_context
|
|
10
|
+
|
|
11
|
+
import vllm_ascend.envs as envs
|
|
12
|
+
from vllm_ascend.platform import NPUPlatform
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusedMoEState(Enum):
|
|
16
|
+
AllGather = 0
|
|
17
|
+
All2All = 1
|
|
18
|
+
MC2 = 2
|
|
19
|
+
AllGatherEP = 3
|
|
20
|
+
NaiveMulticast = 4
|
|
21
|
+
All2AllSeq = 5
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# TODO(zzzzwwjj): add soc_version to choose branch
|
|
25
|
+
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
|
26
|
+
is_deepseek_v3_r1: bool):
|
|
27
|
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
|
28
|
+
# only supports deepseek v3/r1
|
|
29
|
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
|
30
|
+
and is_deepseek_v3_r1):
|
|
31
|
+
return FusedMoEState.AllGatherEP
|
|
32
|
+
elif ep_size == 1:
|
|
33
|
+
if with_prefill:
|
|
34
|
+
return FusedMoEState.NaiveMulticast
|
|
35
|
+
else:
|
|
36
|
+
return FusedMoEState.AllGather
|
|
37
|
+
elif envs.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
|
|
38
|
+
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
|
|
39
|
+
return (FusedMoEState.All2AllSeq if
|
|
40
|
+
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)
|
|
41
|
+
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
|
42
|
+
elif ep_size < 16 or with_prefill:
|
|
43
|
+
return FusedMoEState.All2All
|
|
44
|
+
else:
|
|
45
|
+
return FusedMoEState.MC2
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@contextmanager
|
|
49
|
+
def set_ascend_forward_context(
|
|
50
|
+
attn_metadata: Any,
|
|
51
|
+
vllm_config: VllmConfig,
|
|
52
|
+
virtual_engine: int = 0,
|
|
53
|
+
num_tokens: Optional[int] = None,
|
|
54
|
+
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
|
55
|
+
with_prefill: bool = True,
|
|
56
|
+
in_profile_run: bool = False,
|
|
57
|
+
num_actual_tokens: Optional[int] = None,
|
|
58
|
+
):
|
|
59
|
+
"""A context manager that stores the current forward context,
|
|
60
|
+
can be attention metadata, etc.
|
|
61
|
+
We add some additional param into forward_context.
|
|
62
|
+
"""
|
|
63
|
+
with set_forward_context(attn_metadata,
|
|
64
|
+
vllm_config,
|
|
65
|
+
virtual_engine=virtual_engine,
|
|
66
|
+
num_tokens=num_tokens,
|
|
67
|
+
num_tokens_across_dp=num_tokens_across_dp):
|
|
68
|
+
forward_context = get_forward_context()
|
|
69
|
+
forward_context.with_prefill = with_prefill
|
|
70
|
+
ep_size = (get_ep_group().world_size if
|
|
71
|
+
vllm_config.parallel_config.enable_expert_parallel else 1)
|
|
72
|
+
|
|
73
|
+
is_deepseek_v3_r1 = hasattr(
|
|
74
|
+
vllm_config.model_config.hf_config, 'n_routed_experts'
|
|
75
|
+
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
|
76
|
+
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
|
|
77
|
+
is_deepseek_v3_r1)
|
|
78
|
+
forward_context.fused_moe_state = fused_moe_state
|
|
79
|
+
forward_context.in_profile_run = in_profile_run
|
|
80
|
+
|
|
81
|
+
# NOTE: This cannot be set using set_forward_context
|
|
82
|
+
# due to multiple warmups before actual capturing
|
|
83
|
+
forward_context.capturing = False
|
|
84
|
+
|
|
85
|
+
if num_tokens is None and attn_metadata is not None:
|
|
86
|
+
num_tokens = attn_metadata.num_actual_tokens
|
|
87
|
+
|
|
88
|
+
dp_world_size = get_dp_group().world_size
|
|
89
|
+
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
|
90
|
+
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
max_tokens_across_dp = num_tokens
|
|
94
|
+
|
|
95
|
+
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
|
96
|
+
|
|
97
|
+
if num_tokens is not None:
|
|
98
|
+
if num_actual_tokens is None:
|
|
99
|
+
num_actual_tokens = num_tokens
|
|
100
|
+
tp_world_size = get_tp_group().world_size
|
|
101
|
+
# NOTE: token num which need to pad to when mc2
|
|
102
|
+
forward_context.padded_num_tokens = math.ceil(
|
|
103
|
+
max_tokens_across_dp / tp_world_size) * tp_world_size
|
|
104
|
+
|
|
105
|
+
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
|
106
|
+
dtype=torch.bool,
|
|
107
|
+
device=NPUPlatform.device_type)
|
|
108
|
+
mc2_mask[:num_actual_tokens] = True
|
|
109
|
+
forward_context.mc2_mask = mc2_mask
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
yield
|
|
113
|
+
finally:
|
|
114
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _generate_attn_mask(max_seq_len, dtype):
|
|
19
|
+
# Construct lower triangle matrix.
|
|
20
|
+
mask_flag = torch.tril(
|
|
21
|
+
torch.ones((max_seq_len, max_seq_len),
|
|
22
|
+
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
|
23
|
+
# Create upper triangle matrix used to mark mask positions.
|
|
24
|
+
mask_flag = ~mask_flag
|
|
25
|
+
# Currently for fp16 dtype, the mask value should be set to -inf.
|
|
26
|
+
# TODO: Eliminate this part in the future.
|
|
27
|
+
if dtype == torch.float16:
|
|
28
|
+
mask_value = torch.finfo(torch.float32).min
|
|
29
|
+
else:
|
|
30
|
+
mask_value = 1
|
|
31
|
+
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
|
32
|
+
mask_flag, mask_value).to(dtype)
|
|
33
|
+
return attn_mask
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AttentionMaskBuilder:
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
max_seq_len: int,
|
|
41
|
+
dtype: torch.dtype,
|
|
42
|
+
):
|
|
43
|
+
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
|
44
|
+
|
|
45
|
+
self._seq_len_cached = attn_mask.shape[0]
|
|
46
|
+
self.attn_mask_cache = attn_mask
|
|
47
|
+
self.splitfuse_mask_value = -10000
|
|
48
|
+
|
|
49
|
+
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
|
50
|
+
device: torch.device):
|
|
51
|
+
self._update_attn_cache(max_seq_len, dtype, device)
|
|
52
|
+
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
|
53
|
+
|
|
54
|
+
def get_splitfuse_attn_mask(
|
|
55
|
+
self,
|
|
56
|
+
seq_lens,
|
|
57
|
+
query_lens,
|
|
58
|
+
position,
|
|
59
|
+
dtype,
|
|
60
|
+
device,
|
|
61
|
+
) -> torch.Tensor:
|
|
62
|
+
max_seq_len = max(seq_lens, default=0)
|
|
63
|
+
if max_seq_len <= self._seq_len_cached:
|
|
64
|
+
self._update_attn_cache(max_seq_len, dtype, device)
|
|
65
|
+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
|
66
|
+
# is not the same. Fix this in the future when kernel is ready.
|
|
67
|
+
if self.attn_mask_cache.numel(
|
|
68
|
+
) > 1 and self.attn_mask_cache[0][1] > 0:
|
|
69
|
+
attn_mask = self.get_attn_mask( # type: ignore
|
|
70
|
+
max_seq_len, dtype, device)
|
|
71
|
+
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
|
|
72
|
+
attn_mask = attn_mask * -10000
|
|
73
|
+
else:
|
|
74
|
+
attn_mask = self.attn_mask_cache
|
|
75
|
+
return torch.index_select(attn_mask, dim=0,
|
|
76
|
+
index=position)[:, :max_seq_len]
|
|
77
|
+
total_q_len = sum(query_lens)
|
|
78
|
+
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
|
79
|
+
dtype=dtype,
|
|
80
|
+
device="cpu")
|
|
81
|
+
current_row = 0
|
|
82
|
+
for i in range(len(query_lens)):
|
|
83
|
+
seq_len = seq_lens[i]
|
|
84
|
+
q_len = query_lens[i]
|
|
85
|
+
context_len = seq_len - q_len
|
|
86
|
+
|
|
87
|
+
assert context_len >= 0
|
|
88
|
+
attn_mask[current_row:current_row + q_len,
|
|
89
|
+
context_len:] = self.splitfuse_mask_value
|
|
90
|
+
right_tensor = attn_mask[current_row:current_row + q_len,
|
|
91
|
+
context_len:seq_len]
|
|
92
|
+
right_tensor.masked_fill_(
|
|
93
|
+
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
|
94
|
+
current_row += q_len
|
|
95
|
+
|
|
96
|
+
return attn_mask.to(device, non_blocking=True)
|
|
97
|
+
|
|
98
|
+
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
|
99
|
+
device: torch.device):
|
|
100
|
+
if seqlen > self._seq_len_cached:
|
|
101
|
+
self._seq_len_cached = seqlen
|
|
102
|
+
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
|
103
|
+
if self.attn_mask_cache.device != device:
|
|
104
|
+
self.attn_mask_cache = self.attn_mask_cache.to(device)
|