mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,602 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import namedtuple
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from einops import rearrange
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from msprobe.pytorch.common.utils import logger
|
|
25
|
+
|
|
26
|
+
GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
|
|
27
|
+
SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
|
|
28
|
+
|
|
29
|
+
FaForwardParams = namedtuple("FaForwardParams",
|
|
30
|
+
["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"])
|
|
31
|
+
FaBackwardParams = namedtuple("FaBackwardParams",
|
|
32
|
+
["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"])
|
|
33
|
+
RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
|
|
34
|
+
["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def softmax_forward(x):
|
|
38
|
+
x_max = torch.max(x, dim=-1, keepdims=True)[0]
|
|
39
|
+
x_sub = x.sub(x_max)
|
|
40
|
+
y = torch.exp(x_sub)
|
|
41
|
+
x_sum = y.sum(dim=-1, keepdims=True)
|
|
42
|
+
res = y.div(x_sum)
|
|
43
|
+
return res, x_max, x_sum
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def softmax_grad(dp, softmax_res):
|
|
47
|
+
muls = dp * softmax_res
|
|
48
|
+
muls_r = muls.sum(dim=-1, keepdims=True)
|
|
49
|
+
sub_r = dp - muls_r
|
|
50
|
+
res = sub_r * softmax_res
|
|
51
|
+
return res
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
|
|
55
|
+
if num_kv_heads == 0 or num_kv_heads > num_heads:
|
|
56
|
+
raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.")
|
|
57
|
+
|
|
58
|
+
factor = num_heads // num_kv_heads
|
|
59
|
+
kv_shape = kv_tensor.shape
|
|
60
|
+
b = kv_shape[0]
|
|
61
|
+
s = kv_shape[2]
|
|
62
|
+
d = kv_shape[3]
|
|
63
|
+
kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
|
|
64
|
+
for i in range(num_heads):
|
|
65
|
+
j = i // factor
|
|
66
|
+
kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
|
|
67
|
+
return kv_res
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def calculate_qk(q, k, attn_mask, pse, scalar_value):
|
|
71
|
+
if k.dim() != 4:
|
|
72
|
+
raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})")
|
|
73
|
+
|
|
74
|
+
if k.dim() == 3:
|
|
75
|
+
k = k.unsqueeze(1) # 在head维度扩展
|
|
76
|
+
|
|
77
|
+
if pse is None or len(pse.shape) == 0:
|
|
78
|
+
qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value)
|
|
79
|
+
else:
|
|
80
|
+
qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value)
|
|
81
|
+
if attn_mask is None or len(attn_mask.shape) == 0:
|
|
82
|
+
return qk
|
|
83
|
+
else:
|
|
84
|
+
qk = qk + attn_mask.bool() * (-40000.0) # -10000
|
|
85
|
+
return qk
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def fusion_attention_forward(forward_params):
|
|
89
|
+
q = forward_params.q
|
|
90
|
+
k = forward_params.k
|
|
91
|
+
v = forward_params.v
|
|
92
|
+
drop_mask = forward_params.drop_mask
|
|
93
|
+
attn_mask = forward_params.attn_mask
|
|
94
|
+
pse = forward_params.pse
|
|
95
|
+
scalar_value = forward_params.scalar_value
|
|
96
|
+
keep_prob = forward_params.keep_prob
|
|
97
|
+
|
|
98
|
+
qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
|
|
99
|
+
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
100
|
+
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
101
|
+
drop_res = softmax_res
|
|
102
|
+
else:
|
|
103
|
+
drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
|
|
104
|
+
y = torch.matmul(drop_res, v)
|
|
105
|
+
return y, softmax_max, softmax_sum
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def fusion_attention_backward(backward_params):
|
|
109
|
+
dx = backward_params.dx
|
|
110
|
+
q = backward_params.q
|
|
111
|
+
k = backward_params.k
|
|
112
|
+
v = backward_params.v
|
|
113
|
+
softmax_res = backward_params.softmax_res
|
|
114
|
+
drop_mask = backward_params.drop_mask
|
|
115
|
+
pse = backward_params.pse
|
|
116
|
+
scalar_value = backward_params.scalar_value
|
|
117
|
+
keep_prob = backward_params.keep_prob
|
|
118
|
+
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
119
|
+
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
120
|
+
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
121
|
+
dp_drop = dp
|
|
122
|
+
else:
|
|
123
|
+
drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
|
|
124
|
+
dp_drop = dp * drop_mask * (1.0 / keep_prob)
|
|
125
|
+
dv = torch.matmul(drop_res, dx)
|
|
126
|
+
softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value)
|
|
127
|
+
dq = torch.matmul(softmax_grad_res, k)
|
|
128
|
+
dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
|
|
129
|
+
return dq, dk, dv
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def parse_bsnd_args(query, key, head_num, input_layout):
|
|
133
|
+
supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
|
|
134
|
+
b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
|
|
135
|
+
|
|
136
|
+
if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
|
|
137
|
+
raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
|
|
138
|
+
|
|
139
|
+
if input_layout == "TND":
|
|
140
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
141
|
+
try:
|
|
142
|
+
if input_layout == "BSH":
|
|
143
|
+
b, s1, h1 = query.shape
|
|
144
|
+
_, s2, h2 = key.shape
|
|
145
|
+
d = h1 // n1
|
|
146
|
+
n2 = h2 // d
|
|
147
|
+
elif input_layout == "SBH":
|
|
148
|
+
s1, b, h1 = query.shape
|
|
149
|
+
s2, _, h2 = key.shape
|
|
150
|
+
d = h1 // n1
|
|
151
|
+
n2 = h2 // d
|
|
152
|
+
elif input_layout == "BSND":
|
|
153
|
+
b, s1, n1, d = query.shape
|
|
154
|
+
_, s2, n2, _ = key.shape
|
|
155
|
+
h1 = n1 * d
|
|
156
|
+
h2 = n2 * d
|
|
157
|
+
elif input_layout == "BNSD":
|
|
158
|
+
b, n1, s1, d = query.shape
|
|
159
|
+
_, n2, s2, _ = key.shape
|
|
160
|
+
h1 = n1 * d
|
|
161
|
+
h2 = n2 * d
|
|
162
|
+
except Exception as e:
|
|
163
|
+
raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
|
|
164
|
+
|
|
165
|
+
if d == 0:
|
|
166
|
+
raise ValueError(f"Value d must be non-zero.")
|
|
167
|
+
_dtype = query.dtype
|
|
168
|
+
ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
|
|
169
|
+
return ret
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def convert_from_bnsd(_input, input_layout):
|
|
173
|
+
"""
|
|
174
|
+
transform qkv from bnsd to input_layout.
|
|
175
|
+
B: batch_size
|
|
176
|
+
S: sequence_length
|
|
177
|
+
N: num_heads
|
|
178
|
+
D: head_dim
|
|
179
|
+
Args:
|
|
180
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D)
|
|
181
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
182
|
+
Returns:
|
|
183
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
184
|
+
"""
|
|
185
|
+
if input_layout == "BSH":
|
|
186
|
+
# (B,N,S,D)=>(B,S,N*D)
|
|
187
|
+
out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
|
|
188
|
+
elif input_layout == "SBH":
|
|
189
|
+
# (B,N,S,D)=>(S,B,N*D)
|
|
190
|
+
out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
|
|
191
|
+
elif input_layout == "BSND":
|
|
192
|
+
# (B,N,S,D)=>(B,S,N,D)
|
|
193
|
+
out = rearrange(_input, 'b n s d -> b s n d').contiguous()
|
|
194
|
+
elif input_layout == "TND":
|
|
195
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
196
|
+
else:
|
|
197
|
+
out = _input
|
|
198
|
+
return out
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def convert_to_bnsd(_input, n, input_layout):
|
|
202
|
+
"""
|
|
203
|
+
transform qkv from input_layout to bnsd.
|
|
204
|
+
B: batch_size
|
|
205
|
+
S: sequence_length
|
|
206
|
+
N: num_heads
|
|
207
|
+
D: head_dim
|
|
208
|
+
Args:
|
|
209
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
210
|
+
n (int): num_heads
|
|
211
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
212
|
+
Returns:
|
|
213
|
+
tensor of shape (B,N,S,D)
|
|
214
|
+
"""
|
|
215
|
+
if input_layout == "BSH":
|
|
216
|
+
# (B,S,N*D)=>(B,N,S,D)
|
|
217
|
+
out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
|
|
218
|
+
elif input_layout == "SBH":
|
|
219
|
+
# (S,B,N*D)=>(B,N,S,D)
|
|
220
|
+
out = rearrange(_input, 's b (n d) -> b n s d', n=n)
|
|
221
|
+
elif input_layout == "BSND":
|
|
222
|
+
# (B,S,N,D)=>(B,N,S,D)
|
|
223
|
+
out = rearrange(_input, 'b s n d -> b n s d', n=n)
|
|
224
|
+
elif input_layout == "TND":
|
|
225
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
226
|
+
else:
|
|
227
|
+
out = _input
|
|
228
|
+
if out.dim() != 4:
|
|
229
|
+
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
230
|
+
return out.to(GTYPE)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def convert_from_bsnd(_input, input_layout):
|
|
234
|
+
"""
|
|
235
|
+
transform qkv from bsnd to input_layout.
|
|
236
|
+
B: batch_size
|
|
237
|
+
S: sequence_length
|
|
238
|
+
N: num_heads
|
|
239
|
+
D: head_dim
|
|
240
|
+
Args:
|
|
241
|
+
_input (torch.Tensor): tensor of shape (B,S,N,D)
|
|
242
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
243
|
+
Returns:
|
|
244
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
245
|
+
"""
|
|
246
|
+
if input_layout == "BSH":
|
|
247
|
+
# (B,S,N,D)=>(B,S,N*D)
|
|
248
|
+
out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
|
|
249
|
+
elif input_layout == "SBH":
|
|
250
|
+
# (B,S,N,D)=>(S,B,N*D)
|
|
251
|
+
out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
|
|
252
|
+
elif input_layout == "BNSD":
|
|
253
|
+
# (B,S,N,D)=>(B,N,S,D)
|
|
254
|
+
out = rearrange(_input, 'b s n d -> b n s d').contiguous()
|
|
255
|
+
elif input_layout == "TND":
|
|
256
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
257
|
+
else:
|
|
258
|
+
out = _input
|
|
259
|
+
return out
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def convert_to_bsnd(_input, n, input_layout):
|
|
263
|
+
"""
|
|
264
|
+
transform qkv from input_layout to bsnd.
|
|
265
|
+
B: batch_size
|
|
266
|
+
S: sequence_length
|
|
267
|
+
N: num_heads
|
|
268
|
+
D: head_dim
|
|
269
|
+
Args:
|
|
270
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
271
|
+
n (int): num_heads
|
|
272
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
273
|
+
Returns:
|
|
274
|
+
tensor of shape (B,S,N,D)
|
|
275
|
+
"""
|
|
276
|
+
if input_layout == "BSH":
|
|
277
|
+
# (B,S,N*D)=>(B,S,N,D)
|
|
278
|
+
out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
|
|
279
|
+
elif input_layout == "SBH":
|
|
280
|
+
# (S,B,N*D)=>(B,S,N,D)
|
|
281
|
+
out = rearrange(_input, 's b (n d) -> b s n d', n=n)
|
|
282
|
+
elif input_layout == "BNSD":
|
|
283
|
+
# (B,N,S,D)=>(B,S,N,D)
|
|
284
|
+
out = rearrange(_input, 'b n s d -> b s n d', n=n)
|
|
285
|
+
elif input_layout == "TND":
|
|
286
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
287
|
+
else:
|
|
288
|
+
out = _input
|
|
289
|
+
if out.dim() != 4:
|
|
290
|
+
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
291
|
+
return out
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def generate_attn_mask(*args):
|
|
295
|
+
"""
|
|
296
|
+
# 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
|
|
297
|
+
===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
|
|
301
|
+
shape = [s1, s2]
|
|
302
|
+
|
|
303
|
+
if attn_mask is not None:
|
|
304
|
+
# 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
|
|
305
|
+
if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
|
|
306
|
+
logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}")
|
|
307
|
+
|
|
308
|
+
if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048:
|
|
309
|
+
if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)):
|
|
310
|
+
if sparse_mode == 2:
|
|
311
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
312
|
+
elif sparse_mode == 3:
|
|
313
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
314
|
+
elif sparse_mode == 4:
|
|
315
|
+
attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
316
|
+
attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
317
|
+
attn_mask = attn_mask_u + attn_mask_l
|
|
318
|
+
logger.debug(f"反向转换attn_mask {attn_mask.shape}")
|
|
319
|
+
return attn_mask.to(dtype)
|
|
320
|
+
|
|
321
|
+
return attn_mask.to(dtype)
|
|
322
|
+
|
|
323
|
+
if attn_mask is not None:
|
|
324
|
+
if attn_mask.dim() == 2:
|
|
325
|
+
if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2:
|
|
326
|
+
raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}")
|
|
327
|
+
shape = [s1, s2]
|
|
328
|
+
elif attn_mask.dim() == 4:
|
|
329
|
+
if attn_mask.shape[1] == 1:
|
|
330
|
+
shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
|
|
331
|
+
else:
|
|
332
|
+
shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
|
|
333
|
+
|
|
334
|
+
if sparse_mode == 0:
|
|
335
|
+
attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
336
|
+
attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
337
|
+
attn_mask = attn_mask_u + attn_mask_l
|
|
338
|
+
elif sparse_mode == 1: # no sparse
|
|
339
|
+
attn_mask = torch.from_numpy(np.zeros(shape))
|
|
340
|
+
elif sparse_mode == 2:
|
|
341
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
342
|
+
elif sparse_mode == 3:
|
|
343
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
344
|
+
elif sparse_mode == 4:
|
|
345
|
+
attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
346
|
+
attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
347
|
+
attn_mask = attn_mask_u + attn_mask_l
|
|
348
|
+
# 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS,
|
|
349
|
+
# 因此可以认为FA的输入已经是正确的attn_mask了
|
|
350
|
+
return attn_mask.to(dtype)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def generate_kv(key, value, n1, n2):
|
|
354
|
+
# N不等长适配by cdy
|
|
355
|
+
if not (n1 == n2):
|
|
356
|
+
k_new = broadcast_kv(n1, n2, key, key.dtype)
|
|
357
|
+
v_new = broadcast_kv(n1, n2, value, value.dtype)
|
|
358
|
+
else:
|
|
359
|
+
k_new = key
|
|
360
|
+
v_new = value
|
|
361
|
+
return k_new, v_new
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value):
|
|
365
|
+
"""
|
|
366
|
+
attention = softmax(QK^T/sqrt(d))V
|
|
367
|
+
softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
|
|
368
|
+
"""
|
|
369
|
+
logger.info("Using QKV to rebuild original softmax")
|
|
370
|
+
qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
|
|
371
|
+
softmax_res, _, _ = softmax_forward(qk)
|
|
372
|
+
return softmax_res
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def rebuild_softmax_by_max_sum(softmax_params):
|
|
376
|
+
"""
|
|
377
|
+
attention = softmax(QK^T/sqrt(d))V
|
|
378
|
+
softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
|
|
379
|
+
"""
|
|
380
|
+
q = softmax_params.q
|
|
381
|
+
k = softmax_params.k
|
|
382
|
+
attn_mask = softmax_params.attn_mask
|
|
383
|
+
pse = softmax_params.pse
|
|
384
|
+
scalar_value = softmax_params.scalar_value
|
|
385
|
+
softmax_max = softmax_params.softmax_max
|
|
386
|
+
softmax_sum = softmax_params.softmax_sum
|
|
387
|
+
logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
|
|
388
|
+
|
|
389
|
+
qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
|
|
390
|
+
if softmax_max.shape[-1] == 0:
|
|
391
|
+
raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
|
|
392
|
+
repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
|
|
393
|
+
softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
|
|
394
|
+
softmax_sum.repeat(1, 1, 1, repeat_dim))
|
|
395
|
+
return softmax_res
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def get_head_num(*args, **kwargs):
|
|
399
|
+
if kwargs.get("head_num", None):
|
|
400
|
+
head_num = kwargs.get("head_num")
|
|
401
|
+
elif len(args) >= 4:
|
|
402
|
+
head_num = args[3]
|
|
403
|
+
else:
|
|
404
|
+
raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
|
|
405
|
+
return head_num
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def get_input_layout(*args, **kwargs):
|
|
409
|
+
if kwargs.get("input_layout", None):
|
|
410
|
+
input_layout = kwargs.get("input_layout")
|
|
411
|
+
elif len(args) >= 5:
|
|
412
|
+
input_layout = args[4]
|
|
413
|
+
else:
|
|
414
|
+
raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
|
|
415
|
+
return input_layout
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
419
|
+
if len(args) < 2:
|
|
420
|
+
raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
|
|
421
|
+
|
|
422
|
+
# query, key, value, head_num, input_layout
|
|
423
|
+
head_num = get_head_num(*args, **kwargs)
|
|
424
|
+
input_layout = get_input_layout(*args, **kwargs)
|
|
425
|
+
|
|
426
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
|
|
427
|
+
if n1 == n2 and s1 == s2:
|
|
428
|
+
logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
429
|
+
else:
|
|
430
|
+
logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
431
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
432
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
433
|
+
|
|
434
|
+
dims_kwargs = {
|
|
435
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
436
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
437
|
+
}
|
|
438
|
+
new_kwargs = {
|
|
439
|
+
"keep_prob": 1,
|
|
440
|
+
"scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)),
|
|
441
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
442
|
+
"prefix": kwargs.get("prefix"),
|
|
443
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
444
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
445
|
+
"pse": kwargs.get("pse"),
|
|
446
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
447
|
+
"attn_mask": kwargs.get("attn_mask")
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
return args, dims_kwargs, new_kwargs
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
454
|
+
if len(args) != 6:
|
|
455
|
+
raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
|
|
456
|
+
|
|
457
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
|
|
458
|
+
if n1 == n2 and s1 == s2:
|
|
459
|
+
logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
460
|
+
else:
|
|
461
|
+
logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
462
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
463
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
464
|
+
|
|
465
|
+
dims_kwargs = {
|
|
466
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
467
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
new_kwargs = {
|
|
471
|
+
"keep_prob": 1,
|
|
472
|
+
"scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)),
|
|
473
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
474
|
+
"prefix": kwargs.get("prefix"),
|
|
475
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
476
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
477
|
+
"pse": kwargs.get("pse"),
|
|
478
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
479
|
+
"softmax_max": kwargs.get("softmax_max"),
|
|
480
|
+
"softmax_sum": kwargs.get("softmax_sum"),
|
|
481
|
+
"softmax_in": kwargs.get("softmax_in"),
|
|
482
|
+
"attention_in": kwargs.get("attention_in"),
|
|
483
|
+
"seed": kwargs.get("seed", 0),
|
|
484
|
+
"offset": kwargs.get("offset", 0),
|
|
485
|
+
"numels": kwargs.get("numels", 0),
|
|
486
|
+
"attn_mask": kwargs.get("attn_mask")
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
return args, dims_kwargs, new_kwargs
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
class FlashAttentionScore(nn.Module):
|
|
493
|
+
def __init__(self):
|
|
494
|
+
super(FlashAttentionScore, self).__init__()
|
|
495
|
+
# You can initialize any parameters here if necessary
|
|
496
|
+
|
|
497
|
+
def forward(self, *inputs, **kwargs):
|
|
498
|
+
# Extract the inputs for the attention calculation
|
|
499
|
+
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs)
|
|
500
|
+
query, key, value = new_args[0], new_args[1], new_args[2]
|
|
501
|
+
|
|
502
|
+
input_layout = get_input_layout(*inputs, **kwargs)
|
|
503
|
+
|
|
504
|
+
n1 = dims_kwargs.get("n1")
|
|
505
|
+
n2 = dims_kwargs.get("n2")
|
|
506
|
+
s1 = dims_kwargs.get("s1")
|
|
507
|
+
s2 = dims_kwargs.get("s2")
|
|
508
|
+
b = dims_kwargs.get("b")
|
|
509
|
+
dtype = dims_kwargs.get("dtype")
|
|
510
|
+
attn_mask = new_kwargs.get("attn_mask")
|
|
511
|
+
keep_prob = new_kwargs.get("keep_prob")
|
|
512
|
+
sparse_mode = new_kwargs.get("sparse_mode")
|
|
513
|
+
pre_tockens = new_kwargs.get("pre_tockens")
|
|
514
|
+
next_tockens = new_kwargs.get("next_tokens")
|
|
515
|
+
pse = new_kwargs.get("real_shift")
|
|
516
|
+
scalar_value = new_kwargs.get("scalar_value")
|
|
517
|
+
|
|
518
|
+
args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
519
|
+
|
|
520
|
+
attn_mask = generate_attn_mask(*args_temp)
|
|
521
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
522
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
523
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
524
|
+
|
|
525
|
+
forward_params = FaForwardParams(
|
|
526
|
+
q=query,
|
|
527
|
+
k=key,
|
|
528
|
+
v=value,
|
|
529
|
+
drop_mask=None,
|
|
530
|
+
attn_mask=attn_mask,
|
|
531
|
+
pse=pse,
|
|
532
|
+
scalar_value=scalar_value,
|
|
533
|
+
keep_prob=keep_prob
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
|
|
537
|
+
|
|
538
|
+
# If output dimension is 5, reshape accordingly
|
|
539
|
+
if out_golden.dim() == 5:
|
|
540
|
+
out_golden = out_golden.reshape(out_golden.size(0),
|
|
541
|
+
out_golden.size(1) * out_golden.size(2),
|
|
542
|
+
out_golden.size(3), out_golden.size(4))
|
|
543
|
+
|
|
544
|
+
out_golden = convert_from_bnsd(out_golden, input_layout)
|
|
545
|
+
|
|
546
|
+
# Ensure the output matches the desired layout
|
|
547
|
+
out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
|
|
548
|
+
|
|
549
|
+
return out_golden
|
|
550
|
+
|
|
551
|
+
def backward(self, *inputs, **kwargs):
|
|
552
|
+
# The backward pass will be similar to what was described for the gradient computation
|
|
553
|
+
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs)
|
|
554
|
+
query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
|
|
555
|
+
n1 = dims_kwargs.get("n1")
|
|
556
|
+
n2 = dims_kwargs.get("n2")
|
|
557
|
+
s1 = dims_kwargs.get("s1")
|
|
558
|
+
s2 = dims_kwargs.get("s2")
|
|
559
|
+
b = dims_kwargs.get("b")
|
|
560
|
+
dtype = dims_kwargs.get("dtype")
|
|
561
|
+
attn_mask = new_kwargs.get("attn_mask")
|
|
562
|
+
keep_prob = new_kwargs.get("keep_prob")
|
|
563
|
+
sparse_mode = new_kwargs.get("sparse_mode")
|
|
564
|
+
pre_tockens = new_kwargs.get("pre_tockens")
|
|
565
|
+
next_tockens = new_kwargs.get("next_tockens")
|
|
566
|
+
pse = new_kwargs.get("pse")
|
|
567
|
+
softmax_max = new_kwargs.get("softmax_max")
|
|
568
|
+
softmax_sum = new_kwargs.get("softmax_sum")
|
|
569
|
+
scalar_value = new_kwargs.get("scalar_value")
|
|
570
|
+
|
|
571
|
+
args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
572
|
+
attn_mask = generate_attn_mask(*args_temp)
|
|
573
|
+
|
|
574
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
575
|
+
dx = convert_to_bnsd(dx, n1, input_layout)
|
|
576
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
577
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
578
|
+
|
|
579
|
+
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
580
|
+
|
|
581
|
+
if SOFTMAX_BUILD_MODE == "QKV":
|
|
582
|
+
softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value)
|
|
583
|
+
else:
|
|
584
|
+
softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum)
|
|
585
|
+
softmax_res = rebuild_softmax_by_max_sum(softmax_params)
|
|
586
|
+
|
|
587
|
+
backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob)
|
|
588
|
+
dq, dk, dv = fusion_attention_backward(backward_params)
|
|
589
|
+
|
|
590
|
+
# Reshape as needed
|
|
591
|
+
if dq.dim() == 5:
|
|
592
|
+
dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
|
|
593
|
+
if dk.dim() == 5:
|
|
594
|
+
dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
|
|
595
|
+
if dv.dim() == 5:
|
|
596
|
+
dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
|
|
597
|
+
|
|
598
|
+
dq = convert_from_bnsd(dq, input_layout)
|
|
599
|
+
dk = convert_from_bnsd(dk, input_layout)
|
|
600
|
+
dv = convert_from_bnsd(dv, input_layout)
|
|
601
|
+
|
|
602
|
+
return dq.cpu(), dk.cpu(), dv.cpu()
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FusionOperator:
|
|
20
|
+
"""
|
|
21
|
+
所有融合算子的父类,定义了通用的接口和属性。
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# 初始化操作符字典
|
|
25
|
+
def __init__(self):
|
|
26
|
+
self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符
|
|
27
|
+
self._register_operators()
|
|
28
|
+
|
|
29
|
+
def __getattr__(self, name):
|
|
30
|
+
""" 动态获取算子类 """
|
|
31
|
+
if hasattr(self, name):
|
|
32
|
+
return getattr(self, name)
|
|
33
|
+
else:
|
|
34
|
+
raise AttributeError(f"'FusionOperator' object has no attribute '{name}'")
|
|
35
|
+
|
|
36
|
+
def _register_operators(self):
|
|
37
|
+
""" 注册操作符到父类,以便通过 fusion.xxx 调用 """
|
|
38
|
+
self.flash_attention_score = FlashAttentionScore()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
fusion = FusionOperator()
|
|
@@ -16,12 +16,13 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import csv
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
20
20
|
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
|
|
21
21
|
from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
|
|
22
22
|
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
23
23
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
24
24
|
from msprobe.mindspore.common.log import logger
|
|
25
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class ResultCsvEntry:
|
|
@@ -27,10 +27,11 @@ import numpy as np
|
|
|
27
27
|
from tqdm import tqdm
|
|
28
28
|
|
|
29
29
|
# 本地应用/库特定导入
|
|
30
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
30
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
31
31
|
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
|
|
32
32
|
from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
|
|
33
33
|
from msprobe.mindspore.common.log import logger
|
|
34
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
@@ -19,7 +19,8 @@ import sys
|
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
import mindspore
|
|
21
21
|
from msprobe.mindspore.common.log import logger
|
|
22
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
22
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
23
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
23
24
|
import torch as mindtorch
|
|
24
25
|
from torch import Tensor as mindtorch_tensor
|
|
25
26
|
import torch.nn.functional as mindtorch_func
|