vllm-ascend 0.10.0rc1__cp310-cp310-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. vllm_ascend/__init__.py +27 -0
  2. vllm_ascend/_build_info.py +3 -0
  3. vllm_ascend/_version.py +21 -0
  4. vllm_ascend/ascend_config.py +183 -0
  5. vllm_ascend/ascend_forward_context.py +114 -0
  6. vllm_ascend/attention/__init__.py +0 -0
  7. vllm_ascend/attention/attention_mask.py +104 -0
  8. vllm_ascend/attention/attention_v1.py +477 -0
  9. vllm_ascend/attention/attention_v1_torchair.py +496 -0
  10. vllm_ascend/attention/mla_v1.py +1279 -0
  11. vllm_ascend/compilation/__init__.py +0 -0
  12. vllm_ascend/compilation/piecewise_backend.py +225 -0
  13. vllm_ascend/core/__init__.py +0 -0
  14. vllm_ascend/core/schedule_config.py +74 -0
  15. vllm_ascend/core/scheduler.py +487 -0
  16. vllm_ascend/device_allocator/__init__.py +0 -0
  17. vllm_ascend/device_allocator/camem.py +278 -0
  18. vllm_ascend/distributed/__init__.py +24 -0
  19. vllm_ascend/distributed/communication_op.py +25 -0
  20. vllm_ascend/distributed/communicator.py +96 -0
  21. vllm_ascend/distributed/device_communicators/__init__.py +0 -0
  22. vllm_ascend/distributed/device_communicators/pyhccl.py +165 -0
  23. vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py +253 -0
  24. vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +894 -0
  25. vllm_ascend/distributed/parallel_state.py +48 -0
  26. vllm_ascend/distributed/tensor_parallel.py +248 -0
  27. vllm_ascend/envs.py +175 -0
  28. vllm_ascend/libvllm_ascend_kernels.so +0 -0
  29. vllm_ascend/lora/__init__.py +0 -0
  30. vllm_ascend/lora/punica_wrapper/__init__.py +0 -0
  31. vllm_ascend/lora/punica_wrapper/lora_ops.py +112 -0
  32. vllm_ascend/lora/punica_wrapper/punica_npu.py +364 -0
  33. vllm_ascend/models/__init__.py +61 -0
  34. vllm_ascend/models/deepseek_dbo.py +1046 -0
  35. vllm_ascend/models/deepseek_mtp.py +218 -0
  36. vllm_ascend/models/deepseek_v2.py +990 -0
  37. vllm_ascend/models/deepseek_v3.py +27 -0
  38. vllm_ascend/models/pangu_moe.py +1117 -0
  39. vllm_ascend/models/qwen2_5_vl.py +499 -0
  40. vllm_ascend/models/qwen2_5_vl_without_padding.py +377 -0
  41. vllm_ascend/models/qwen2_vl.py +352 -0
  42. vllm_ascend/models/qwen3.py +156 -0
  43. vllm_ascend/models/qwen3_moe.py +388 -0
  44. vllm_ascend/multistream/__init__.py +0 -0
  45. vllm_ascend/multistream/base.py +29 -0
  46. vllm_ascend/multistream/context.py +67 -0
  47. vllm_ascend/multistream/decorator.py +22 -0
  48. vllm_ascend/multistream/layers.py +61 -0
  49. vllm_ascend/multistream/metadata.py +182 -0
  50. vllm_ascend/multistream/ms_split.py +247 -0
  51. vllm_ascend/ops/__init__.py +49 -0
  52. vllm_ascend/ops/activation.py +42 -0
  53. vllm_ascend/ops/attention.py +309 -0
  54. vllm_ascend/ops/cache.py +35 -0
  55. vllm_ascend/ops/comm_utils.py +62 -0
  56. vllm_ascend/ops/common_fused_moe.py +115 -0
  57. vllm_ascend/ops/expert_load_balancer.py +99 -0
  58. vllm_ascend/ops/fused_moe.py +1557 -0
  59. vllm_ascend/ops/layernorm.py +86 -0
  60. vllm_ascend/ops/moe_dispatcher/__init__.py +0 -0
  61. vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +453 -0
  62. vllm_ascend/ops/rotary_embedding.py +292 -0
  63. vllm_ascend/ops/sequence_parallel.py +120 -0
  64. vllm_ascend/ops/vocab_parallel_embedding.py +74 -0
  65. vllm_ascend/patch/__init__.py +104 -0
  66. vllm_ascend/patch/platform/__init__.py +25 -0
  67. vllm_ascend/patch/platform/patch_0_10_0/__init__.py +16 -0
  68. vllm_ascend/patch/platform/patch_common/__init__.py +18 -0
  69. vllm_ascend/patch/platform/patch_common/patch_distributed.py +115 -0
  70. vllm_ascend/patch/platform/patch_main/__init__.py +16 -0
  71. vllm_ascend/patch/worker/__init__.py +26 -0
  72. vllm_ascend/patch/worker/patch_0_10_0/__init__.py +18 -0
  73. vllm_ascend/patch/worker/patch_0_10_0/patch_sampler_gather_logprobs.py +87 -0
  74. vllm_ascend/patch/worker/patch_common/__init__.py +20 -0
  75. vllm_ascend/patch/worker/patch_common/patch_distributed.py +49 -0
  76. vllm_ascend/patch/worker/patch_common/patch_linear.py +145 -0
  77. vllm_ascend/patch/worker/patch_common/patch_minicpm.py +36 -0
  78. vllm_ascend/patch/worker/patch_main/__init__.py +16 -0
  79. vllm_ascend/platform.py +288 -0
  80. vllm_ascend/quantization/__init__.py +0 -0
  81. vllm_ascend/quantization/func_wrapper.py +184 -0
  82. vllm_ascend/quantization/quant_config.py +354 -0
  83. vllm_ascend/quantization/quantizer.py +311 -0
  84. vllm_ascend/quantization/w4a8_dynamic.py +396 -0
  85. vllm_ascend/quantization/w8a8.py +767 -0
  86. vllm_ascend/quantization/w8a8_dynamic.py +1033 -0
  87. vllm_ascend/sample/__init__.py +0 -0
  88. vllm_ascend/sample/rejection_sampler.py +453 -0
  89. vllm_ascend/sample/sampler.py +65 -0
  90. vllm_ascend/torchair/__init__.py +0 -0
  91. vllm_ascend/torchair/torchair_model_runner.py +29 -0
  92. vllm_ascend/torchair/torchair_worker.py +61 -0
  93. vllm_ascend/torchair/utils.py +98 -0
  94. vllm_ascend/utils.py +507 -0
  95. vllm_ascend/vllm_ascend_C.cpython-310-aarch64-linux-gnu.so +0 -0
  96. vllm_ascend/worker/__init__.py +0 -0
  97. vllm_ascend/worker/eagle_proposer_v1.py +384 -0
  98. vllm_ascend/worker/model_runner_v1.py +2791 -0
  99. vllm_ascend/worker/mtp_proposer_v1.py +400 -0
  100. vllm_ascend/worker/npu_input_batch.py +758 -0
  101. vllm_ascend/worker/worker_v1.py +355 -0
  102. vllm_ascend-0.10.0rc1.dist-info/LICENSE +201 -0
  103. vllm_ascend-0.10.0rc1.dist-info/METADATA +130 -0
  104. vllm_ascend-0.10.0rc1.dist-info/RECORD +107 -0
  105. vllm_ascend-0.10.0rc1.dist-info/WHEEL +5 -0
  106. vllm_ascend-0.10.0rc1.dist-info/entry_points.txt +5 -0
  107. vllm_ascend-0.10.0rc1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,27 @@
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # This file is a part of the vllm-ascend project.
16
+ #
17
+
18
+
19
+ def register():
20
+ """Register the NPU platform."""
21
+
22
+ return "vllm_ascend.platform.NPUPlatform"
23
+
24
+
25
+ def register_model():
26
+ from .models import register_model
27
+ register_model()
@@ -0,0 +1,3 @@
1
+ # Auto-generated file
2
+ __soc_version__ = 'ASCEND910B1'
3
+ __sleep_mode_enabled__ = True
@@ -0,0 +1,21 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.10.0rc1'
21
+ __version_tuple__ = version_tuple = (0, 10, 0, 'rc1')
@@ -0,0 +1,183 @@
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Optional
17
+
18
+ from vllm.logger import logger
19
+
20
+ TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"]
21
+
22
+
23
+ def _check_torchair_supported(model_type: str):
24
+ for supported_model in TORCHAIR_MODEL_LIST:
25
+ if supported_model in model_type.lower():
26
+ return True
27
+ return False
28
+
29
+
30
+ class AscendConfig:
31
+ """
32
+ Configuration Object for additional_config from vllm.configs.
33
+ """
34
+
35
+ def __init__(self, vllm_config):
36
+ additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
37
+
38
+ torchair_graph_config = additional_config.get("torchair_graph_config",
39
+ {})
40
+ self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
41
+
42
+ ascend_scheduler_config = additional_config.get(
43
+ "ascend_scheduler_config", {})
44
+ self.ascend_scheduler_config = AscendSchedulerConfig(
45
+ ascend_scheduler_config)
46
+
47
+ self.expert_map_path = additional_config.get("expert_map_path", None)
48
+ self.chunked_prefill_for_mla = additional_config.get(
49
+ "chunked_prefill_for_mla", False)
50
+
51
+
52
+ class TorchairGraphConfig:
53
+ """
54
+ Configuration Object for torchair_graph_config from additional_config
55
+ """
56
+
57
+ def __init__(self, torchair_graph_config):
58
+ self.enabled = torchair_graph_config.get("enabled", False)
59
+ self.use_cached_graph = torchair_graph_config.get(
60
+ "use_cached_graph", False)
61
+ self.graph_batch_sizes = torchair_graph_config.get(
62
+ "graph_batch_sizes", [])
63
+ self.graph_batch_sizes_init = torchair_graph_config.get(
64
+ "graph_batch_sizes_init", False)
65
+ self.enable_multistream_mla = torchair_graph_config.get(
66
+ "enable_multistream_mla", False)
67
+ self.enable_multistream_moe = torchair_graph_config.get(
68
+ "enable_multistream_moe", False)
69
+ self.enable_view_optimize = torchair_graph_config.get(
70
+ "enable_view_optimize", True)
71
+ self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
72
+
73
+ if not isinstance(self.graph_batch_sizes, list):
74
+ raise TypeError("graph_batch_sizes must be list[int]")
75
+ if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
76
+ raise ValueError(
77
+ "graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
78
+ )
79
+ if not self.enabled:
80
+ if self.use_cached_graph:
81
+ raise RuntimeError(
82
+ "use_cached_graph is valid only when Torchair graph mode is enabled"
83
+ )
84
+ if self.graph_batch_sizes:
85
+ raise RuntimeError(
86
+ "graph_batch_sizes is valid only when Torchair graph mode is enabled"
87
+ )
88
+ if self.graph_batch_sizes_init:
89
+ raise RuntimeError(
90
+ "graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
91
+ )
92
+ if self.enable_multistream_mla:
93
+ raise RuntimeError(
94
+ "enable_multistream_mla is valid only when Torchair graph mode is enabled"
95
+ )
96
+ if self.enable_multistream_moe:
97
+ raise RuntimeError(
98
+ "enable_multistream_moe is valid only when Torchair graph mode is enabled"
99
+ )
100
+ if self.enable_kv_nz:
101
+ raise RuntimeError(
102
+ "enable_kv_nz is valid only when Torchair graph mode is enabled"
103
+ )
104
+
105
+
106
+ class AscendSchedulerConfig:
107
+ """
108
+ Configuration Object for ascend_scheduler_config from additional_config
109
+ """
110
+
111
+ def __init__(self, ascend_scheduler_config: dict):
112
+ self.enabled = ascend_scheduler_config.get("enabled", False)
113
+ # Ascend scheduler is based on vllm v0 scheduler, so we should support
114
+ # all vllm v0 scheduler configs as well.
115
+ for k, v in ascend_scheduler_config.items():
116
+ if not hasattr(self, k):
117
+ setattr(self, k, v)
118
+
119
+
120
+ _ASCEND_CONFIG: Optional[AscendConfig] = None
121
+
122
+
123
+ def init_ascend_config(vllm_config):
124
+ additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
125
+ refresh = additional_config.get("refresh",
126
+ False) if additional_config else False
127
+ global _ASCEND_CONFIG
128
+ if _ASCEND_CONFIG is not None and not refresh:
129
+ return _ASCEND_CONFIG
130
+ _ASCEND_CONFIG = AscendConfig(vllm_config)
131
+ return _ASCEND_CONFIG
132
+
133
+
134
+ def clear_ascend_config():
135
+ global _ASCEND_CONFIG
136
+ _ASCEND_CONFIG = None
137
+
138
+
139
+ def get_ascend_config():
140
+ global _ASCEND_CONFIG
141
+ if _ASCEND_CONFIG is None:
142
+ raise RuntimeError(
143
+ "Ascend config is not initialized. Please call init_ascend_config first."
144
+ )
145
+ return _ASCEND_CONFIG
146
+
147
+
148
+ def check_ascend_config(vllm_config, enforce_eager):
149
+ ascend_config = get_ascend_config()
150
+
151
+ # for eager mode
152
+ if enforce_eager:
153
+ # torchair_graph cannot be enabled with eager mode.
154
+ if ascend_config.torchair_graph_config.enabled:
155
+ raise RuntimeError(
156
+ "Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
157
+ )
158
+ # for graph mode
159
+ else:
160
+ # torchair_graph case
161
+ if ascend_config.torchair_graph_config.enabled:
162
+ # torchair_graph is supported for deepseek/pangu model only.
163
+ if vllm_config.model_config:
164
+ model_type = vllm_config.model_config.hf_config.model_type
165
+ if not _check_torchair_supported(model_type):
166
+ raise NotImplementedError(
167
+ "Torchair graph mode only works with following model types:"
168
+ f"{TORCHAIR_MODEL_LIST}.")
169
+ # aclgraph case
170
+ else:
171
+ # aclgraph doesn't work with deepseek model and only qwen model is well tested.
172
+ if vllm_config.model_config:
173
+ model_type = vllm_config.model_config.hf_config.model_type
174
+ if "deepseek" in model_type:
175
+ raise NotImplementedError(
176
+ "ACL Graph does not support deepseek. Please "
177
+ "try torchair graph mode to serve deepseek models on vllm-ascend."
178
+ " Or set `enforce_eager=True` to use eager mode.")
179
+ if "qwen" not in model_type:
180
+ logger.warning(
181
+ "ACL Graph is currently experimental. Please "
182
+ "raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
183
+ " if you encourage any Error")
@@ -0,0 +1,114 @@
1
+ import math
2
+ from contextlib import contextmanager
3
+ from enum import Enum
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ from vllm.config import VllmConfig
8
+ from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
9
+ from vllm.forward_context import get_forward_context, set_forward_context
10
+
11
+ import vllm_ascend.envs as envs
12
+ from vllm_ascend.platform import NPUPlatform
13
+
14
+
15
+ class FusedMoEState(Enum):
16
+ AllGather = 0
17
+ All2All = 1
18
+ MC2 = 2
19
+ AllGatherEP = 3
20
+ NaiveMulticast = 4
21
+ All2AllSeq = 5
22
+
23
+
24
+ # TODO(zzzzwwjj): add soc_version to choose branch
25
+ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
26
+ is_deepseek_v3_r1: bool):
27
+ # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
28
+ # only supports deepseek v3/r1
29
+ if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
30
+ and is_deepseek_v3_r1):
31
+ return FusedMoEState.AllGatherEP
32
+ elif ep_size == 1:
33
+ if with_prefill:
34
+ return FusedMoEState.NaiveMulticast
35
+ else:
36
+ return FusedMoEState.AllGather
37
+ elif envs.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
38
+ # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
39
+ return (FusedMoEState.All2AllSeq if
40
+ (ep_size < 16 or with_prefill) else FusedMoEState.MC2)
41
+ # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
42
+ elif ep_size < 16 or with_prefill:
43
+ return FusedMoEState.All2All
44
+ else:
45
+ return FusedMoEState.MC2
46
+
47
+
48
+ @contextmanager
49
+ def set_ascend_forward_context(
50
+ attn_metadata: Any,
51
+ vllm_config: VllmConfig,
52
+ virtual_engine: int = 0,
53
+ num_tokens: Optional[int] = None,
54
+ num_tokens_across_dp: Optional[torch.Tensor] = None,
55
+ with_prefill: bool = True,
56
+ in_profile_run: bool = False,
57
+ num_actual_tokens: Optional[int] = None,
58
+ ):
59
+ """A context manager that stores the current forward context,
60
+ can be attention metadata, etc.
61
+ We add some additional param into forward_context.
62
+ """
63
+ with set_forward_context(attn_metadata,
64
+ vllm_config,
65
+ virtual_engine=virtual_engine,
66
+ num_tokens=num_tokens,
67
+ num_tokens_across_dp=num_tokens_across_dp):
68
+ forward_context = get_forward_context()
69
+ forward_context.with_prefill = with_prefill
70
+ ep_size = (get_ep_group().world_size if
71
+ vllm_config.parallel_config.enable_expert_parallel else 1)
72
+
73
+ is_deepseek_v3_r1 = hasattr(
74
+ vllm_config.model_config.hf_config, 'n_routed_experts'
75
+ ) and vllm_config.model_config.hf_config.n_routed_experts == 256
76
+ fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
77
+ is_deepseek_v3_r1)
78
+ forward_context.fused_moe_state = fused_moe_state
79
+ forward_context.in_profile_run = in_profile_run
80
+
81
+ # NOTE: This cannot be set using set_forward_context
82
+ # due to multiple warmups before actual capturing
83
+ forward_context.capturing = False
84
+
85
+ if num_tokens is None and attn_metadata is not None:
86
+ num_tokens = attn_metadata.num_actual_tokens
87
+
88
+ dp_world_size = get_dp_group().world_size
89
+ if dp_world_size > 1 and forward_context.dp_metadata is not None:
90
+ max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
91
+ )
92
+ else:
93
+ max_tokens_across_dp = num_tokens
94
+
95
+ forward_context.max_tokens_across_dp = max_tokens_across_dp
96
+
97
+ if num_tokens is not None:
98
+ if num_actual_tokens is None:
99
+ num_actual_tokens = num_tokens
100
+ tp_world_size = get_tp_group().world_size
101
+ # NOTE: token num which need to pad to when mc2
102
+ forward_context.padded_num_tokens = math.ceil(
103
+ max_tokens_across_dp / tp_world_size) * tp_world_size
104
+
105
+ mc2_mask = torch.zeros(forward_context.padded_num_tokens,
106
+ dtype=torch.bool,
107
+ device=NPUPlatform.device_type)
108
+ mc2_mask[:num_actual_tokens] = True
109
+ forward_context.mc2_mask = mc2_mask
110
+
111
+ try:
112
+ yield
113
+ finally:
114
+ pass
File without changes
@@ -0,0 +1,104 @@
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+
18
+ def _generate_attn_mask(max_seq_len, dtype):
19
+ # Construct lower triangle matrix.
20
+ mask_flag = torch.tril(
21
+ torch.ones((max_seq_len, max_seq_len),
22
+ dtype=torch.bool)).view(max_seq_len, max_seq_len)
23
+ # Create upper triangle matrix used to mark mask positions.
24
+ mask_flag = ~mask_flag
25
+ # Currently for fp16 dtype, the mask value should be set to -inf.
26
+ # TODO: Eliminate this part in the future.
27
+ if dtype == torch.float16:
28
+ mask_value = torch.finfo(torch.float32).min
29
+ else:
30
+ mask_value = 1
31
+ attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
32
+ mask_flag, mask_value).to(dtype)
33
+ return attn_mask
34
+
35
+
36
+ class AttentionMaskBuilder:
37
+
38
+ def __init__(
39
+ self,
40
+ max_seq_len: int,
41
+ dtype: torch.dtype,
42
+ ):
43
+ attn_mask = _generate_attn_mask(max_seq_len, dtype)
44
+
45
+ self._seq_len_cached = attn_mask.shape[0]
46
+ self.attn_mask_cache = attn_mask
47
+ self.splitfuse_mask_value = -10000
48
+
49
+ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
50
+ device: torch.device):
51
+ self._update_attn_cache(max_seq_len, dtype, device)
52
+ return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
53
+
54
+ def get_splitfuse_attn_mask(
55
+ self,
56
+ seq_lens,
57
+ query_lens,
58
+ position,
59
+ dtype,
60
+ device,
61
+ ) -> torch.Tensor:
62
+ max_seq_len = max(seq_lens, default=0)
63
+ if max_seq_len <= self._seq_len_cached:
64
+ self._update_attn_cache(max_seq_len, dtype, device)
65
+ # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
66
+ # is not the same. Fix this in the future when kernel is ready.
67
+ if self.attn_mask_cache.numel(
68
+ ) > 1 and self.attn_mask_cache[0][1] > 0:
69
+ attn_mask = self.get_attn_mask( # type: ignore
70
+ max_seq_len, dtype, device)
71
+ # Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
72
+ attn_mask = attn_mask * -10000
73
+ else:
74
+ attn_mask = self.attn_mask_cache
75
+ return torch.index_select(attn_mask, dim=0,
76
+ index=position)[:, :max_seq_len]
77
+ total_q_len = sum(query_lens)
78
+ attn_mask = torch.zeros((total_q_len, max_seq_len),
79
+ dtype=dtype,
80
+ device="cpu")
81
+ current_row = 0
82
+ for i in range(len(query_lens)):
83
+ seq_len = seq_lens[i]
84
+ q_len = query_lens[i]
85
+ context_len = seq_len - q_len
86
+
87
+ assert context_len >= 0
88
+ attn_mask[current_row:current_row + q_len,
89
+ context_len:] = self.splitfuse_mask_value
90
+ right_tensor = attn_mask[current_row:current_row + q_len,
91
+ context_len:seq_len]
92
+ right_tensor.masked_fill_(
93
+ right_tensor.tril() == self.splitfuse_mask_value, 0)
94
+ current_row += q_len
95
+
96
+ return attn_mask.to(device, non_blocking=True)
97
+
98
+ def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
99
+ device: torch.device):
100
+ if seqlen > self._seq_len_cached:
101
+ self._seq_len_cached = seqlen
102
+ self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
103
+ if self.attn_mask_cache.device != device:
104
+ self.attn_mask_cache = self.attn_mask_cache.to(device)