mindspore 2.2.10__cp37-none-any.whl → 2.2.14__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +2 -1
- mindspore/_akg/akg/composite/build_module.py +95 -5
- mindspore/_akg/akg/topi/cpp/impl.py +1 -1
- mindspore/_akg/akg/tvm/_ffi/base.py +1 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/util.py +18 -1
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +12 -2
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/common/_utils.py +16 -0
- mindspore/common/tensor.py +0 -2
- mindspore/communication/management.py +3 -0
- mindspore/context.py +34 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets.py +23 -0
- mindspore/dataset/engine/validators.py +1 -1
- mindspore/dataset/vision/py_transforms_util.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +5 -6
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +118 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +49 -57
- mindspore/mindrecord/tools/cifar10_to_mr.py +46 -55
- mindspore/mindrecord/tools/csv_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +4 -9
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -4
- mindspore/nn/layer/activation.py +1 -1
- mindspore/nn/layer/embedding.py +2 -2
- mindspore/nn/layer/flash_attention.py +48 -135
- mindspore/nn/loss/loss.py +1 -1
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/sgd.py +3 -2
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +6 -3
- mindspore/numpy/math_ops.py +1 -1
- mindspore/ops/__init__.py +3 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -31
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +37 -17
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/function/array_func.py +6 -5
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/linalg_func.py +21 -11
- mindspore/ops/function/math_func.py +3 -0
- mindspore/ops/function/nn_func.py +13 -11
- mindspore/ops/function/parameter_func.py +2 -0
- mindspore/ops/function/sparse_unary_func.py +2 -2
- mindspore/ops/function/vmap_func.py +1 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +3 -4
- mindspore/ops/operations/_inner_ops.py +56 -1
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +15 -4
- mindspore/ops/operations/custom_ops.py +1 -1
- mindspore/ops/operations/debug_ops.py +1 -1
- mindspore/ops/operations/image_ops.py +3 -3
- mindspore/ops/operations/inner_ops.py +49 -0
- mindspore/ops/operations/math_ops.py +65 -3
- mindspore/ops/operations/nn_ops.py +95 -28
- mindspore/ops/operations/random_ops.py +2 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/silent_check.py +162 -0
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +82 -3
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_tensor.py +3 -1
- mindspore/parallel/_transformer/transformer.py +8 -8
- mindspore/parallel/checkpoint_transform.py +191 -45
- mindspore/profiler/parser/ascend_cluster_generator.py +111 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +315 -0
- mindspore/profiler/parser/ascend_flops_generator.py +8 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +2 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +30 -6
- mindspore/profiler/parser/ascend_msprof_generator.py +16 -5
- mindspore/profiler/parser/ascend_op_generator.py +15 -7
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -2
- mindspore/profiler/parser/base_timeline_generator.py +11 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +8 -2
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +4 -2
- mindspore/profiler/parser/msadvisor_parser.py +9 -3
- mindspore/profiler/profiling.py +97 -25
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/api/symbol_tree.py +2 -2
- mindspore/rewrite/parsers/for_parser.py +6 -6
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/train/callback/_checkpoint.py +8 -8
- mindspore/train/callback/_landscape.py +2 -3
- mindspore/train/callback/_summary_collector.py +6 -7
- mindspore/train/dataset_helper.py +6 -0
- mindspore/train/model.py +17 -5
- mindspore/train/serialization.py +6 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +5 -6
- mindspore/version.py +1 -1
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/METADATA +3 -2
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/RECORD +141 -149
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/WHEEL +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/top_level.txt +0 -0
|
Binary file
|
|
File without changes
|
|
@@ -1,406 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""the base class of flash attention"""
|
|
16
|
-
from abc import ABCMeta
|
|
17
|
-
from abc import abstractmethod
|
|
18
|
-
from functools import partial
|
|
19
|
-
from collections import defaultdict
|
|
20
|
-
|
|
21
|
-
import te.platform as tbe_platform
|
|
22
|
-
from tbe import tik
|
|
23
|
-
from tbe.common.platform import get_soc_spec
|
|
24
|
-
|
|
25
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
|
|
26
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
|
|
27
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
|
|
28
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import MASK_FILL_VALUE
|
|
29
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
|
|
30
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.tik_ops_utils import TikOpsUtils
|
|
31
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingPara
|
|
32
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
|
|
33
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.sparse_tiling import SparseTiling
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class FlashAttention(metaclass=ABCMeta):
|
|
37
|
-
"""The base class of FlashAttention"""
|
|
38
|
-
|
|
39
|
-
def __init__(self, q, k, v, attn_mask, dropout_mask, alibi_mask, kernel_name,
|
|
40
|
-
tiling_stgy_cls,
|
|
41
|
-
prev_block_num=65536,
|
|
42
|
-
next_block_num=65536,
|
|
43
|
-
high_precision=False,
|
|
44
|
-
disable_debug=True):
|
|
45
|
-
"""
|
|
46
|
-
Init parameter shape
|
|
47
|
-
:param q: with shape: (B, h, N, d)
|
|
48
|
-
:param k: with shape: (B, h, N, d)
|
|
49
|
-
:param v: with shape: (B, h, N, d)
|
|
50
|
-
:param attn_mask: with shape: (1, N, N) or (B, N, N)
|
|
51
|
-
:param dropout_mask: with shape: (B, h, N, N)
|
|
52
|
-
:param alibi_mask: with shape: (B, h, 1, N)
|
|
53
|
-
:param kernel_name:
|
|
54
|
-
:param tiling_stgy_cls:
|
|
55
|
-
:param prev_block_num:
|
|
56
|
-
:param next_block_num:
|
|
57
|
-
:param disable_debug:
|
|
58
|
-
"""
|
|
59
|
-
self.tik_instance = tik.Tik(disable_debug=disable_debug)
|
|
60
|
-
self.core_num = get_soc_spec(tbe_platform.CORE_NUM)
|
|
61
|
-
self.M = tbe_platform.get_soc_spec(tbe_platform.L1_SIZE)
|
|
62
|
-
self.kernel_name = kernel_name
|
|
63
|
-
self.cont_data_mv_1_bust = partial(self.tik_instance.data_move, sid=0, nburst=1,
|
|
64
|
-
src_stride=0,
|
|
65
|
-
dst_stride=0)
|
|
66
|
-
self.tik_ops_utils = TikOpsUtils(self.tik_instance)
|
|
67
|
-
self.parse_input_shape(alibi_mask, attn_mask, dropout_mask, k, q, v)
|
|
68
|
-
# NZ
|
|
69
|
-
_, _, N1, M1, M0, N0 = self.q_shape
|
|
70
|
-
self.M1 = M1
|
|
71
|
-
self.N1 = N1
|
|
72
|
-
self.M0 = M0
|
|
73
|
-
self.N0 = N0
|
|
74
|
-
self.d = N1 * N0
|
|
75
|
-
# ND
|
|
76
|
-
batch_size, h, Nq, actual_d = self.q_ori_shape
|
|
77
|
-
self.head_num = h
|
|
78
|
-
self.B, self.Nq = batch_size * h, Nq
|
|
79
|
-
self.N = self.k_ori_shape[2]
|
|
80
|
-
self.actual_d = actual_d
|
|
81
|
-
|
|
82
|
-
self.l_shape = [batch_size, h, self.Nq]
|
|
83
|
-
self.m_shape = [batch_size, h, self.Nq]
|
|
84
|
-
self.O_shape = self.q_shape
|
|
85
|
-
|
|
86
|
-
self.prev_block_num = prev_block_num
|
|
87
|
-
self.next_block_num = next_block_num
|
|
88
|
-
self.high_precision = high_precision
|
|
89
|
-
if self.high_precision:
|
|
90
|
-
self.precision_type = FP32
|
|
91
|
-
else:
|
|
92
|
-
self.precision_type = FP16
|
|
93
|
-
if tiling_stgy_cls is None:
|
|
94
|
-
self.tiling_stgy = SparseTiling(self.Nq, self.N, self.d)
|
|
95
|
-
else:
|
|
96
|
-
self.tiling_stgy: TilingStrategy = tiling_stgy_cls(self.Nq, self.N, self.d)
|
|
97
|
-
self.Br = None
|
|
98
|
-
self.last_Br = None
|
|
99
|
-
self.Bc = None
|
|
100
|
-
self.last_Bc = None
|
|
101
|
-
self.Tr = None
|
|
102
|
-
self.Tc = None
|
|
103
|
-
self.Q_gm = None
|
|
104
|
-
self.K_gm = None
|
|
105
|
-
self.V_gm = None
|
|
106
|
-
self.dim_mask_gm = None
|
|
107
|
-
self.att_mask_gm = None
|
|
108
|
-
self.drop_mask_gm = None
|
|
109
|
-
self.alibi_mask_gm = None
|
|
110
|
-
|
|
111
|
-
@staticmethod
|
|
112
|
-
def get_l_m_gm_offset(batch_start, batch_idx, h, block_h, block_idx):
|
|
113
|
-
"""get l m gm offset"""
|
|
114
|
-
gm_offset = (batch_start + batch_idx) * h + block_idx * block_h
|
|
115
|
-
return gm_offset
|
|
116
|
-
|
|
117
|
-
@staticmethod
|
|
118
|
-
def get_alibi_gm_offset(batch_start, batch_idx, w, block_w, block_idx):
|
|
119
|
-
"""get alibi gm offset"""
|
|
120
|
-
gm_offset = (batch_start + batch_idx) * w + block_idx * block_w
|
|
121
|
-
return gm_offset
|
|
122
|
-
|
|
123
|
-
@staticmethod
|
|
124
|
-
def get_drop_mask_gm_offset(batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
|
|
125
|
-
"""get drop mask gm offset"""
|
|
126
|
-
gm_offset = (batch_start + batch_idx) * h * w + block_h_idx * (w * block_h) + block_w_idx * block_w
|
|
127
|
-
return gm_offset
|
|
128
|
-
|
|
129
|
-
@abstractmethod
|
|
130
|
-
def define_custom_inputs(self):
|
|
131
|
-
"""define custom inputs"""
|
|
132
|
-
raise NotImplementedError
|
|
133
|
-
|
|
134
|
-
@abstractmethod
|
|
135
|
-
def define_outputs(self):
|
|
136
|
-
"""define outputs"""
|
|
137
|
-
raise NotImplementedError
|
|
138
|
-
|
|
139
|
-
@abstractmethod
|
|
140
|
-
def collect_inputs(self):
|
|
141
|
-
"""collect inputs"""
|
|
142
|
-
raise NotImplementedError
|
|
143
|
-
|
|
144
|
-
@abstractmethod
|
|
145
|
-
def collect_outputs(self):
|
|
146
|
-
"""collect outputs"""
|
|
147
|
-
raise NotImplementedError
|
|
148
|
-
|
|
149
|
-
@abstractmethod
|
|
150
|
-
def compute_one_core(self, batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx):
|
|
151
|
-
"""compute one core"""
|
|
152
|
-
raise NotImplementedError
|
|
153
|
-
|
|
154
|
-
@abstractmethod
|
|
155
|
-
def prepare_global_ones(self):
|
|
156
|
-
"""prepare global ones"""
|
|
157
|
-
raise NotImplementedError
|
|
158
|
-
|
|
159
|
-
def get_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_idx):
|
|
160
|
-
"""get gm offset"""
|
|
161
|
-
gm_offset = (batch_start + batch_idx) * h * w + block_idx * block_h * self.N0
|
|
162
|
-
return gm_offset
|
|
163
|
-
|
|
164
|
-
def get_cur_tr_block_num(self, tr_idx):
|
|
165
|
-
"""get cur tr block_num"""
|
|
166
|
-
cur_prev_block_num = min(tr_idx, self.prev_block_num)
|
|
167
|
-
cur_next_block_num = min(self.next_block_num, self.Tc - tr_idx - 1)
|
|
168
|
-
block_num = cur_prev_block_num + 1 + cur_next_block_num
|
|
169
|
-
return block_num
|
|
170
|
-
|
|
171
|
-
def get_total_block_num(self):
|
|
172
|
-
"""get total block num"""
|
|
173
|
-
block_num = 0
|
|
174
|
-
for b_idx in range(self.B):
|
|
175
|
-
for tr_idx in range(self.Tr):
|
|
176
|
-
block_num += self.get_cur_tr_block_num(tr_idx)
|
|
177
|
-
return block_num
|
|
178
|
-
|
|
179
|
-
def update_core_task_map(self,
|
|
180
|
-
core_b_map,
|
|
181
|
-
core_b_tr_map,
|
|
182
|
-
core_idx,
|
|
183
|
-
b_start,
|
|
184
|
-
b_end,
|
|
185
|
-
tr_start,
|
|
186
|
-
tr_end):
|
|
187
|
-
"""update core task map"""
|
|
188
|
-
core_b_map[core_idx][0] = min(core_b_map[core_idx][0], b_start)
|
|
189
|
-
if tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
|
|
190
|
-
core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end - 1)
|
|
191
|
-
else:
|
|
192
|
-
core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end)
|
|
193
|
-
for b_idx in range(b_start, b_end + 1):
|
|
194
|
-
if b_idx == b_end and tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
|
|
195
|
-
break
|
|
196
|
-
elif b_idx == b_start and b_idx == b_end: # 没跨head
|
|
197
|
-
core_b_tr_map[core_idx][b_idx] = (tr_start, tr_end)
|
|
198
|
-
elif b_idx == b_start: # 跨head,第一个head
|
|
199
|
-
core_b_tr_map[core_idx][b_idx] = (tr_start, self.Tr)
|
|
200
|
-
elif b_idx == b_end: # 跨head,最后一个head
|
|
201
|
-
core_b_tr_map[core_idx][b_idx] = (0, tr_end)
|
|
202
|
-
else: # 跨head,中间的head
|
|
203
|
-
core_b_tr_map[core_idx][b_idx] = (0, self.Tr)
|
|
204
|
-
|
|
205
|
-
def convert_py_dict_to_tik_tensor(self, core_b_map, core_b_tr_map):
|
|
206
|
-
"""convert py dict to tik tensor"""
|
|
207
|
-
# python dict -> tik tensor
|
|
208
|
-
# [batch_start, batch_idx_end] -> [batch_start, batch_num]
|
|
209
|
-
# [tr_start, tr_idx_end] -> [tr_start, tr_idx_end)
|
|
210
|
-
core_idx_to_batch_info = self.tik_instance.Tensor(
|
|
211
|
-
"int32", (self.core_num, 2), name="core_idx_to_batch_info", scope=UB
|
|
212
|
-
)
|
|
213
|
-
core_idx_to_tr_info = self.tik_instance.Tensor(
|
|
214
|
-
"int32", (self.core_num, self.B, 2), name="core_idx_to_tr_info", scope=UB
|
|
215
|
-
)
|
|
216
|
-
for core_idx in core_b_map.keys():
|
|
217
|
-
batch_start, batch_end = core_b_map[core_idx]
|
|
218
|
-
core_idx_to_batch_info[core_idx, 0] = batch_start
|
|
219
|
-
core_idx_to_batch_info[core_idx, 1] = batch_end - batch_start + 1
|
|
220
|
-
for batch_idx in core_b_tr_map[core_idx].keys():
|
|
221
|
-
tr_start, tr_end = core_b_tr_map[core_idx][batch_idx]
|
|
222
|
-
core_idx_to_tr_info[core_idx, batch_idx, 0] = tr_start
|
|
223
|
-
core_idx_to_tr_info[core_idx, batch_idx, 1] = tr_end
|
|
224
|
-
|
|
225
|
-
return core_idx_to_batch_info, core_idx_to_tr_info
|
|
226
|
-
|
|
227
|
-
def get_core_task_info(self):
|
|
228
|
-
"""
|
|
229
|
-
Get batch start and batch number of each NPU core.
|
|
230
|
-
:return: Tensor([[core_1_batch_start, core_1_batch_num],...,[core_m_batch_start,
|
|
231
|
-
core_m_batch_num]]), Tensor([[[core_1_batch_1_Tr_start, core_1_batch_1_Tr_end],...[core_1_batch_n_Tr_start,
|
|
232
|
-
core_1_batch_n_Tr_end]],...,[[core_m_batch_1_Tr_start, core_m_batch_1_Tr_end],...[core_m_batch_n_Tr_start,
|
|
233
|
-
core_m_batch_n_Tr_end]]
|
|
234
|
-
"""
|
|
235
|
-
if self.core_num > self.B * self.Tr:
|
|
236
|
-
self.core_num = self.B * self.Tr
|
|
237
|
-
|
|
238
|
-
total_blk_num = self.get_total_block_num()
|
|
239
|
-
b_start = 0
|
|
240
|
-
tr_start = 0
|
|
241
|
-
remain_blk_num = total_blk_num
|
|
242
|
-
core_b_map = defaultdict(lambda: [100000, -1])
|
|
243
|
-
core_b_tr_map = defaultdict(lambda: defaultdict(list))
|
|
244
|
-
for core_idx in range(self.core_num):
|
|
245
|
-
cur_core_blk_num = 0
|
|
246
|
-
cur_each_core_blk_num = remain_blk_num // (self.core_num - core_idx)
|
|
247
|
-
cur_core_finished = False
|
|
248
|
-
b_end = b_start
|
|
249
|
-
tr_end = tr_start
|
|
250
|
-
while b_end < self.B:
|
|
251
|
-
while tr_end < self.Tr:
|
|
252
|
-
cur_tr_blk_num = self.get_cur_tr_block_num(tr_end)
|
|
253
|
-
if abs(cur_core_blk_num - cur_each_core_blk_num) <= \
|
|
254
|
-
(cur_core_blk_num + cur_tr_blk_num - cur_each_core_blk_num):
|
|
255
|
-
self.update_core_task_map(core_b_map, core_b_tr_map, core_idx, b_start, b_end, tr_start, tr_end)
|
|
256
|
-
remain_blk_num -= cur_core_blk_num
|
|
257
|
-
cur_core_finished = True
|
|
258
|
-
break
|
|
259
|
-
else:
|
|
260
|
-
cur_core_blk_num += cur_tr_blk_num
|
|
261
|
-
tr_end += 1
|
|
262
|
-
if tr_end == self.Tr:
|
|
263
|
-
tr_end = 0
|
|
264
|
-
b_end += 1
|
|
265
|
-
if cur_core_finished:
|
|
266
|
-
b_start = b_end
|
|
267
|
-
tr_start = tr_end
|
|
268
|
-
break
|
|
269
|
-
core_idx_to_batch_info, core_idx_to_tr_info = self.convert_py_dict_to_tik_tensor(core_b_map, core_b_tr_map)
|
|
270
|
-
return core_idx_to_batch_info, core_idx_to_tr_info
|
|
271
|
-
|
|
272
|
-
def get_attn_mask_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
|
|
273
|
-
"""get attn mask gm offset"""
|
|
274
|
-
if self.att_mask_shape[0] == 1:
|
|
275
|
-
gm_offset = block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
|
|
276
|
-
else:
|
|
277
|
-
gm_offset = ((batch_start + batch_idx) // self.head_num) * h * w \
|
|
278
|
-
+ block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
|
|
279
|
-
return gm_offset
|
|
280
|
-
|
|
281
|
-
def parse_input_shape(self, alibi_mask, attn_mask, dropout_mask, k, q, v):
|
|
282
|
-
"""parser input shape"""
|
|
283
|
-
self.has_attn_mask = False
|
|
284
|
-
self.has_drop_mask = False
|
|
285
|
-
self.has_alibi_mask = False
|
|
286
|
-
# NZ
|
|
287
|
-
self.q_shape = q["shape"]
|
|
288
|
-
self.k_shape = k["shape"]
|
|
289
|
-
self.v_shape = v["shape"]
|
|
290
|
-
# ND
|
|
291
|
-
self.q_ori_shape = q["ori_shape"]
|
|
292
|
-
self.k_ori_shape = k["ori_shape"]
|
|
293
|
-
if attn_mask is not None:
|
|
294
|
-
self.has_attn_mask = True
|
|
295
|
-
self.att_mask_shape = attn_mask["shape"]
|
|
296
|
-
if dropout_mask is not None:
|
|
297
|
-
self.has_drop_mask = True
|
|
298
|
-
self.drop_mask_shape = dropout_mask["shape"]
|
|
299
|
-
if alibi_mask is not None:
|
|
300
|
-
self.has_alibi_mask = True
|
|
301
|
-
self.alibi_mask_shape = alibi_mask["shape"]
|
|
302
|
-
|
|
303
|
-
def define_inputs_outputs(self):
|
|
304
|
-
"""define inputs outputs"""
|
|
305
|
-
self.define_common_inputs()
|
|
306
|
-
|
|
307
|
-
self.define_custom_inputs()
|
|
308
|
-
|
|
309
|
-
self.define_outputs()
|
|
310
|
-
|
|
311
|
-
def init(self):
|
|
312
|
-
"""init parameters"""
|
|
313
|
-
tiling_para: TilingPara = self.tiling_stgy.tiling()
|
|
314
|
-
|
|
315
|
-
self.Br = tiling_para.Br
|
|
316
|
-
self.last_Br = tiling_para.last_Br
|
|
317
|
-
self.Bc = tiling_para.Bc
|
|
318
|
-
self.last_Bc = tiling_para.last_Bc
|
|
319
|
-
self.Tr = tiling_para.Tr
|
|
320
|
-
self.Tc = tiling_para.Tc
|
|
321
|
-
|
|
322
|
-
self.define_inputs_outputs()
|
|
323
|
-
|
|
324
|
-
def define_common_inputs(self):
|
|
325
|
-
"""define common input gm tensors"""
|
|
326
|
-
self.Q_gm = self.tik_instance.Tensor(FP16, self.q_shape, name="Q_gm", scope=GM)
|
|
327
|
-
self.K_gm = self.tik_instance.Tensor(FP16, self.k_shape, name="K_gm", scope=GM)
|
|
328
|
-
self.V_gm = self.tik_instance.Tensor(FP16, self.v_shape, name="V_gm", scope=GM)
|
|
329
|
-
if self.has_attn_mask:
|
|
330
|
-
self.att_mask_gm = self.tik_instance.Tensor(FP16, self.att_mask_shape,
|
|
331
|
-
name="att_mask_gm", scope=GM)
|
|
332
|
-
if self.has_drop_mask:
|
|
333
|
-
self.drop_mask_gm = self.tik_instance.Tensor(FP16, self.drop_mask_shape,
|
|
334
|
-
name="drop_mask_gm", scope=GM)
|
|
335
|
-
if self.has_alibi_mask:
|
|
336
|
-
self.alibi_mask_gm = self.tik_instance.Tensor(FP16, self.alibi_mask_shape,
|
|
337
|
-
name="alibi_mask_gm", scope=GM)
|
|
338
|
-
|
|
339
|
-
def do_alibi_mask(self, Sij_ub, alibi_mask_gm_offset, m_aligned, n_aligned):
|
|
340
|
-
"""load alibi mask from gm to ub, then add Sij"""
|
|
341
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
342
|
-
alibi_mask_ub = self.tik_instance.Tensor(FP16, (1, n_aligned),
|
|
343
|
-
scope=UB, name="alibi_mask_ub")
|
|
344
|
-
self.tik_instance.data_move(alibi_mask_ub, self.alibi_mask_gm[alibi_mask_gm_offset], 0, 1,
|
|
345
|
-
n_aligned // 16, 0, 0)
|
|
346
|
-
alibi_mask_ub_broadcast = self.tik_ops_utils.broadcast_row(alibi_mask_ub, (m_aligned, n_aligned))
|
|
347
|
-
self.tik_instance.h_add(Sij_ub, Sij_ub, alibi_mask_ub_broadcast)
|
|
348
|
-
|
|
349
|
-
def do_att_mask(self, Sij_ub_N1MN0, attn_mask_gm_offset, q_blk_height, kv_blk_height,
|
|
350
|
-
q_blk_h_aligned, kv_blk_h_aligned):
|
|
351
|
-
"""load attn mask from gm to ub, then mul it by MASK_FILL_VALUE and add Sij"""
|
|
352
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
353
|
-
att_mask_ub = self.tik_instance.Tensor(FP16, (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
|
|
354
|
-
scope=UB, name="att_mask_ub")
|
|
355
|
-
self.tik_instance.data_move(att_mask_ub, self.att_mask_gm[attn_mask_gm_offset], 0,
|
|
356
|
-
kv_blk_height // self.N0, q_blk_height * self.N0 // 16,
|
|
357
|
-
(self.Nq - q_blk_height) * self.N0 // 16, 0)
|
|
358
|
-
self.tik_instance.h_mul(att_mask_ub, att_mask_ub, MASK_FILL_VALUE)
|
|
359
|
-
self.tik_instance.h_add(Sij_ub_N1MN0, Sij_ub_N1MN0, att_mask_ub)
|
|
360
|
-
|
|
361
|
-
def do_dropout_mask(self, Pij_ub, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
|
|
362
|
-
q_blk_h_aligned, q_blk_height, precision_type=FP16, workspace=None):
|
|
363
|
-
"""load drop mask from gm to ub, then mul it by Pij"""
|
|
364
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
365
|
-
dropout_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
|
|
366
|
-
scope=UB, name="drop_mask_ub")
|
|
367
|
-
self.tik_instance.data_move(dropout_mask_ub, self.drop_mask_gm[dropout_mask_gm_offset], 0,
|
|
368
|
-
q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
|
|
369
|
-
dropout_mask_ub = dropout_mask_ub.reshape((kv_blk_height // self.N0, q_blk_height, self.N0))
|
|
370
|
-
if precision_type == FP32:
|
|
371
|
-
dropout_mask_ub_fp32 = self.tik_instance.Tensor(FP32,
|
|
372
|
-
(kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
|
|
373
|
-
scope=UB, name="dropout_mask_ub_fp32")
|
|
374
|
-
self.tik_instance.h_cast(dropout_mask_ub_fp32, dropout_mask_ub, "none")
|
|
375
|
-
if workspace is None:
|
|
376
|
-
self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub_fp32)
|
|
377
|
-
else:
|
|
378
|
-
self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub_fp32)
|
|
379
|
-
else:
|
|
380
|
-
if workspace is None:
|
|
381
|
-
self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub)
|
|
382
|
-
else:
|
|
383
|
-
self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub)
|
|
384
|
-
|
|
385
|
-
def compute_process(self):
|
|
386
|
-
"""The compute process of FlashAttention"""
|
|
387
|
-
self.init()
|
|
388
|
-
self.prepare_global_ones()
|
|
389
|
-
core_idx_to_batch_info, core_idx_to_tr_info = self.get_core_task_info()
|
|
390
|
-
with self.tik_instance.for_range(begint=0, endt=self.core_num, name="core_index",
|
|
391
|
-
block_num=self.core_num) as core_idx:
|
|
392
|
-
batch_start_s = self.tik_instance.Scalar("int32", name="batch_start_s")
|
|
393
|
-
batch_num_s = self.tik_instance.Scalar("int32", name="batch_num_s")
|
|
394
|
-
|
|
395
|
-
batch_start_s.set_as(core_idx_to_batch_info[core_idx, 0])
|
|
396
|
-
batch_num_s.set_as(core_idx_to_batch_info[core_idx, 1])
|
|
397
|
-
|
|
398
|
-
self.compute_one_core(batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx)
|
|
399
|
-
|
|
400
|
-
self.tik_instance.BuildCCE(
|
|
401
|
-
kernel_name=self.kernel_name,
|
|
402
|
-
inputs=self.collect_inputs(),
|
|
403
|
-
outputs=self.collect_outputs(),
|
|
404
|
-
config={"dump_cce_code": False, "save_temp_cce_file": True, "enable_const_fold": True},
|
|
405
|
-
enable_l2=True
|
|
406
|
-
)
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""define constants"""
|
|
16
|
-
from tbe import tik
|
|
17
|
-
|
|
18
|
-
BLOCK_NUM = 16
|
|
19
|
-
FP16 = "float16"
|
|
20
|
-
INT8 = "int8"
|
|
21
|
-
INT32 = "int32"
|
|
22
|
-
FP32 = "float32"
|
|
23
|
-
REPEAT_SZ = 128
|
|
24
|
-
BLK_STRIDE = 1
|
|
25
|
-
REPEAT_STRIDE = 8
|
|
26
|
-
TRANS_CUBE_TGT = 8
|
|
27
|
-
FP16_MIN_VAL = -65504.0
|
|
28
|
-
MASK_FILL_VALUE = -10000.0
|
|
29
|
-
GM = tik.scope_gm
|
|
30
|
-
L1 = tik.scope_cbuf
|
|
31
|
-
L1OUT = tik.scope_cbuf_out
|
|
32
|
-
UB = tik.scope_ubuf
|
|
33
|
-
L0A = tik.scope_ca
|
|
34
|
-
L0B = tik.scope_cb
|
|
35
|
-
L0C = tik.scope_cc
|
|
36
|
-
DTYPE_SIZE = {
|
|
37
|
-
"int8": 1,
|
|
38
|
-
"float16": 2,
|
|
39
|
-
"int16": 2,
|
|
40
|
-
"float32": 4,
|
|
41
|
-
}
|