mindspore 2.2.10__cp39-cp39-win_amd64.whl → 2.2.14__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +2 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +12 -2
- mindspore/common/_utils.py +16 -0
- mindspore/common/tensor.py +0 -2
- mindspore/communication/management.py +3 -0
- mindspore/context.py +34 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets.py +23 -0
- mindspore/dataset/engine/validators.py +1 -1
- mindspore/dataset/vision/py_transforms_util.py +2 -2
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/optim/lr_scheduler.py +5 -6
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +49 -57
- mindspore/mindrecord/tools/cifar10_to_mr.py +46 -55
- mindspore/mindrecord/tools/csv_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +4 -9
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -4
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/activation.py +1 -1
- mindspore/nn/layer/embedding.py +2 -2
- mindspore/nn/layer/flash_attention.py +48 -135
- mindspore/nn/loss/loss.py +1 -1
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/sgd.py +3 -2
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +6 -3
- mindspore/numpy/math_ops.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -31
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +37 -17
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/function/array_func.py +6 -5
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/linalg_func.py +21 -11
- mindspore/ops/function/math_func.py +3 -0
- mindspore/ops/function/nn_func.py +13 -11
- mindspore/ops/function/parameter_func.py +2 -0
- mindspore/ops/function/sparse_unary_func.py +2 -2
- mindspore/ops/function/vmap_func.py +1 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +3 -4
- mindspore/ops/operations/_inner_ops.py +56 -1
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +15 -4
- mindspore/ops/operations/custom_ops.py +1 -1
- mindspore/ops/operations/debug_ops.py +1 -1
- mindspore/ops/operations/image_ops.py +3 -3
- mindspore/ops/operations/inner_ops.py +49 -0
- mindspore/ops/operations/math_ops.py +65 -3
- mindspore/ops/operations/nn_ops.py +95 -28
- mindspore/ops/operations/random_ops.py +2 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/silent_check.py +162 -0
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +82 -3
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_tensor.py +3 -1
- mindspore/parallel/_transformer/transformer.py +8 -8
- mindspore/parallel/checkpoint_transform.py +191 -45
- mindspore/profiler/parser/ascend_cluster_generator.py +111 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +315 -0
- mindspore/profiler/parser/ascend_flops_generator.py +8 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +2 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +30 -6
- mindspore/profiler/parser/ascend_msprof_generator.py +16 -5
- mindspore/profiler/parser/ascend_op_generator.py +15 -7
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -2
- mindspore/profiler/parser/base_timeline_generator.py +11 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +8 -2
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +4 -2
- mindspore/profiler/parser/msadvisor_parser.py +9 -3
- mindspore/profiler/profiling.py +97 -25
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/api/symbol_tree.py +2 -2
- mindspore/rewrite/parsers/for_parser.py +6 -6
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/callback/_checkpoint.py +8 -8
- mindspore/train/callback/_landscape.py +2 -3
- mindspore/train/callback/_summary_collector.py +6 -7
- mindspore/train/dataset_helper.py +6 -0
- mindspore/train/model.py +17 -5
- mindspore/train/serialization.py +6 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +5 -6
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/METADATA +3 -2
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/RECORD +117 -124
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/WHEEL +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/top_level.txt +0 -0
|
@@ -1,563 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""The forward tik ops of flash attention"""
|
|
16
|
-
|
|
17
|
-
from tbe import tik
|
|
18
|
-
|
|
19
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.attention import FlashAttention
|
|
20
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SIZE
|
|
21
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
|
|
22
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
|
|
23
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
|
|
24
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import L1
|
|
25
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
|
|
26
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class FlashAttentionFwd(FlashAttention):
|
|
30
|
-
"""The implementation of flash attention forward
|
|
31
|
-
This function contains the flash attention forward implementation used in flash attention (see paper)
|
|
32
|
-
`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/pdf/2205.14135.pdf>`
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
def __init__(self, query, key, value,
|
|
36
|
-
attn_mask, dropout_mask, alibi_mask,
|
|
37
|
-
kernel_name,
|
|
38
|
-
tiling_stgy: TilingStrategy,
|
|
39
|
-
prev_block_num=65536,
|
|
40
|
-
next_block_num=65536, high_precision=False, disable_debug=True):
|
|
41
|
-
super(FlashAttentionFwd, self).__init__(query, key, value, attn_mask, dropout_mask, alibi_mask,
|
|
42
|
-
kernel_name,
|
|
43
|
-
tiling_stgy, prev_block_num, next_block_num, high_precision,
|
|
44
|
-
disable_debug)
|
|
45
|
-
self.O_gm = None
|
|
46
|
-
self.l_gm = None
|
|
47
|
-
self.m_gm = None
|
|
48
|
-
self.O_gm_workspace = None
|
|
49
|
-
|
|
50
|
-
def define_custom_inputs(self):
|
|
51
|
-
"""define custom inputs"""
|
|
52
|
-
|
|
53
|
-
def define_outputs(self):
|
|
54
|
-
"""define outputs"""
|
|
55
|
-
self.O_gm = self.tik_instance.Tensor(FP16, self.O_shape, name="O_gm", scope=GM, is_atomic_add=True)
|
|
56
|
-
if self.high_precision:
|
|
57
|
-
self.O_gm_workspace = self.tik_instance.Tensor(FP32, self.O_shape, name="O_gm_workspace", scope=GM,
|
|
58
|
-
is_workspace=True, is_atomic_add=True)
|
|
59
|
-
self.l_gm = self.tik_instance.Tensor(self.precision_type, self.l_shape, name="l_gm", scope=GM,
|
|
60
|
-
is_atomic_add=True)
|
|
61
|
-
self.m_gm = self.tik_instance.Tensor(FP16, self.m_shape, name="m_gm", scope=GM, is_atomic_add=True)
|
|
62
|
-
|
|
63
|
-
def prepare_global_ones(self):
|
|
64
|
-
"""Prepare global ones tensor in L1 for cube impl row_sum"""
|
|
65
|
-
Bc_aligned = (self.Bc + 15) // 16 * 16
|
|
66
|
-
last_Bc_aligned = (self.last_Bc + 15) // 16 * 16
|
|
67
|
-
self.ones_l1 = self.tik_instance.Tensor(FP16, (Bc_aligned, 16), name="ones_l1", scope=L1)
|
|
68
|
-
self.last_ones_l1 = self.tik_instance.Tensor(FP16, (last_Bc_aligned, 16), name="last_ones_l1", scope=L1)
|
|
69
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
70
|
-
ones_ub = self.tik_instance.Tensor(FP16, (Bc_aligned, 16), name="ones_ub", scope=UB)
|
|
71
|
-
self.tik_instance.h_duplicate(ones_ub, 1.0)
|
|
72
|
-
self.cont_data_mv_1_bust(dst=self.ones_l1, src=ones_ub, burst=Bc_aligned)
|
|
73
|
-
last_ones_ub = self.tik_instance.Tensor(FP16, (last_Bc_aligned, 16), name="last_ones_ub", scope=UB)
|
|
74
|
-
self.tik_instance.h_duplicate(ones_ub, 1.0)
|
|
75
|
-
self.cont_data_mv_1_bust(dst=self.last_ones_l1, src=last_ones_ub, burst=last_Bc_aligned)
|
|
76
|
-
|
|
77
|
-
def softmax_compute(self, Sij_ub, mij_ub, lij_ub, m, n):
|
|
78
|
-
"""Refer to Algorithm 2 line12"""
|
|
79
|
-
m_aligned = self.tik_ops_utils.up_align_to_K0(m)
|
|
80
|
-
n_aligned = self.tik_ops_utils.up_align_to_K0(n)
|
|
81
|
-
n0 = 16
|
|
82
|
-
n1 = n // 16
|
|
83
|
-
# only support n % 16 == 0
|
|
84
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
85
|
-
mn0_block_max = self.tik_instance.Tensor(FP16, (1, m, n0), name="mn0_block_max", scope=UB)
|
|
86
|
-
self.cont_data_mv_1_bust(dst=mn0_block_max, src=Sij_ub, burst=m)
|
|
87
|
-
with self.tik_instance.for_range(1, n1) as idx:
|
|
88
|
-
self.tik_instance.h_max(mn0_block_max, mn0_block_max, Sij_ub[idx, :, :])
|
|
89
|
-
mn0_block_max = mn0_block_max.reshape((m, n0))
|
|
90
|
-
self.tik_instance.h_reduce_max(mij_ub, mn0_block_max, 1)
|
|
91
|
-
# Sij - mij
|
|
92
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
93
|
-
broadcast_mij_ub = self.tik_ops_utils.broadcast(mij_ub, (m, n0))
|
|
94
|
-
broadcast_mij_ub = broadcast_mij_ub.reshape((1, m, n0))
|
|
95
|
-
for idx in range(n1):
|
|
96
|
-
self.tik_instance.h_sub(Sij_ub[idx, :, :], Sij_ub[idx, :, :], broadcast_mij_ub)
|
|
97
|
-
# exp
|
|
98
|
-
if self.high_precision:
|
|
99
|
-
Sij_ub_fp32 = self.tik_instance.Tensor(
|
|
100
|
-
FP32, (n_aligned // 16, m_aligned, 16), name="Sij_ub_fp32", scope=UB
|
|
101
|
-
)
|
|
102
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
103
|
-
self.tik_instance.h_cast(Sij_ub_fp32, Sij_ub, "none")
|
|
104
|
-
self.tik_instance.h_exp(Sij_ub_fp32, Sij_ub_fp32)
|
|
105
|
-
self.tik_instance.h_cast(Sij_ub, Sij_ub_fp32, "none")
|
|
106
|
-
else:
|
|
107
|
-
self.tik_instance.h_exp(Sij_ub, Sij_ub)
|
|
108
|
-
|
|
109
|
-
# cube impl rowsum
|
|
110
|
-
Sij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (n_aligned // 16, m_aligned, 16),
|
|
111
|
-
name="Sij_l1_K1MK0_ed", scope=L1)
|
|
112
|
-
self.cont_data_mv_1_bust(dst=Sij_l1_K1MK0_ed, src=Sij_ub, burst=m * n // 16)
|
|
113
|
-
if n == self.Bc:
|
|
114
|
-
Sij_row_sum_ub = self.tik_ops_utils.row_sum_cube_impl(Sij_l1_K1MK0_ed, self.ones_l1,
|
|
115
|
-
lij_ub, m, n, self.precision_type)
|
|
116
|
-
else:
|
|
117
|
-
Sij_row_sum_ub = self.tik_ops_utils.row_sum_cube_impl(Sij_l1_K1MK0_ed, self.last_ones_l1,
|
|
118
|
-
lij_ub, m, n, self.precision_type)
|
|
119
|
-
|
|
120
|
-
if self.high_precision:
|
|
121
|
-
return Sij_ub_fp32, mij_ub, Sij_row_sum_ub
|
|
122
|
-
if self.has_drop_mask:
|
|
123
|
-
return Sij_ub, mij_ub, Sij_row_sum_ub
|
|
124
|
-
return Sij_l1_K1MK0_ed, mij_ub, Sij_row_sum_ub
|
|
125
|
-
|
|
126
|
-
def update_m_l(self, mi_old_ub, mij_ub, li_old_ub, lij_ub, vec_len):
|
|
127
|
-
"""Refer to Algorithm 2 line13
|
|
128
|
-
mi_new = max(mi, mij), li_new = exp(mi-mi_new)*li + exp(mij - mi_new) * lij
|
|
129
|
-
"""
|
|
130
|
-
dtype = li_old_ub.dtype
|
|
131
|
-
vec_len_aligned = self.tik_ops_utils.up_align_to_K0(vec_len)
|
|
132
|
-
mi_new_ub = self.tik_instance.Tensor(FP16, (vec_len_aligned,), name="mi_new_ub", scope=UB)
|
|
133
|
-
li_new_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), name="li_new_ub", scope=UB)
|
|
134
|
-
self.tik_instance.h_max(mi_new_ub, mi_old_ub, mij_ub)
|
|
135
|
-
|
|
136
|
-
self.tik_instance.h_sub(mi_old_ub, mi_old_ub, mi_new_ub) # mi-mi_new
|
|
137
|
-
self.tik_instance.h_exp(mi_old_ub, mi_old_ub) # exp(mi-mi_new)
|
|
138
|
-
|
|
139
|
-
self.tik_instance.h_sub(mij_ub, mij_ub, mi_new_ub) # mij-mi_new
|
|
140
|
-
self.tik_instance.h_exp(mij_ub, mij_ub) # exp(mij-mi_new)
|
|
141
|
-
|
|
142
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
143
|
-
mul_li_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="mul_li_ub")
|
|
144
|
-
mul_lij_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="mul_lij_ub")
|
|
145
|
-
if self.high_precision:
|
|
146
|
-
self.tik_instance.h_cast(mul_li_ub, mi_old_ub, "none")
|
|
147
|
-
self.tik_instance.h_cast(mul_lij_ub, mij_ub, "none")
|
|
148
|
-
self.tik_instance.h_mul(mul_li_ub, mul_li_ub, li_old_ub)
|
|
149
|
-
self.tik_instance.h_mul(mul_lij_ub, mul_lij_ub, lij_ub)
|
|
150
|
-
else:
|
|
151
|
-
self.tik_instance.h_mul(mul_li_ub, mi_old_ub, li_old_ub)
|
|
152
|
-
self.tik_instance.h_mul(mul_lij_ub, mij_ub, lij_ub)
|
|
153
|
-
self.tik_instance.h_add(li_new_ub, mul_li_ub, mul_lij_ub)
|
|
154
|
-
return mi_new_ub, li_new_ub
|
|
155
|
-
|
|
156
|
-
def update_o_m_l_fp32(self,
|
|
157
|
-
Pij_ub_fp32,
|
|
158
|
-
Vj_l1_K1NK0_ed,
|
|
159
|
-
Pij_ub,
|
|
160
|
-
mij_ub,
|
|
161
|
-
lij_ub,
|
|
162
|
-
batch_start,
|
|
163
|
-
batch_idx,
|
|
164
|
-
kv_blk_idx,
|
|
165
|
-
kv_blk_height,
|
|
166
|
-
q_blk_idx,
|
|
167
|
-
block_h):
|
|
168
|
-
""" load o m l from gm and update them in ub, then write them back to gm
|
|
169
|
-
:param Pij_Vj_ub: input tensor with shape of (q_blk_h_aligned, self.d)
|
|
170
|
-
:param mij_ub: input tensor with shape of (Br)
|
|
171
|
-
:param lij_ub: input tensor with shape of (Br)
|
|
172
|
-
:param batch_start:
|
|
173
|
-
:param batch_idx:
|
|
174
|
-
:param kv_blk_idx:
|
|
175
|
-
:param q_blk_idx:
|
|
176
|
-
:param block_h:
|
|
177
|
-
:return: None
|
|
178
|
-
"""
|
|
179
|
-
vec_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
|
|
180
|
-
o_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
|
|
181
|
-
block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
|
|
182
|
-
block_k_aligned_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
|
|
183
|
-
n1 = block_k_aligned_aligned // self.N0
|
|
184
|
-
with self.tik_instance.if_scope(tik.any(kv_blk_idx == 0, kv_blk_idx + self.prev_block_num == q_blk_idx)):
|
|
185
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, lij_ub, vec_gm_offset, block_h)
|
|
186
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mij_ub, vec_gm_offset, block_h)
|
|
187
|
-
li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(lij_ub, block_h)
|
|
188
|
-
vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="vec_ub", scope=UB)
|
|
189
|
-
for i in range(block_h):
|
|
190
|
-
src_scalar = self.tik_instance.Scalar(init_value=li_new_rec_ub[i], dtype=FP32)
|
|
191
|
-
self.tik_instance.h_duplicate(vec_ub[i, :], src_scalar)
|
|
192
|
-
vec_ub = vec_ub.reshape((1, block_h, self.N0))
|
|
193
|
-
with self.tik_instance.for_range(0, n1) as idx:
|
|
194
|
-
self.tik_instance.h_mul(Pij_ub_fp32[idx, :, :],
|
|
195
|
-
Pij_ub_fp32[idx, :, :],
|
|
196
|
-
vec_ub)
|
|
197
|
-
self.tik_instance.h_cast(Pij_ub, Pij_ub_fp32, "none")
|
|
198
|
-
Pij_l1_K1MK0_ed = self.tik_instance.Tensor(
|
|
199
|
-
FP16, (block_k_aligned_aligned // 16, block_h_aligned, 16), name="Pij_l1_K1MK0_ed", scope=L1
|
|
200
|
-
)
|
|
201
|
-
self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_ub,
|
|
202
|
-
burst=block_k_aligned_aligned * block_h_aligned // 16)
|
|
203
|
-
Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
|
|
204
|
-
kv_blk_height, self.actual_d, N1MN0_to_MN=False,
|
|
205
|
-
precision_type=self.precision_type) # Pij*Vj
|
|
206
|
-
self.tik_instance.data_move(dst=self.O_gm_workspace[o_gm_offset], src=Pij_Vj_matmul_res_ub, sid=0,
|
|
207
|
-
nburst=self.N1, burst=block_h * self.N0 // 8,
|
|
208
|
-
src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 8)
|
|
209
|
-
with self.tik_instance.else_scope():
|
|
210
|
-
mi_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="mi_old_ub", scope=UB)
|
|
211
|
-
li_ub = self.tik_instance.Tensor(FP32, (block_h_aligned,), name="li_ub", scope=UB)
|
|
212
|
-
self.tik_ops_utils.move_vector_from_gm_to_ub(mi_ub, self.m_gm, vec_gm_offset, block_h)
|
|
213
|
-
self.tik_ops_utils.move_vector_from_gm_to_ub(li_ub, self.l_gm, vec_gm_offset, block_h)
|
|
214
|
-
mi_new_ub, li_new_ub = self.update_m_l(mi_ub, mij_ub, li_ub, lij_ub, block_h)
|
|
215
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, li_new_ub, vec_gm_offset, block_h)
|
|
216
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mi_new_ub, vec_gm_offset, block_h)
|
|
217
|
-
|
|
218
|
-
exp_m_old_fp32 = self.tik_instance.Tensor(FP32, (block_h_aligned,), scope=UB, name="exp_m_old_fp32")
|
|
219
|
-
exp_m_cur_fp32 = self.tik_instance.Tensor(FP32, (block_h_aligned,), scope=UB, name="exp_m_cur_fp32")
|
|
220
|
-
self.tik_instance.h_cast(exp_m_old_fp32, mi_ub, "none")
|
|
221
|
-
self.tik_instance.h_cast(exp_m_cur_fp32, mij_ub, "none")
|
|
222
|
-
|
|
223
|
-
li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(li_new_ub, block_h)
|
|
224
|
-
self.tik_instance.h_mul(exp_m_cur_fp32, exp_m_cur_fp32, li_new_rec_ub)
|
|
225
|
-
exp_m_cur_fp32_vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="exp_m_cur_fp32_vec_ub",
|
|
226
|
-
scope=UB)
|
|
227
|
-
for i in range(block_h):
|
|
228
|
-
src_scalar = self.tik_instance.Scalar(init_value=exp_m_cur_fp32[i], dtype=FP32)
|
|
229
|
-
self.tik_instance.h_duplicate(exp_m_cur_fp32_vec_ub[i, :], src_scalar)
|
|
230
|
-
exp_m_cur_fp32_vec_ub = exp_m_cur_fp32_vec_ub.reshape((1, block_h, self.N0))
|
|
231
|
-
with self.tik_instance.for_range(0, n1) as idx:
|
|
232
|
-
self.tik_instance.h_mul(Pij_ub_fp32[idx, :, :],
|
|
233
|
-
Pij_ub_fp32[idx, :, :],
|
|
234
|
-
exp_m_cur_fp32_vec_ub)
|
|
235
|
-
self.tik_instance.h_cast(Pij_ub, Pij_ub_fp32, "none")
|
|
236
|
-
Pij_l1_K1MK0_ed = self.tik_instance.Tensor(
|
|
237
|
-
FP16, (block_k_aligned_aligned // 16, block_h_aligned, 16), name="Pij_l1_K1MK0_ed", scope=L1
|
|
238
|
-
)
|
|
239
|
-
self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_ub,
|
|
240
|
-
burst=block_k_aligned_aligned * block_h_aligned // 16)
|
|
241
|
-
Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
|
|
242
|
-
kv_blk_height, self.actual_d, N1MN0_to_MN=False,
|
|
243
|
-
precision_type=self.precision_type) # Pij*Vj
|
|
244
|
-
n1, m, n0 = Pij_Vj_matmul_res_ub.shape
|
|
245
|
-
Oi_ub = self.tik_instance.Tensor(FP32, (n1, m, n0), name="Oi_ub", scope=UB)
|
|
246
|
-
self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm_workspace[o_gm_offset],
|
|
247
|
-
sid=0, nburst=self.N1, burst=m * self.N0 // 8,
|
|
248
|
-
src_stride=(self.Nq - m) * self.N0 // 8, dst_stride=0)
|
|
249
|
-
|
|
250
|
-
self.tik_instance.h_mul(li_new_rec_ub, li_new_rec_ub, li_ub)
|
|
251
|
-
self.tik_instance.h_mul(li_new_rec_ub, li_new_rec_ub, exp_m_old_fp32)
|
|
252
|
-
li_new_rec_vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="li_new_rec_vec_ub",
|
|
253
|
-
scope=UB)
|
|
254
|
-
for i in range(block_h):
|
|
255
|
-
src_scalar = self.tik_instance.Scalar(init_value=li_new_rec_ub[i], dtype=FP32)
|
|
256
|
-
self.tik_instance.h_duplicate(li_new_rec_vec_ub[i, :], src_scalar)
|
|
257
|
-
li_new_rec_vec_ub = li_new_rec_vec_ub.reshape((1, block_h, self.N0))
|
|
258
|
-
with self.tik_instance.for_range(0, n1) as idx:
|
|
259
|
-
self.tik_instance.h_mul(Oi_ub[idx, :, :],
|
|
260
|
-
Oi_ub[idx, :, :],
|
|
261
|
-
li_new_rec_vec_ub)
|
|
262
|
-
self.tik_instance.h_add(Oi_ub, Oi_ub, Pij_Vj_matmul_res_ub)
|
|
263
|
-
self.tik_instance.data_move(dst=self.O_gm_workspace[o_gm_offset], src=Oi_ub, sid=0,
|
|
264
|
-
nburst=self.N1, burst=block_h * self.N0 // 8,
|
|
265
|
-
src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 8)
|
|
266
|
-
|
|
267
|
-
def exp_Pij_Vj(self, exp_mij_sub_mi_new, Pij_Vj_ub, block_h_aligned):
|
|
268
|
-
"""Refer to Algorithm 2 line15
|
|
269
|
-
exp(mij - mi_new) * Pij * Vj
|
|
270
|
-
"""
|
|
271
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
272
|
-
broadcast_exp_mij_sub_mi_new = self.tik_ops_utils.broadcast(exp_mij_sub_mi_new, (block_h_aligned, self.d))
|
|
273
|
-
self.tik_instance.h_mul(Pij_Vj_ub, Pij_Vj_ub, broadcast_exp_mij_sub_mi_new)
|
|
274
|
-
return Pij_Vj_ub
|
|
275
|
-
|
|
276
|
-
def update_o_m_l(self,
|
|
277
|
-
Pij_l1_K1MK0_ed,
|
|
278
|
-
Vj_l1_K1NK0_ed,
|
|
279
|
-
mij_ub,
|
|
280
|
-
lij_ub,
|
|
281
|
-
batch_start,
|
|
282
|
-
batch_idx,
|
|
283
|
-
kv_blk_idx,
|
|
284
|
-
kv_blk_height,
|
|
285
|
-
q_blk_idx,
|
|
286
|
-
block_h):
|
|
287
|
-
"""Refer to Algorithm 2 line13 and line15 in FlashAttention"""
|
|
288
|
-
vec_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
|
|
289
|
-
o_gm_offset = self.get_gm_offset(
|
|
290
|
-
batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx
|
|
291
|
-
)
|
|
292
|
-
block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
|
|
293
|
-
|
|
294
|
-
Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
|
|
295
|
-
kv_blk_height, self.actual_d,
|
|
296
|
-
N1MN0_to_MN=False) # Pij*Vj
|
|
297
|
-
n1, m, n0 = Pij_Vj_matmul_res_ub.shape
|
|
298
|
-
with self.tik_instance.if_scope(tik.any(kv_blk_idx == 0, kv_blk_idx + self.prev_block_num == q_blk_idx)):
|
|
299
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, lij_ub, vec_gm_offset, block_h)
|
|
300
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mij_ub, vec_gm_offset, block_h)
|
|
301
|
-
li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(lij_ub, block_h)
|
|
302
|
-
broadcast_li_new_rec_ub = self.tik_ops_utils.broadcast(li_new_rec_ub, (m, n0))
|
|
303
|
-
broadcast_li_new_rec_ub = broadcast_li_new_rec_ub.reshape((1, m, n0))
|
|
304
|
-
with self.tik_instance.for_range(0, n1) as idx:
|
|
305
|
-
self.tik_instance.h_mul(Pij_Vj_matmul_res_ub[idx, :, :],
|
|
306
|
-
Pij_Vj_matmul_res_ub[idx, :, :],
|
|
307
|
-
broadcast_li_new_rec_ub)
|
|
308
|
-
self.tik_instance.data_move(dst=self.O_gm[o_gm_offset], src=Pij_Vj_matmul_res_ub, sid=0,
|
|
309
|
-
nburst=self.N1, burst=block_h * self.N0 // 16,
|
|
310
|
-
src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 16)
|
|
311
|
-
|
|
312
|
-
with self.tik_instance.else_scope():
|
|
313
|
-
mi_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="mi_old_ub", scope=UB)
|
|
314
|
-
li_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="li_ub", scope=UB)
|
|
315
|
-
self.tik_ops_utils.move_vector_from_gm_to_ub(mi_ub, self.m_gm, vec_gm_offset, block_h)
|
|
316
|
-
self.tik_ops_utils.move_vector_from_gm_to_ub(li_ub, self.l_gm, vec_gm_offset, block_h)
|
|
317
|
-
mi_new_ub, li_new_ub = self.update_m_l(mi_ub, mij_ub, li_ub, lij_ub, block_h)
|
|
318
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, li_new_ub, vec_gm_offset, block_h)
|
|
319
|
-
self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mi_new_ub, vec_gm_offset, block_h)
|
|
320
|
-
exp_mi_sub_mi_new = mi_ub
|
|
321
|
-
exp_mij_sub_mi_new = mij_ub
|
|
322
|
-
|
|
323
|
-
li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(li_new_ub, block_h)
|
|
324
|
-
self.tik_instance.h_mul(li_ub, li_ub, exp_mi_sub_mi_new)
|
|
325
|
-
self.tik_instance.h_mul(li_ub, li_ub, li_new_rec_ub)
|
|
326
|
-
scale1 = li_ub
|
|
327
|
-
self.tik_instance.h_mul(exp_mij_sub_mi_new, exp_mij_sub_mi_new, li_new_rec_ub)
|
|
328
|
-
scale2 = exp_mij_sub_mi_new
|
|
329
|
-
Oi_ub = self.tik_instance.Tensor(FP16, (n1, m, n0), name="Oi_ub", scope=UB)
|
|
330
|
-
self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm[o_gm_offset],
|
|
331
|
-
sid=0, nburst=self.N1, burst=m * self.N0 // 16,
|
|
332
|
-
src_stride=(self.Nq - m) * self.N0 // 16, dst_stride=0)
|
|
333
|
-
broadcast_scale1 = self.tik_ops_utils.broadcast(scale1, (m, n0))
|
|
334
|
-
broadcast_scale1 = broadcast_scale1.reshape((1, m, n0))
|
|
335
|
-
with self.tik_instance.for_range(0, n1) as idx:
|
|
336
|
-
self.tik_instance.h_mul(Oi_ub[idx, :, :], Oi_ub[idx, :, :], broadcast_scale1)
|
|
337
|
-
broadcast_scale2 = self.tik_ops_utils.broadcast(scale2, (m, n0))
|
|
338
|
-
broadcast_scale2 = broadcast_scale2.reshape((1, m, n0))
|
|
339
|
-
with self.tik_instance.for_range(0, n1) as idx:
|
|
340
|
-
self.tik_instance.h_mul(Pij_Vj_matmul_res_ub[idx, :, :],
|
|
341
|
-
Pij_Vj_matmul_res_ub[idx, :, :],
|
|
342
|
-
broadcast_scale2)
|
|
343
|
-
self.tik_instance.h_add(Oi_ub, Oi_ub, Pij_Vj_matmul_res_ub)
|
|
344
|
-
self.tik_instance.data_move(dst=self.O_gm[o_gm_offset], src=Oi_ub, sid=0,
|
|
345
|
-
nburst=self.N1, burst=block_h * self.N0 // 16,
|
|
346
|
-
src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 16)
|
|
347
|
-
|
|
348
|
-
def compute_in_each_kv_block(self, batch_start, batch_idx, kv_blk_idx, kv_blk_height,
|
|
349
|
-
core_idx_to_tr_info, core_idx):
|
|
350
|
-
"""The forward computation in each outer loop"""
|
|
351
|
-
kv_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
|
|
352
|
-
kv_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.N, self.d, self.Bc, kv_blk_idx)
|
|
353
|
-
# load Kj (kv_blk_idx_th block of K_gm)
|
|
354
|
-
KjT_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
|
|
355
|
-
name="KjT_l1_K1MK0_ed", scope=L1)
|
|
356
|
-
self.tik_instance.data_move(dst=KjT_l1_K1MK0_ed, src=self.K_gm[kv_gm_offset],
|
|
357
|
-
sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
|
|
358
|
-
src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
|
|
359
|
-
|
|
360
|
-
# load Vj (kv_blk_idx_th block of V_gm), then reorder for Pij*Vj
|
|
361
|
-
Vj_l1 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Vj_l1", scope=L1)
|
|
362
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
363
|
-
Vj_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
|
|
364
|
-
name="Vj_ub", scope=UB)
|
|
365
|
-
self.tik_instance.data_move(dst=Vj_ub, src=self.V_gm[kv_gm_offset],
|
|
366
|
-
sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
|
|
367
|
-
src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
|
|
368
|
-
# (N1, K, N0) -> (K, N)
|
|
369
|
-
Vj_ub = self.tik_ops_utils.N1MN0_TO_MN(Vj_ub)
|
|
370
|
-
# (K, N) -> (K1, N, K0)
|
|
371
|
-
Vj_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Vj_ub, workspace_tensor=Vj_l1)
|
|
372
|
-
|
|
373
|
-
tr_start_s = self.tik_instance.Scalar("int32", name="tr_start_s")
|
|
374
|
-
tr_end_s = self.tik_instance.Scalar("int32", name="tr_end_s")
|
|
375
|
-
tr_start_s.set_as(core_idx_to_tr_info[core_idx, batch_start + batch_idx, 0])
|
|
376
|
-
tr_end_s.set_as(core_idx_to_tr_info[core_idx, batch_start + batch_idx, 1])
|
|
377
|
-
with self.tik_instance.for_range(tr_start_s, tr_end_s, name="q_blk_idx") as q_blk_idx:
|
|
378
|
-
# 根据atten_mask倒三角特性,过滤无效计算
|
|
379
|
-
with self.tik_instance.if_scope(tik.all(kv_blk_idx - self.next_block_num <= q_blk_idx,
|
|
380
|
-
q_blk_idx <= kv_blk_idx + self.prev_block_num)):
|
|
381
|
-
with self.tik_instance.if_scope(q_blk_idx != self.Tr - 1):
|
|
382
|
-
self.compute_in_each_q_block(KjT_l1_K1MK0_ed, Vj_l1_K1NK0_ed, batch_idx,
|
|
383
|
-
batch_start,
|
|
384
|
-
kv_blk_height, self.Br, q_blk_idx, kv_blk_idx)
|
|
385
|
-
with self.tik_instance.else_scope():
|
|
386
|
-
self.compute_in_each_q_block(KjT_l1_K1MK0_ed, Vj_l1_K1NK0_ed, batch_idx,
|
|
387
|
-
batch_start,
|
|
388
|
-
kv_blk_height, self.last_Br, q_blk_idx, kv_blk_idx)
|
|
389
|
-
|
|
390
|
-
def compute_in_each_q_block(self, KjT_l1_K1MK0_ed, Vj_l1_K1NK0_ed, batch_idx, batch_start,
|
|
391
|
-
kv_blk_height, q_blk_height, q_blk_idx, kv_blk_idx):
|
|
392
|
-
"""The forward computation in each inner loop"""
|
|
393
|
-
kv_blk_h_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
|
|
394
|
-
q_blk_h_aligned = self.tik_ops_utils.up_align_to_K0(q_blk_height)
|
|
395
|
-
# load Qi (q_blk_idx_th block of Q_gm), then reorder it fo Qi*KjT
|
|
396
|
-
q_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
|
|
397
|
-
Qi_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_h_aligned, self.N0),
|
|
398
|
-
scope=L1, name="Qi_l1_K1MK0_ed")
|
|
399
|
-
self.tik_instance.data_move(dst=Qi_l1_K1MK0_ed, src=self.Q_gm[q_gm_offset],
|
|
400
|
-
sid=0, nburst=self.N1, burst=q_blk_h_aligned * self.N0 // 16,
|
|
401
|
-
src_stride=(self.Nq - q_blk_h_aligned) * self.N0 // 16, dst_stride=0)
|
|
402
|
-
|
|
403
|
-
lij_ub = self.tik_instance.Tensor(self.precision_type, (q_blk_h_aligned,), scope=UB, name="lij_ub")
|
|
404
|
-
mij_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned,), scope=UB, name="mij_ub")
|
|
405
|
-
|
|
406
|
-
Sij_ub_N1MN0 = self.tik_ops_utils.matmul_compute(Qi_l1_K1MK0_ed, KjT_l1_K1MK0_ed, m=q_blk_height,
|
|
407
|
-
k=self.actual_d, n=kv_blk_height,
|
|
408
|
-
N1MN0_to_MN=False) # Qi*KjT
|
|
409
|
-
if self.has_alibi_mask:
|
|
410
|
-
alibi_mask_gm_offset = self.get_alibi_gm_offset(batch_start, batch_idx, self.N, self.Bc, kv_blk_idx)
|
|
411
|
-
self.do_alibi_mask(Sij_ub_N1MN0, alibi_mask_gm_offset, q_blk_h_aligned, kv_blk_h_aligned)
|
|
412
|
-
|
|
413
|
-
# att_mask
|
|
414
|
-
if self.has_attn_mask:
|
|
415
|
-
attn_mask_gm_offset = self.get_attn_mask_gm_offset(batch_start, batch_idx, self.Nq, self.N,
|
|
416
|
-
self.Br, q_blk_idx, self.Bc, kv_blk_idx)
|
|
417
|
-
self.do_att_mask(Sij_ub_N1MN0, attn_mask_gm_offset, q_blk_height, kv_blk_height,
|
|
418
|
-
q_blk_h_aligned, kv_blk_h_aligned)
|
|
419
|
-
|
|
420
|
-
Pij_N1MN0, mij_ub, lij_ub = self.softmax_compute(
|
|
421
|
-
Sij_ub_N1MN0, mij_ub, lij_ub, q_blk_height, kv_blk_height
|
|
422
|
-
) # self.high_precision=True, Pij_ub return type fp32
|
|
423
|
-
# dropout_mask
|
|
424
|
-
if self.has_drop_mask:
|
|
425
|
-
dropout_mask_gm_offset = self.get_drop_mask_gm_offset(batch_start, batch_idx, self.Nq,
|
|
426
|
-
self.N, self.Br, q_blk_idx, self.Bc, kv_blk_idx)
|
|
427
|
-
self.do_dropout_mask(Pij_N1MN0, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
|
|
428
|
-
q_blk_h_aligned, q_blk_height, precision_type=self.precision_type)
|
|
429
|
-
if not self.high_precision:
|
|
430
|
-
Pij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16,
|
|
431
|
-
(kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
|
|
432
|
-
name="Pij_l1_K1MK0_ed", scope=L1)
|
|
433
|
-
self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_N1MN0,
|
|
434
|
-
burst=kv_blk_h_aligned * q_blk_h_aligned // 16)
|
|
435
|
-
Pij_N1MN0 = Pij_l1_K1MK0_ed
|
|
436
|
-
if self.high_precision:
|
|
437
|
-
self.update_o_m_l_fp32(
|
|
438
|
-
Pij_N1MN0,
|
|
439
|
-
Vj_l1_K1NK0_ed,
|
|
440
|
-
Sij_ub_N1MN0,
|
|
441
|
-
mij_ub,
|
|
442
|
-
lij_ub,
|
|
443
|
-
batch_start,
|
|
444
|
-
batch_idx,
|
|
445
|
-
kv_blk_idx,
|
|
446
|
-
kv_blk_height,
|
|
447
|
-
q_blk_idx,
|
|
448
|
-
q_blk_height
|
|
449
|
-
)
|
|
450
|
-
else:
|
|
451
|
-
self.update_o_m_l(
|
|
452
|
-
Pij_N1MN0,
|
|
453
|
-
Vj_l1_K1NK0_ed,
|
|
454
|
-
mij_ub,
|
|
455
|
-
lij_ub,
|
|
456
|
-
batch_start,
|
|
457
|
-
batch_idx,
|
|
458
|
-
kv_blk_idx,
|
|
459
|
-
kv_blk_height,
|
|
460
|
-
q_blk_idx,
|
|
461
|
-
q_blk_height
|
|
462
|
-
)
|
|
463
|
-
|
|
464
|
-
def compute_one_core(self, batch_start_sc, batch_num_sc, core_idx_to_tr_info, core_idx):
|
|
465
|
-
"""The computation of FlashAttention forward on each core"""
|
|
466
|
-
with self.tik_instance.for_range(0, batch_num_sc, name="batch_index") as batch_idx:
|
|
467
|
-
with self.tik_instance.for_range(0, self.Tc, name="kv_blk_idx") as kv_blk_idx:
|
|
468
|
-
with self.tik_instance.if_scope(kv_blk_idx != self.Tc - 1):
|
|
469
|
-
self.compute_in_each_kv_block(batch_start_sc, batch_idx, kv_blk_idx, self.Bc,
|
|
470
|
-
core_idx_to_tr_info, core_idx)
|
|
471
|
-
with self.tik_instance.else_scope():
|
|
472
|
-
self.compute_in_each_kv_block(batch_start_sc, batch_idx, kv_blk_idx, self.last_Bc,
|
|
473
|
-
core_idx_to_tr_info, core_idx)
|
|
474
|
-
if self.high_precision:
|
|
475
|
-
block_h = 128
|
|
476
|
-
gm_offset = (batch_start_sc + batch_idx) * (self.Nq * self.d)
|
|
477
|
-
temp_ub = self.tik_instance.Tensor(FP32, (block_h, self.d), name="temp_ub", scope=UB)
|
|
478
|
-
temp_ub_fp16 = self.tik_instance.Tensor(FP16, (block_h, self.d), name="temp_ub_fp16", scope=UB)
|
|
479
|
-
try:
|
|
480
|
-
dtype32_size = DTYPE_SIZE[FP32]
|
|
481
|
-
except KeyError:
|
|
482
|
-
raise ValueError("The argument 'FP32' is not valid.")
|
|
483
|
-
try:
|
|
484
|
-
dtype16_size = DTYPE_SIZE[FP16]
|
|
485
|
-
except KeyError:
|
|
486
|
-
raise ValueError("The argument 'FP16' is not valid.")
|
|
487
|
-
if self.Nq // block_h // 2 > 0:
|
|
488
|
-
with self.tik_instance.for_range(0, self.Nq // block_h // 2) as i:
|
|
489
|
-
with self.tik_instance.for_range(0, 2, thread_num=2) as t_idx:
|
|
490
|
-
index = i * 2 + t_idx
|
|
491
|
-
gm_offset += index * (block_h * self.d)
|
|
492
|
-
self.cont_data_mv_1_bust(dst=temp_ub, src=self.O_gm_workspace[gm_offset],
|
|
493
|
-
burst=block_h * self.d * dtype32_size // 32)
|
|
494
|
-
self.tik_instance.h_cast(temp_ub_fp16, temp_ub, "none")
|
|
495
|
-
self.cont_data_mv_1_bust(dst=self.O_gm[gm_offset], src=temp_ub_fp16,
|
|
496
|
-
burst=block_h * self.d * dtype16_size // 32)
|
|
497
|
-
if self.Nq % (block_h * 2) > 0:
|
|
498
|
-
gm_offset = (batch_start_sc + batch_idx) * (self.Nq * self.d) + \
|
|
499
|
-
(self.Nq // (block_h * 2) * 2) * (block_h * self.d)
|
|
500
|
-
last_block_h = self.Nq % (block_h * 2)
|
|
501
|
-
self.cont_data_mv_1_bust(dst=temp_ub, src=self.O_gm_workspace[gm_offset],
|
|
502
|
-
burst=last_block_h * self.d * dtype32_size // 32)
|
|
503
|
-
self.tik_instance.h_cast(temp_ub_fp16, temp_ub, "none")
|
|
504
|
-
self.cont_data_mv_1_bust(dst=self.O_gm[gm_offset], src=temp_ub_fp16,
|
|
505
|
-
burst=last_block_h * self.d * dtype16_size // 32)
|
|
506
|
-
|
|
507
|
-
def collect_inputs(self):
|
|
508
|
-
"""collect all input gm tensors into input_gm_list,
|
|
509
|
-
the input list should keep order with the para order in Primitive and init
|
|
510
|
-
"""
|
|
511
|
-
input_gm_list = [self.Q_gm, self.K_gm, self.V_gm]
|
|
512
|
-
if self.has_attn_mask:
|
|
513
|
-
input_gm_list.append(self.att_mask_gm)
|
|
514
|
-
if self.has_drop_mask:
|
|
515
|
-
input_gm_list.append(self.drop_mask_gm)
|
|
516
|
-
if self.has_alibi_mask:
|
|
517
|
-
input_gm_list.append(self.alibi_mask_gm)
|
|
518
|
-
|
|
519
|
-
return input_gm_list
|
|
520
|
-
|
|
521
|
-
def collect_outputs(self):
|
|
522
|
-
"""collect all output gm tensors into output_gm_list,
|
|
523
|
-
the output list should keep order with the para order in Primitive and init
|
|
524
|
-
"""
|
|
525
|
-
output_gm_list = [self.O_gm, self.l_gm, self.m_gm]
|
|
526
|
-
return output_gm_list
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
def flash_attention(query, key, value, attn_mask, dropout_mask, alibi_mask, output, rowsum, rowmax,
|
|
530
|
-
prev_block_num=65536, next_block_num=65536, high_precision=False, tiling_stgy_name='sparse',
|
|
531
|
-
kernel_name="flash_attention", disable_debug=True):
|
|
532
|
-
"""
|
|
533
|
-
algorithm: flash_attention_backward
|
|
534
|
-
|
|
535
|
-
Parameters
|
|
536
|
-
----------
|
|
537
|
-
query : dict. shape and dtype of input, only support float16
|
|
538
|
-
key : dict. shape and dtype of input, only support float16
|
|
539
|
-
value: dict. shape and dtype of input, only support float16
|
|
540
|
-
attn_mask: dict. shape and dtype of input, only support float16
|
|
541
|
-
dropout_mask: dict. shape and dtype of input, only support float16
|
|
542
|
-
dropout_mask: dict. shape and dtype of input, only support float16
|
|
543
|
-
alibi_mask: dict. shape and dtype of input, only support float16
|
|
544
|
-
output: dict. shape and dtype of output, only support float16
|
|
545
|
-
rowsum: dict. shape and dtype of output, only support float16
|
|
546
|
-
rowmax: dict. shape and dtype of output, only support float16
|
|
547
|
-
prev_block_num: int. an attribute used to define sparse attention
|
|
548
|
-
next_block_num: int. an attribute used to define sparse attention
|
|
549
|
-
tiling_stgy_name: str. an attribute used to choose the tiling strategy
|
|
550
|
-
kernel_name: str. cce kernel name, default value is real_div
|
|
551
|
-
disable_debug: bool. whether disable debug
|
|
552
|
-
|
|
553
|
-
Returns
|
|
554
|
-
-------
|
|
555
|
-
tik_instance
|
|
556
|
-
"""
|
|
557
|
-
fa = FlashAttentionFwd(query=query, key=key, value=value, attn_mask=attn_mask,
|
|
558
|
-
dropout_mask=dropout_mask, alibi_mask=alibi_mask, kernel_name=kernel_name,
|
|
559
|
-
tiling_stgy=TilingStrategy.from_strategy_name(tiling_stgy_name),
|
|
560
|
-
prev_block_num=prev_block_num, next_block_num=next_block_num,
|
|
561
|
-
high_precision=high_precision, disable_debug=disable_debug)
|
|
562
|
-
fa.compute_process()
|
|
563
|
-
return fa.tik_instance
|