mindspore 2.2.10__cp39-cp39-win_amd64.whl → 2.2.14__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +2 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +12 -2
- mindspore/common/_utils.py +16 -0
- mindspore/common/tensor.py +0 -2
- mindspore/communication/management.py +3 -0
- mindspore/context.py +34 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets.py +23 -0
- mindspore/dataset/engine/validators.py +1 -1
- mindspore/dataset/vision/py_transforms_util.py +2 -2
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/optim/lr_scheduler.py +5 -6
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +49 -57
- mindspore/mindrecord/tools/cifar10_to_mr.py +46 -55
- mindspore/mindrecord/tools/csv_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +4 -9
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -4
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/activation.py +1 -1
- mindspore/nn/layer/embedding.py +2 -2
- mindspore/nn/layer/flash_attention.py +48 -135
- mindspore/nn/loss/loss.py +1 -1
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/sgd.py +3 -2
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +6 -3
- mindspore/numpy/math_ops.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -31
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +37 -17
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/function/array_func.py +6 -5
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/linalg_func.py +21 -11
- mindspore/ops/function/math_func.py +3 -0
- mindspore/ops/function/nn_func.py +13 -11
- mindspore/ops/function/parameter_func.py +2 -0
- mindspore/ops/function/sparse_unary_func.py +2 -2
- mindspore/ops/function/vmap_func.py +1 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +3 -4
- mindspore/ops/operations/_inner_ops.py +56 -1
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +15 -4
- mindspore/ops/operations/custom_ops.py +1 -1
- mindspore/ops/operations/debug_ops.py +1 -1
- mindspore/ops/operations/image_ops.py +3 -3
- mindspore/ops/operations/inner_ops.py +49 -0
- mindspore/ops/operations/math_ops.py +65 -3
- mindspore/ops/operations/nn_ops.py +95 -28
- mindspore/ops/operations/random_ops.py +2 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/silent_check.py +162 -0
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +82 -3
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_tensor.py +3 -1
- mindspore/parallel/_transformer/transformer.py +8 -8
- mindspore/parallel/checkpoint_transform.py +191 -45
- mindspore/profiler/parser/ascend_cluster_generator.py +111 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +315 -0
- mindspore/profiler/parser/ascend_flops_generator.py +8 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +2 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +30 -6
- mindspore/profiler/parser/ascend_msprof_generator.py +16 -5
- mindspore/profiler/parser/ascend_op_generator.py +15 -7
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -2
- mindspore/profiler/parser/base_timeline_generator.py +11 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +8 -2
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +4 -2
- mindspore/profiler/parser/msadvisor_parser.py +9 -3
- mindspore/profiler/profiling.py +97 -25
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/api/symbol_tree.py +2 -2
- mindspore/rewrite/parsers/for_parser.py +6 -6
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/callback/_checkpoint.py +8 -8
- mindspore/train/callback/_landscape.py +2 -3
- mindspore/train/callback/_summary_collector.py +6 -7
- mindspore/train/dataset_helper.py +6 -0
- mindspore/train/model.py +17 -5
- mindspore/train/serialization.py +6 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +5 -6
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/METADATA +3 -2
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/RECORD +117 -124
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/WHEEL +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/top_level.txt +0 -0
|
@@ -1,193 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""The impl of flash attention"""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
|
-
import mindspore.ops as ops
|
|
18
|
-
import mindspore.common.dtype as mstype
|
|
19
|
-
from mindspore.ops import Custom
|
|
20
|
-
from mindspore.ops import DataType
|
|
21
|
-
from mindspore.ops import TBERegOp
|
|
22
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_bwd import flash_attention_grad
|
|
23
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_fwd import flash_attention
|
|
24
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
25
|
-
|
|
26
|
-
KERNEL_NAME = "flash_attention"
|
|
27
|
-
|
|
28
|
-
cus_flash_atten_op_info = TBERegOp("FlashAttentionPrimitive") \
|
|
29
|
-
.fusion_type("OPAQUE") \
|
|
30
|
-
.partial_flag(True) \
|
|
31
|
-
.async_flag(False) \
|
|
32
|
-
.binfile_name("flash_attention.so") \
|
|
33
|
-
.compute_cost(10) \
|
|
34
|
-
.kernel_name(KERNEL_NAME) \
|
|
35
|
-
.attr("prev_block_num", "required", "int", "all", "65536") \
|
|
36
|
-
.attr("next_block_num", "required", "int", "all", "65536") \
|
|
37
|
-
.attr("high_precision", "required", "bool", "all", "false") \
|
|
38
|
-
.attr("tiling_stgy_name", "required", "str", "all", "sparse") \
|
|
39
|
-
.input(0, "query", False, "required", "all") \
|
|
40
|
-
.input(1, "key", False, "required", "all") \
|
|
41
|
-
.input(2, "value", False, "required", "all") \
|
|
42
|
-
.input(3, "attn_mask", False, "optional", "all") \
|
|
43
|
-
.input(4, "dropout_mask", False, "optional", "all") \
|
|
44
|
-
.input(5, "alibi_mask", False, "optional", "all") \
|
|
45
|
-
.output(0, "output", False, "required", "all") \
|
|
46
|
-
.output(1, "rowsum", False, "required", "all") \
|
|
47
|
-
.output(2, "rowmax", False, "required", "all") \
|
|
48
|
-
.dtype_format(DataType.F16_FracNZ,
|
|
49
|
-
DataType.F16_FracNZ,
|
|
50
|
-
DataType.F16_FracNZ,
|
|
51
|
-
DataType.F16_FracNZ,
|
|
52
|
-
DataType.F16_Default,
|
|
53
|
-
DataType.F16_FracNZ,
|
|
54
|
-
DataType.F16_FracNZ,
|
|
55
|
-
DataType.F16_Default,
|
|
56
|
-
DataType.F16_Default) \
|
|
57
|
-
.dtype_format(DataType.F16_FracNZ,
|
|
58
|
-
DataType.F16_FracNZ,
|
|
59
|
-
DataType.F16_FracNZ,
|
|
60
|
-
DataType.F16_FracNZ,
|
|
61
|
-
DataType.F16_Default,
|
|
62
|
-
DataType.F16_FracNZ,
|
|
63
|
-
DataType.F16_FracNZ,
|
|
64
|
-
DataType.F32_Default,
|
|
65
|
-
DataType.F16_Default) \
|
|
66
|
-
.get_op_info()
|
|
67
|
-
|
|
68
|
-
GRAD_KERNEL_NAME = "flash_attention_grad"
|
|
69
|
-
|
|
70
|
-
cus_flash_atten_grad_op_info = TBERegOp("FlashAttentionGradPrimitive") \
|
|
71
|
-
.fusion_type("OPAQUE") \
|
|
72
|
-
.partial_flag(True) \
|
|
73
|
-
.async_flag(False) \
|
|
74
|
-
.binfile_name("flash_attention_grad.so") \
|
|
75
|
-
.compute_cost(10) \
|
|
76
|
-
.kernel_name(GRAD_KERNEL_NAME) \
|
|
77
|
-
.attr("prev_block_num", "required", "int", "all", "65536") \
|
|
78
|
-
.attr("next_block_num", "required", "int", "all", "65536") \
|
|
79
|
-
.attr("high_precision", "required", "bool", "all", "false") \
|
|
80
|
-
.attr("tiling_stgy_name", "required", "str", "all", "sparse") \
|
|
81
|
-
.input(0, "query", False, "required", "all") \
|
|
82
|
-
.input(1, "key", False, "required", "all") \
|
|
83
|
-
.input(2, "value", False, "required", "all") \
|
|
84
|
-
.input(3, "output", False, "required", "all") \
|
|
85
|
-
.input(4, "do", False, "required", "all") \
|
|
86
|
-
.input(5, "rowsum", False, "required", "all") \
|
|
87
|
-
.input(6, "rowmax", False, "required", "all") \
|
|
88
|
-
.input(7, "attn_mask", False, "optional", "all") \
|
|
89
|
-
.input(8, "dropout_mask", False, "optional", "all") \
|
|
90
|
-
.input(9, "alibi_mask", False, "optional", "all") \
|
|
91
|
-
.output(0, "dq", False, "required", "all") \
|
|
92
|
-
.output(1, "dk", False, "required", "all") \
|
|
93
|
-
.output(2, "dv", False, "required", "all") \
|
|
94
|
-
.dtype_format(DataType.F16_FracNZ,
|
|
95
|
-
DataType.F16_FracNZ,
|
|
96
|
-
DataType.F16_FracNZ,
|
|
97
|
-
DataType.F16_FracNZ,
|
|
98
|
-
DataType.F16_FracNZ,
|
|
99
|
-
DataType.F16_Default,
|
|
100
|
-
DataType.F16_Default,
|
|
101
|
-
DataType.F16_FracNZ,
|
|
102
|
-
DataType.F16_Default,
|
|
103
|
-
DataType.F16_FracNZ,
|
|
104
|
-
DataType.F32_FracNZ,
|
|
105
|
-
DataType.F32_FracNZ,
|
|
106
|
-
DataType.F32_FracNZ) \
|
|
107
|
-
.dtype_format(DataType.F16_FracNZ,
|
|
108
|
-
DataType.F16_FracNZ,
|
|
109
|
-
DataType.F16_FracNZ,
|
|
110
|
-
DataType.F16_FracNZ,
|
|
111
|
-
DataType.F16_FracNZ,
|
|
112
|
-
DataType.F32_Default,
|
|
113
|
-
DataType.F16_Default,
|
|
114
|
-
DataType.F16_FracNZ,
|
|
115
|
-
DataType.F16_Default,
|
|
116
|
-
DataType.F16_FracNZ,
|
|
117
|
-
DataType.F32_FracNZ,
|
|
118
|
-
DataType.F32_FracNZ,
|
|
119
|
-
DataType.F32_FracNZ) \
|
|
120
|
-
.get_op_info()
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
|
|
124
|
-
tiling_stgy_name='sparse', high_precision=False):
|
|
125
|
-
"""get flash attention grad"""
|
|
126
|
-
|
|
127
|
-
def infer_shape(q_shape, k_shape, v_shape, o_shape, do_shape, l_shape, m_shape,
|
|
128
|
-
att_mask_shape, dropout_mask_shape, alibi_mask_shape):
|
|
129
|
-
return q_shape, k_shape, v_shape
|
|
130
|
-
|
|
131
|
-
def infer_dtype(q_dtype, k_dtype, v_dtype, o_dytpe, do_dtype, l_dtype, m_dtype,
|
|
132
|
-
attn_mask_dtype, dropout_mask_dtype, alibi_mask_type):
|
|
133
|
-
return mstype.float32, mstype.float32, mstype.float32
|
|
134
|
-
|
|
135
|
-
fa_grad = Custom(flash_attention_grad, out_shape=infer_shape,
|
|
136
|
-
out_dtype=infer_dtype, func_type="tbe", reg_info=cus_flash_atten_grad_op_info)
|
|
137
|
-
fa_grad.add_prim_attr("prev_block_num", prev_block_num)
|
|
138
|
-
fa_grad.add_prim_attr("next_block_num", next_block_num)
|
|
139
|
-
fa_grad.add_prim_attr("high_precision", high_precision)
|
|
140
|
-
fa_grad.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
|
|
141
|
-
fa_grad.init_prim_io_names(
|
|
142
|
-
inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "attn_mask", "dropout_mask",
|
|
143
|
-
"alibi_mask"],
|
|
144
|
-
outputs=["dq", "dk", "dv"]
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
def bprop(query, key, value, attn_mask, dropout_mask, alibi_mask, out, douts):
|
|
148
|
-
output, rowsum, rowmax = out
|
|
149
|
-
dout, _, _ = douts
|
|
150
|
-
dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax, attn_mask, dropout_mask,
|
|
151
|
-
alibi_mask)
|
|
152
|
-
dq = ops.cast(dq, mstype.float16)
|
|
153
|
-
dk = ops.cast(dk, mstype.float16)
|
|
154
|
-
dv = ops.cast(dv, mstype.float16)
|
|
155
|
-
return dq, dk, dv, zeros_like(attn_mask), \
|
|
156
|
-
zeros_like(dropout_mask), zeros_like(alibi_mask)
|
|
157
|
-
|
|
158
|
-
return bprop
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_name='sparse', high_precision=False):
|
|
162
|
-
"""get_flash_attention"""
|
|
163
|
-
|
|
164
|
-
def infer_shape(q_shape, k_shape, v_shape, attn_mask_shape=None,
|
|
165
|
-
dropout_mask_shape=None, alibi_mask_shape=None):
|
|
166
|
-
"""infer shape"""
|
|
167
|
-
batch, hidden_size, seq_len, _ = q_shape
|
|
168
|
-
l_shape = (batch, hidden_size, seq_len)
|
|
169
|
-
m_shape = (batch, hidden_size, seq_len)
|
|
170
|
-
return q_shape, l_shape, m_shape
|
|
171
|
-
|
|
172
|
-
def infer_dtype(q_dtype, k_dtype, v_dtype, attn_mask_dtype=None,
|
|
173
|
-
dropout_mask_dtype=None, alibi_mask_type=None):
|
|
174
|
-
"""infer type"""
|
|
175
|
-
l_dtype = mstype.float16
|
|
176
|
-
if high_precision:
|
|
177
|
-
l_dtype = mstype.float32
|
|
178
|
-
return q_dtype, l_dtype, q_dtype
|
|
179
|
-
|
|
180
|
-
fa_grad = get_flash_attention_grad(prev_block_num, next_block_num, tiling_stgy_name, high_precision)
|
|
181
|
-
fa_forward = Custom(flash_attention, out_shape=infer_shape,
|
|
182
|
-
out_dtype=infer_dtype, func_type="tbe", bprop=fa_grad,
|
|
183
|
-
reg_info=cus_flash_atten_op_info)
|
|
184
|
-
fa_forward.add_prim_attr("prev_block_num", prev_block_num)
|
|
185
|
-
fa_forward.add_prim_attr("next_block_num", next_block_num)
|
|
186
|
-
fa_forward.add_prim_attr("high_precision", high_precision)
|
|
187
|
-
fa_forward.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
|
|
188
|
-
fa_forward.init_prim_io_names(
|
|
189
|
-
inputs=["query", "key", "value", "attn_mask", "dropout_mask", "alibi_mask"],
|
|
190
|
-
outputs=["output", "rowsum", "rowmax"]
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
return fa_forward
|
|
@@ -1,435 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""the common about tik ops"""
|
|
16
|
-
from functools import partial
|
|
17
|
-
|
|
18
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SIZE
|
|
19
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
|
|
20
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
|
|
21
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import L0C
|
|
22
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class TikOpsUtils:
|
|
26
|
-
"""Utils function class about tik ops"""
|
|
27
|
-
|
|
28
|
-
def __init__(self, tik_instance):
|
|
29
|
-
self.tik_instance = tik_instance
|
|
30
|
-
self.dtype = "float16"
|
|
31
|
-
self.cont_data_mv_1_bust = partial(self.tik_instance.data_move, sid=0, nburst=1,
|
|
32
|
-
src_stride=0,
|
|
33
|
-
dst_stride=0)
|
|
34
|
-
|
|
35
|
-
def MK_TO_K1MK0(self, mk_input_tensor, workspace_tensor=None):
|
|
36
|
-
"""change data shape from (M, K) to (K1, M, K0), K1 = K // K0, the effect is equant to:
|
|
37
|
-
new_tensor = np.stack(np.hsplit(mk_input_tensor, K1), axis=0)
|
|
38
|
-
|
|
39
|
-
:param mk_input_tensor: input tensor in GM with shape: (M, K)
|
|
40
|
-
:param workspace_tensor: workspace tensor with shape: (K1, M, K0)
|
|
41
|
-
tensor will be changed, otherwise the new data will be copied to the workspace tensor,
|
|
42
|
-
and input tensor will stay unchanged.
|
|
43
|
-
:return: Tensor with shape (K1,M, K0)
|
|
44
|
-
"""
|
|
45
|
-
dtype = mk_input_tensor.dtype
|
|
46
|
-
m, k = mk_input_tensor.shape
|
|
47
|
-
K0 = 16
|
|
48
|
-
K1 = k // K0
|
|
49
|
-
M = self.up_align_to_K0(m)
|
|
50
|
-
try:
|
|
51
|
-
dtype_size = DTYPE_SIZE[dtype]
|
|
52
|
-
except KeyError:
|
|
53
|
-
raise ValueError("The argument 'dtype' is not valid.")
|
|
54
|
-
if workspace_tensor is not None:
|
|
55
|
-
with self.tik_instance.for_range(0, K1) as i:
|
|
56
|
-
self.tik_instance.data_move(
|
|
57
|
-
workspace_tensor[i * M * K0:],
|
|
58
|
-
mk_input_tensor[i * K0:],
|
|
59
|
-
0,
|
|
60
|
-
M,
|
|
61
|
-
K0 * dtype_size // 32,
|
|
62
|
-
(K1 - 1) * K0 * dtype_size // 32,
|
|
63
|
-
0,
|
|
64
|
-
)
|
|
65
|
-
return workspace_tensor.reshape((K1, M, K0))
|
|
66
|
-
|
|
67
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
68
|
-
tmp_ub = self.tik_instance.Tensor(dtype, (K1, M, K0), name="tmp_ub", scope=UB)
|
|
69
|
-
# data_move(m,k) --> (k1,m,K0)
|
|
70
|
-
with self.tik_instance.for_range(0, K1) as i:
|
|
71
|
-
self.tik_instance.data_move(
|
|
72
|
-
tmp_ub[i * M * K0:],
|
|
73
|
-
mk_input_tensor[i * K0:],
|
|
74
|
-
0,
|
|
75
|
-
M,
|
|
76
|
-
K0 * dtype_size // 32,
|
|
77
|
-
(K1 - 1) * K0 * dtype_size // 32,
|
|
78
|
-
0,
|
|
79
|
-
)
|
|
80
|
-
self.cont_data_mv_1_bust(
|
|
81
|
-
dst=mk_input_tensor, src=tmp_ub, burst=K1 * M * K0 * dtype_size // 32)
|
|
82
|
-
return mk_input_tensor.reshape((K1, M, K0))
|
|
83
|
-
|
|
84
|
-
def transpose_matrix(self, src_ub, dst_ub, N, nk0=False):
|
|
85
|
-
""" transpose matrix, default support shape: (16, n) -> (n, 16)
|
|
86
|
-
if nk0 is true, support shape: (n, 16) -> (16, n)
|
|
87
|
-
"""
|
|
88
|
-
K0 = 16
|
|
89
|
-
rep_times = N // K0
|
|
90
|
-
if nk0:
|
|
91
|
-
src_list = [src_ub[16 * i] for i in range(16)]
|
|
92
|
-
dst_list = [dst_ub[N * i] for i in range(16)]
|
|
93
|
-
else:
|
|
94
|
-
src_list = [src_ub[N * i] for i in range(16)]
|
|
95
|
-
dst_list = [dst_ub[16 * i] for i in range(16)]
|
|
96
|
-
|
|
97
|
-
dst_rep_stride = K0
|
|
98
|
-
src_rep_stride = 1
|
|
99
|
-
if rep_times == 1:
|
|
100
|
-
dst_rep_stride = 0
|
|
101
|
-
src_rep_stride = 0
|
|
102
|
-
|
|
103
|
-
if nk0:
|
|
104
|
-
src_rep_stride, dst_rep_stride = dst_rep_stride, src_rep_stride
|
|
105
|
-
|
|
106
|
-
self.tik_instance.vec_trans_scatter(
|
|
107
|
-
False, False, dst_list, src_list, rep_times, dst_rep_stride, src_rep_stride
|
|
108
|
-
)
|
|
109
|
-
return dst_ub
|
|
110
|
-
|
|
111
|
-
def KN_TO_K1NK0(self, kn_input_tensor, workspace_tensor=None):
|
|
112
|
-
"""change data shape from (K,N) to (K1, N, K0), K1 = K // K0, the effect is equvilent to:
|
|
113
|
-
new_tensor = np.reshape(kn_input_tensor, newshape=(K1, K0, N)).swapaxes(1, 2)
|
|
114
|
-
|
|
115
|
-
:param kn_input_tensor: input tensor with shape: (K, N)
|
|
116
|
-
:param workspace_tensor: workspace tensor with shape: (K1, N, K0)
|
|
117
|
-
tensor will be changed, otherwise the new data will be copied to the workspace tensor,
|
|
118
|
-
and input tensor will stay unchanged.
|
|
119
|
-
:return: Tensor with shape: (K1, N, K0)
|
|
120
|
-
"""
|
|
121
|
-
dtype = kn_input_tensor.dtype
|
|
122
|
-
k, n = kn_input_tensor.shape
|
|
123
|
-
K0 = 16
|
|
124
|
-
K1 = k // K0
|
|
125
|
-
N = n
|
|
126
|
-
try:
|
|
127
|
-
dtype_size = DTYPE_SIZE[dtype]
|
|
128
|
-
except KeyError:
|
|
129
|
-
raise ValueError("The argument 'dtype' is not valid.")
|
|
130
|
-
with self.tik_instance.for_range(0, K1) as index:
|
|
131
|
-
k1nk0_ub = self.tik_instance.Tensor(dtype, (N, K0), UB, "k1nk0_ub")
|
|
132
|
-
src_ub = self.tik_instance.Tensor(dtype, (K0, N), UB, "src_ub")
|
|
133
|
-
burst_len = K0 * N * dtype_size // 32
|
|
134
|
-
self.cont_data_mv_1_bust(dst=src_ub, src=kn_input_tensor[index * K0 * N],
|
|
135
|
-
burst=burst_len)
|
|
136
|
-
k1nk0_ub = self.transpose_matrix(src_ub, k1nk0_ub, N)
|
|
137
|
-
if workspace_tensor is None:
|
|
138
|
-
self.cont_data_mv_1_bust(dst=kn_input_tensor[index * K0 * N], src=k1nk0_ub,
|
|
139
|
-
burst=burst_len)
|
|
140
|
-
else:
|
|
141
|
-
self.cont_data_mv_1_bust(dst=workspace_tensor[index * K0 * N], src=k1nk0_ub,
|
|
142
|
-
burst=burst_len)
|
|
143
|
-
if workspace_tensor is None:
|
|
144
|
-
return kn_input_tensor.reshape((K1, N, K0))
|
|
145
|
-
|
|
146
|
-
return workspace_tensor.reshape((K1, N, K0))
|
|
147
|
-
|
|
148
|
-
def N1MN0_TO_MN(self, N1MN0_input):
|
|
149
|
-
"""change data shape from (N1, M, N0) to (M, N), N0=16, N = N1 * K0, the effect is equant to:
|
|
150
|
-
N1MN0_input = np.concatenate(list(map(np.squeeze, np.split(N1MN0_input, N1))), axis=1)
|
|
151
|
-
|
|
152
|
-
:param N1MN0_input: input tensor with shape (N, M, N0) in GM or L1.
|
|
153
|
-
:return:
|
|
154
|
-
"""
|
|
155
|
-
dtype = N1MN0_input.dtype
|
|
156
|
-
N1, M, N0 = N1MN0_input.shape
|
|
157
|
-
try:
|
|
158
|
-
dtype_size = DTYPE_SIZE[dtype]
|
|
159
|
-
except KeyError:
|
|
160
|
-
raise ValueError("The argument 'dtype' is not valid.")
|
|
161
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
162
|
-
tmp_ub = self.tik_instance.Tensor(dtype, (M, N1 * N0), name="tmp_ub", scope=UB)
|
|
163
|
-
# data_move (n1,m,n0) --> (m,n)
|
|
164
|
-
with self.tik_instance.for_range(0, N1) as i:
|
|
165
|
-
self.tik_instance.data_move(
|
|
166
|
-
tmp_ub[i * N0:],
|
|
167
|
-
N1MN0_input[i * M * N0:],
|
|
168
|
-
0,
|
|
169
|
-
M,
|
|
170
|
-
N0 * dtype_size // 32,
|
|
171
|
-
0,
|
|
172
|
-
(N1 - 1) * N0 * dtype_size // 32,
|
|
173
|
-
)
|
|
174
|
-
# data_move out
|
|
175
|
-
self.cont_data_mv_1_bust(dst=N1MN0_input, src=tmp_ub, burst=M * N1 * N0 * dtype_size // 32)
|
|
176
|
-
return N1MN0_input.reshape((M, N1 * N0))
|
|
177
|
-
|
|
178
|
-
def broadcast(self, vec_ub, shape):
|
|
179
|
-
""" broadcast a vector to a matrix
|
|
180
|
-
:param vec_ub: a tensor in UB with shape of (M,), and dtype is float16
|
|
181
|
-
:param shape: the target shape, a tuple with value (M, N), M and N are integer multiples of 16
|
|
182
|
-
:return: a tensor in UB with shape of (M, N)
|
|
183
|
-
"""
|
|
184
|
-
M, N = shape
|
|
185
|
-
dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
|
|
186
|
-
|
|
187
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
188
|
-
# (M,) -> (2, M) -> (4, M) -> (8, M) -> (16, M)
|
|
189
|
-
tmp_ub1 = self.tik_instance.Tensor(FP16, (16, M), name="tmp_ub1", scope=UB)
|
|
190
|
-
self.tik_instance.data_move(tmp_ub1, vec_ub, 0, 1, M // 16, 0, 0)
|
|
191
|
-
times = self.tik_instance.Scalar("int32", name="times", init_value=1)
|
|
192
|
-
with self.tik_instance.for_range(begint=0, endt=16):
|
|
193
|
-
with self.tik_instance.if_scope(times <= 8):
|
|
194
|
-
offset = times * M
|
|
195
|
-
burst = times * M // 16
|
|
196
|
-
self.cont_data_mv_1_bust(dst=tmp_ub1[offset], src=tmp_ub1, burst=burst)
|
|
197
|
-
with self.tik_instance.else_scope():
|
|
198
|
-
self.tik_instance.tik_break()
|
|
199
|
-
times.set_as(times * 2)
|
|
200
|
-
|
|
201
|
-
# (16, M) -> (M, 16)
|
|
202
|
-
tmp_ub2 = self.tik_instance.Tensor(FP16, (M, 16), name="tmp_ub2", scope=UB)
|
|
203
|
-
tmp_ub2_transposed = self.transpose_matrix(tmp_ub1, tmp_ub2, M)
|
|
204
|
-
|
|
205
|
-
# (M, 16) -> (M, 32) -> (M, 64) -> ... -> (M, N)
|
|
206
|
-
self.tik_instance.data_move(dst_ub, tmp_ub2_transposed, 0, M, 1, 0, N // 16 - 1)
|
|
207
|
-
times.set_as(1)
|
|
208
|
-
with self.tik_instance.for_range(begint=0, endt=N):
|
|
209
|
-
offset = times * 16
|
|
210
|
-
with self.tik_instance.if_scope(offset * 2 <= N):
|
|
211
|
-
burst = offset // 16
|
|
212
|
-
src_stride = N // 16 - burst
|
|
213
|
-
dst_stride = N // 16 - burst
|
|
214
|
-
self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
|
|
215
|
-
dst_stride)
|
|
216
|
-
with self.tik_instance.else_scope():
|
|
217
|
-
burst = (N - offset) // 16
|
|
218
|
-
src_stride = N // 16 - burst
|
|
219
|
-
dst_stride = N // 16 - burst
|
|
220
|
-
with self.tik_instance.if_scope(burst > 0):
|
|
221
|
-
self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
|
|
222
|
-
dst_stride)
|
|
223
|
-
self.tik_instance.tik_break()
|
|
224
|
-
times.set_as(times * 2)
|
|
225
|
-
return dst_ub
|
|
226
|
-
|
|
227
|
-
def broadcast_row(self, vec_ub, shape):
|
|
228
|
-
"""broadcast row"""
|
|
229
|
-
M, N = shape
|
|
230
|
-
dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
|
|
231
|
-
self.tik_instance.data_move(dst_ub, vec_ub, 0, 1, N // 16, 0, 0)
|
|
232
|
-
times = self.tik_instance.Scalar("int32", name="times", init_value=1)
|
|
233
|
-
# (1, N) -> (2, M) -> (4, N) -> ... -> (M, N)
|
|
234
|
-
with self.tik_instance.for_range(begint=0, endt=M):
|
|
235
|
-
with self.tik_instance.if_scope(times * 2 <= M):
|
|
236
|
-
burst = times * N // 16
|
|
237
|
-
offset = times * N
|
|
238
|
-
self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
|
|
239
|
-
with self.tik_instance.else_scope():
|
|
240
|
-
burst = (M - times) * N // 16
|
|
241
|
-
offset = times * N
|
|
242
|
-
with self.tik_instance.if_scope(burst > 0):
|
|
243
|
-
self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
|
|
244
|
-
self.tik_instance.tik_break()
|
|
245
|
-
times.set_as(times * 2)
|
|
246
|
-
return dst_ub
|
|
247
|
-
|
|
248
|
-
def get_K0(self, dtype=None):
|
|
249
|
-
"""get K0"""
|
|
250
|
-
if dtype is None:
|
|
251
|
-
dtype = self.dtype
|
|
252
|
-
try:
|
|
253
|
-
dtype_size = DTYPE_SIZE[dtype]
|
|
254
|
-
except KeyError:
|
|
255
|
-
raise ValueError("The argument 'dtype' is not valid.")
|
|
256
|
-
return 32 // dtype_size
|
|
257
|
-
|
|
258
|
-
def up_align_to_K0(self, n, dtype=None):
|
|
259
|
-
"""byte alignment by dtype"""
|
|
260
|
-
if dtype is None:
|
|
261
|
-
dtype = self.dtype
|
|
262
|
-
try:
|
|
263
|
-
dtype_size = DTYPE_SIZE[dtype]
|
|
264
|
-
except KeyError:
|
|
265
|
-
raise ValueError("The argument 'dtype' is not valid.")
|
|
266
|
-
K0 = 32 // dtype_size
|
|
267
|
-
return (n + K0 - 1) // K0 * K0
|
|
268
|
-
|
|
269
|
-
def calc_vec_rec(self, vec_ub, vec_len):
|
|
270
|
-
"""cal the reciprocal of a vector"""
|
|
271
|
-
dtype = vec_ub.dtype
|
|
272
|
-
vec_len_aligned = self.up_align_to_K0(vec_len)
|
|
273
|
-
vec_rec_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="li_new_rec_ub")
|
|
274
|
-
try:
|
|
275
|
-
dtype_size = DTYPE_SIZE[dtype]
|
|
276
|
-
except KeyError:
|
|
277
|
-
raise ValueError("The argument 'dtype' is not valid.")
|
|
278
|
-
mask_len = 256 // dtype_size
|
|
279
|
-
block_len = 32 // dtype_size
|
|
280
|
-
work_size = 8 // dtype_size
|
|
281
|
-
|
|
282
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
283
|
-
repeat_times = vec_len // mask_len
|
|
284
|
-
if repeat_times > 0:
|
|
285
|
-
dst_rep_stride = 8
|
|
286
|
-
src_rep_stride = 8
|
|
287
|
-
|
|
288
|
-
src_extent_size = (repeat_times - 1) * src_rep_stride * block_len + mask_len
|
|
289
|
-
wk_size_unit = ((src_extent_size + block_len - 1) // block_len) * block_len
|
|
290
|
-
wk_size = work_size * wk_size_unit
|
|
291
|
-
# 定义work_tensor
|
|
292
|
-
work_tensor_ub = self.tik_instance.Tensor(
|
|
293
|
-
"float32", (wk_size,), name="work_tensor_ub", scope=UB
|
|
294
|
-
)
|
|
295
|
-
# 如果work_tensor有索引,需要写成 work_tensor[index:]
|
|
296
|
-
self.tik_instance.vec_rec_high_preci(
|
|
297
|
-
mask_len,
|
|
298
|
-
vec_rec_ub[0:],
|
|
299
|
-
vec_ub[0:],
|
|
300
|
-
work_tensor_ub[0:],
|
|
301
|
-
repeat_times,
|
|
302
|
-
dst_rep_stride,
|
|
303
|
-
src_rep_stride,
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
mask_len = vec_len - repeat_times * mask_len
|
|
307
|
-
if mask_len > 0:
|
|
308
|
-
wk_size = work_size * ((mask_len + block_len - 1) // block_len) * block_len
|
|
309
|
-
work_tensor_ub2 = self.tik_instance.Tensor(
|
|
310
|
-
"float32", (wk_size,), name="work_tensor_ub2", scope=UB
|
|
311
|
-
)
|
|
312
|
-
self.tik_instance.vec_rec_high_preci(
|
|
313
|
-
mask_len,
|
|
314
|
-
vec_rec_ub[repeat_times * 128:],
|
|
315
|
-
vec_ub[repeat_times * 128:],
|
|
316
|
-
work_tensor_ub2[0:],
|
|
317
|
-
1,
|
|
318
|
-
0,
|
|
319
|
-
0,
|
|
320
|
-
)
|
|
321
|
-
return vec_rec_ub
|
|
322
|
-
|
|
323
|
-
def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, right_all_one_matrix_l1, rowsum_ub, m, k, precision_type):
|
|
324
|
-
"""用cube实现矩阵行和:右乘一个shape=(n,1)全一矩阵
|
|
325
|
-
:param matrix_l1_K1MK0_ed: input tensor with shape (K1, M, K0)
|
|
326
|
-
:param right_all_one_matrix_l1: input tensor with shape (K, 16)
|
|
327
|
-
:param rowsum_ub: output tensor stores the row sum of input tensor
|
|
328
|
-
:param m: actual tensor height
|
|
329
|
-
:param k: actual tensor width
|
|
330
|
-
:return: row sum of the output tensor
|
|
331
|
-
"""
|
|
332
|
-
K1, M, K0 = matrix_l1_K1MK0_ed.shape
|
|
333
|
-
# 调用matmul实现rowsum,结果shape=(m, 16),取每行的第一个数
|
|
334
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
335
|
-
row_sum_ub_N1MN0 = self.matmul_compute(matrix_l1_K1MK0_ed, right_all_one_matrix_l1, m, k, 16,
|
|
336
|
-
N1MN0_to_MN=False, precision_type=precision_type)
|
|
337
|
-
row_sum_ub_MN_ed = row_sum_ub_N1MN0.reshape((M, 16))
|
|
338
|
-
if precision_type == FP32:
|
|
339
|
-
for idx in range(0, m):
|
|
340
|
-
cur_row_sum = self.tik_instance.Scalar(FP32, init_value=row_sum_ub_MN_ed[idx, 0])
|
|
341
|
-
rowsum_ub[idx].set_as(cur_row_sum)
|
|
342
|
-
else:
|
|
343
|
-
# row_sum_ub_MN_ed 先转置,然后取一行, 替换原来按行操作: lij_ub[i].set_as(row_sum_ub_MN_ed[i, 0])
|
|
344
|
-
row_sum_ub_trans = self.tik_instance.Tensor(FP16, (16, M), name="row_sum_ub_trans", scope=UB)
|
|
345
|
-
row_sum_ub_trans = self.transpose_matrix(row_sum_ub_MN_ed, row_sum_ub_trans, M, True)
|
|
346
|
-
self.cont_data_mv_1_bust(dst=rowsum_ub, src=row_sum_ub_trans, burst=M // 16)
|
|
347
|
-
|
|
348
|
-
return rowsum_ub
|
|
349
|
-
|
|
350
|
-
def matmul_compute(self, A_l1, B_l1, m, k, n, N1MN0_to_MN=True, precision_type=FP16):
|
|
351
|
-
"""calculate matrix multiplication A_l1 * B_l1, and move the result to C_ub,
|
|
352
|
-
then rearrange C_ub
|
|
353
|
-
:param A_l1: input tensor in L1 with shape of (K1, M, K0)
|
|
354
|
-
:param B_l1: input tensor in L1 with shape of (K1, N, K0)
|
|
355
|
-
:param m: the actual number of rows of A_l1
|
|
356
|
-
:param k: the actual number of cols of A_l1
|
|
357
|
-
:param n: the actual number of cols of B_l1
|
|
358
|
-
:param N1MN0_to_MN: Whether reorder the result tensor.
|
|
359
|
-
:return: C_ub with tensor with shape of (M, N) if N1MN0_to_MN else (N1, M, N0)
|
|
360
|
-
"""
|
|
361
|
-
M = self.up_align_to_K0(m)
|
|
362
|
-
N = self.up_align_to_K0(n)
|
|
363
|
-
C_ub = self.tik_instance.Tensor(precision_type, (N // 16, M, 16), name="C_ub", scope=UB)
|
|
364
|
-
try:
|
|
365
|
-
dtype_size = DTYPE_SIZE[FP32]
|
|
366
|
-
except KeyError:
|
|
367
|
-
raise ValueError("The argument 'dtype' is not valid.")
|
|
368
|
-
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
369
|
-
# matmul
|
|
370
|
-
C_l0c = self.tik_instance.Tensor(
|
|
371
|
-
FP32, (N // 16, M, 16), scope=L0C, name="C_l0c"
|
|
372
|
-
) # n1mn0 (n0=16)
|
|
373
|
-
self.tik_instance.matmul(C_l0c, A_l1, B_l1, m, k, n)
|
|
374
|
-
# L0C -> ub, fp32 -> fp16 (tensor_mov可做随路转换)
|
|
375
|
-
self.tik_instance.tensor_mov(C_ub, C_l0c, "m", 1, M * N * dtype_size // 1024, 0, 0)
|
|
376
|
-
if N1MN0_to_MN:
|
|
377
|
-
return self.N1MN0_TO_MN(C_ub)
|
|
378
|
-
return C_ub
|
|
379
|
-
|
|
380
|
-
def move_vector_from_gm_to_ub(self, dst_tensor, src_tensor, gm_offset, vec_len):
|
|
381
|
-
"""load the vector from gm to ub
|
|
382
|
-
:param dst_tensor:
|
|
383
|
-
:param src_tensor:
|
|
384
|
-
:param gm_offset:
|
|
385
|
-
:return:
|
|
386
|
-
"""
|
|
387
|
-
try:
|
|
388
|
-
dtype_size = DTYPE_SIZE[src_tensor.dtype]
|
|
389
|
-
except KeyError:
|
|
390
|
-
raise ValueError("The argument 'src_tensor dtype' is not valid.")
|
|
391
|
-
a_burst_num = 32 // dtype_size
|
|
392
|
-
full_tik_blk_num, tail_num = divmod(vec_len, a_burst_num)
|
|
393
|
-
with self.tik_instance.if_scope(full_tik_blk_num > 0):
|
|
394
|
-
self.cont_data_mv_1_bust(dst=dst_tensor, src=src_tensor[gm_offset],
|
|
395
|
-
burst=full_tik_blk_num)
|
|
396
|
-
# 地址回退处理尾部数据
|
|
397
|
-
with self.tik_instance.if_scope(tail_num > 0):
|
|
398
|
-
offset = vec_len - a_burst_num
|
|
399
|
-
last_blk_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="last_blk_ub", scope=UB)
|
|
400
|
-
self.cont_data_mv_1_bust(dst=last_blk_ub, src=src_tensor[gm_offset + offset], burst=1)
|
|
401
|
-
with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes对齐, 无法用datamove
|
|
402
|
-
dst_tensor[offset + idx].set_as(last_blk_ub[idx])
|
|
403
|
-
|
|
404
|
-
def move_vector_from_ub_to_gm(self, dst_tensor, src_tensor, gm_offset, block_h):
|
|
405
|
-
"""write the vector back to gm
|
|
406
|
-
:param dst_tensor:
|
|
407
|
-
:param src_tensor:
|
|
408
|
-
:param gm_offset:
|
|
409
|
-
:param block_h:
|
|
410
|
-
:return:
|
|
411
|
-
"""
|
|
412
|
-
try:
|
|
413
|
-
dtype_size = DTYPE_SIZE[src_tensor.dtype]
|
|
414
|
-
except KeyError:
|
|
415
|
-
raise ValueError("The argument 'src_tensor dtype' is not valid.")
|
|
416
|
-
a_burst_num = 32 // dtype_size
|
|
417
|
-
full_tik_blk_num = block_h // a_burst_num
|
|
418
|
-
with self.tik_instance.if_scope(full_tik_blk_num > 0):
|
|
419
|
-
self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset], src=src_tensor,
|
|
420
|
-
burst=full_tik_blk_num)
|
|
421
|
-
tail_num = block_h % a_burst_num
|
|
422
|
-
with self.tik_instance.if_scope(tail_num > 0):
|
|
423
|
-
offset = block_h - a_burst_num
|
|
424
|
-
tmp_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="tmp_ub", scope=UB)
|
|
425
|
-
with self.tik_instance.for_range(0, a_burst_num) as idx:
|
|
426
|
-
tmp_ub[idx].set_as(src_tensor[offset + idx])
|
|
427
|
-
self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset + offset], src=tmp_ub, burst=1)
|
|
428
|
-
|
|
429
|
-
def scale_compute_vector(self, Sij_ub, dim):
|
|
430
|
-
"""scale compute vector"""
|
|
431
|
-
scale_value = dim ** -0.5
|
|
432
|
-
scale = self.tik_instance.Scalar(dtype=FP16)
|
|
433
|
-
scale.set_as(scale_value)
|
|
434
|
-
self.tik_instance.h_mul(Sij_ub, Sij_ub, scale)
|
|
435
|
-
return Sij_ub
|
|
File without changes
|
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""tiling for sparse """
|
|
16
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingPara
|
|
17
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class SparseTiling(TilingStrategy):
|
|
21
|
-
"""A tiling strategy implementation for sparse shape"""
|
|
22
|
-
|
|
23
|
-
@classmethod
|
|
24
|
-
def strategy_name(cls):
|
|
25
|
-
return "sparse"
|
|
26
|
-
|
|
27
|
-
def tiling(self) -> TilingPara:
|
|
28
|
-
self.Br = min(128, self.Nq)
|
|
29
|
-
self.Bc = min(128, self.N)
|
|
30
|
-
|
|
31
|
-
self.Tr = self.Nq // self.Br
|
|
32
|
-
self.Tc = self.N // self.Bc
|
|
33
|
-
|
|
34
|
-
if self.Nq % self.Br != 0:
|
|
35
|
-
self.last_Br = self.Nq - self.Tr * self.Br
|
|
36
|
-
self.Tr += 1
|
|
37
|
-
else:
|
|
38
|
-
self.last_Br = self.Br
|
|
39
|
-
if self.N % self.Bc != 0:
|
|
40
|
-
self.last_Bc = self.N - self.Tc * self.Bc
|
|
41
|
-
self.Tc += 1
|
|
42
|
-
else:
|
|
43
|
-
self.last_Bc = self.Bc
|
|
44
|
-
|
|
45
|
-
return self.gen_tiling_para()
|