mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,602 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+ import torch
18
+ import torch.nn as nn
19
+ import numpy as np
20
+
21
+ from einops import rearrange
22
+
23
+
24
+ from msprobe.pytorch.common.utils import logger
25
+
26
+ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
27
+ SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
28
+
29
+ FaForwardParams = namedtuple("FaForwardParams",
30
+ ["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"])
31
+ FaBackwardParams = namedtuple("FaBackwardParams",
32
+ ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"])
33
+ RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
34
+ ["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"])
35
+
36
+
37
+ def softmax_forward(x):
38
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
39
+ x_sub = x.sub(x_max)
40
+ y = torch.exp(x_sub)
41
+ x_sum = y.sum(dim=-1, keepdims=True)
42
+ res = y.div(x_sum)
43
+ return res, x_max, x_sum
44
+
45
+
46
+ def softmax_grad(dp, softmax_res):
47
+ muls = dp * softmax_res
48
+ muls_r = muls.sum(dim=-1, keepdims=True)
49
+ sub_r = dp - muls_r
50
+ res = sub_r * softmax_res
51
+ return res
52
+
53
+
54
+ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
55
+ if num_kv_heads == 0 or num_kv_heads > num_heads:
56
+ raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.")
57
+
58
+ factor = num_heads // num_kv_heads
59
+ kv_shape = kv_tensor.shape
60
+ b = kv_shape[0]
61
+ s = kv_shape[2]
62
+ d = kv_shape[3]
63
+ kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
64
+ for i in range(num_heads):
65
+ j = i // factor
66
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
67
+ return kv_res
68
+
69
+
70
+ def calculate_qk(q, k, attn_mask, pse, scalar_value):
71
+ if k.dim() != 4:
72
+ raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})")
73
+
74
+ if k.dim() == 3:
75
+ k = k.unsqueeze(1) # 在head维度扩展
76
+
77
+ if pse is None or len(pse.shape) == 0:
78
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value)
79
+ else:
80
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value)
81
+ if attn_mask is None or len(attn_mask.shape) == 0:
82
+ return qk
83
+ else:
84
+ qk = qk + attn_mask.bool() * (-40000.0) # -10000
85
+ return qk
86
+
87
+
88
+ def fusion_attention_forward(forward_params):
89
+ q = forward_params.q
90
+ k = forward_params.k
91
+ v = forward_params.v
92
+ drop_mask = forward_params.drop_mask
93
+ attn_mask = forward_params.attn_mask
94
+ pse = forward_params.pse
95
+ scalar_value = forward_params.scalar_value
96
+ keep_prob = forward_params.keep_prob
97
+
98
+ qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
99
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
100
+ if drop_mask is None or len(drop_mask.shape) == 0:
101
+ drop_res = softmax_res
102
+ else:
103
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
104
+ y = torch.matmul(drop_res, v)
105
+ return y, softmax_max, softmax_sum
106
+
107
+
108
+ def fusion_attention_backward(backward_params):
109
+ dx = backward_params.dx
110
+ q = backward_params.q
111
+ k = backward_params.k
112
+ v = backward_params.v
113
+ softmax_res = backward_params.softmax_res
114
+ drop_mask = backward_params.drop_mask
115
+ pse = backward_params.pse
116
+ scalar_value = backward_params.scalar_value
117
+ keep_prob = backward_params.keep_prob
118
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
119
+ if drop_mask is None or len(drop_mask.shape) == 0:
120
+ drop_res = softmax_res.permute(0, 1, 3, 2)
121
+ dp_drop = dp
122
+ else:
123
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
124
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
125
+ dv = torch.matmul(drop_res, dx)
126
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value)
127
+ dq = torch.matmul(softmax_grad_res, k)
128
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
129
+ return dq, dk, dv
130
+
131
+
132
+ def parse_bsnd_args(query, key, head_num, input_layout):
133
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
134
+ b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
135
+
136
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
137
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
138
+
139
+ if input_layout == "TND":
140
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
141
+ try:
142
+ if input_layout == "BSH":
143
+ b, s1, h1 = query.shape
144
+ _, s2, h2 = key.shape
145
+ d = h1 // n1
146
+ n2 = h2 // d
147
+ elif input_layout == "SBH":
148
+ s1, b, h1 = query.shape
149
+ s2, _, h2 = key.shape
150
+ d = h1 // n1
151
+ n2 = h2 // d
152
+ elif input_layout == "BSND":
153
+ b, s1, n1, d = query.shape
154
+ _, s2, n2, _ = key.shape
155
+ h1 = n1 * d
156
+ h2 = n2 * d
157
+ elif input_layout == "BNSD":
158
+ b, n1, s1, d = query.shape
159
+ _, n2, s2, _ = key.shape
160
+ h1 = n1 * d
161
+ h2 = n2 * d
162
+ except Exception as e:
163
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
164
+
165
+ if d == 0:
166
+ raise ValueError(f"Value d must be non-zero.")
167
+ _dtype = query.dtype
168
+ ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
169
+ return ret
170
+
171
+
172
+ def convert_from_bnsd(_input, input_layout):
173
+ """
174
+ transform qkv from bnsd to input_layout.
175
+ B: batch_size
176
+ S: sequence_length
177
+ N: num_heads
178
+ D: head_dim
179
+ Args:
180
+ _input (torch.Tensor): tensor of shape (B,N,S,D)
181
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
182
+ Returns:
183
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
184
+ """
185
+ if input_layout == "BSH":
186
+ # (B,N,S,D)=>(B,S,N*D)
187
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
188
+ elif input_layout == "SBH":
189
+ # (B,N,S,D)=>(S,B,N*D)
190
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
191
+ elif input_layout == "BSND":
192
+ # (B,N,S,D)=>(B,S,N,D)
193
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
194
+ elif input_layout == "TND":
195
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
196
+ else:
197
+ out = _input
198
+ return out
199
+
200
+
201
+ def convert_to_bnsd(_input, n, input_layout):
202
+ """
203
+ transform qkv from input_layout to bnsd.
204
+ B: batch_size
205
+ S: sequence_length
206
+ N: num_heads
207
+ D: head_dim
208
+ Args:
209
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
210
+ n (int): num_heads
211
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
212
+ Returns:
213
+ tensor of shape (B,N,S,D)
214
+ """
215
+ if input_layout == "BSH":
216
+ # (B,S,N*D)=>(B,N,S,D)
217
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
218
+ elif input_layout == "SBH":
219
+ # (S,B,N*D)=>(B,N,S,D)
220
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
221
+ elif input_layout == "BSND":
222
+ # (B,S,N,D)=>(B,N,S,D)
223
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
224
+ elif input_layout == "TND":
225
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
226
+ else:
227
+ out = _input
228
+ if out.dim() != 4:
229
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
230
+ return out.to(GTYPE)
231
+
232
+
233
+ def convert_from_bsnd(_input, input_layout):
234
+ """
235
+ transform qkv from bsnd to input_layout.
236
+ B: batch_size
237
+ S: sequence_length
238
+ N: num_heads
239
+ D: head_dim
240
+ Args:
241
+ _input (torch.Tensor): tensor of shape (B,S,N,D)
242
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
243
+ Returns:
244
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
245
+ """
246
+ if input_layout == "BSH":
247
+ # (B,S,N,D)=>(B,S,N*D)
248
+ out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
249
+ elif input_layout == "SBH":
250
+ # (B,S,N,D)=>(S,B,N*D)
251
+ out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
252
+ elif input_layout == "BNSD":
253
+ # (B,S,N,D)=>(B,N,S,D)
254
+ out = rearrange(_input, 'b s n d -> b n s d').contiguous()
255
+ elif input_layout == "TND":
256
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
257
+ else:
258
+ out = _input
259
+ return out
260
+
261
+
262
+ def convert_to_bsnd(_input, n, input_layout):
263
+ """
264
+ transform qkv from input_layout to bsnd.
265
+ B: batch_size
266
+ S: sequence_length
267
+ N: num_heads
268
+ D: head_dim
269
+ Args:
270
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
271
+ n (int): num_heads
272
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
273
+ Returns:
274
+ tensor of shape (B,S,N,D)
275
+ """
276
+ if input_layout == "BSH":
277
+ # (B,S,N*D)=>(B,S,N,D)
278
+ out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
279
+ elif input_layout == "SBH":
280
+ # (S,B,N*D)=>(B,S,N,D)
281
+ out = rearrange(_input, 's b (n d) -> b s n d', n=n)
282
+ elif input_layout == "BNSD":
283
+ # (B,N,S,D)=>(B,S,N,D)
284
+ out = rearrange(_input, 'b n s d -> b s n d', n=n)
285
+ elif input_layout == "TND":
286
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
287
+ else:
288
+ out = _input
289
+ if out.dim() != 4:
290
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
291
+ return out
292
+
293
+
294
+ def generate_attn_mask(*args):
295
+ """
296
+ # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
297
+ ===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
298
+ """
299
+
300
+ sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
301
+ shape = [s1, s2]
302
+
303
+ if attn_mask is not None:
304
+ # 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
305
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
306
+ logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}")
307
+
308
+ if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048:
309
+ if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)):
310
+ if sparse_mode == 2:
311
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
312
+ elif sparse_mode == 3:
313
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
314
+ elif sparse_mode == 4:
315
+ attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
316
+ attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
317
+ attn_mask = attn_mask_u + attn_mask_l
318
+ logger.debug(f"反向转换attn_mask {attn_mask.shape}")
319
+ return attn_mask.to(dtype)
320
+
321
+ return attn_mask.to(dtype)
322
+
323
+ if attn_mask is not None:
324
+ if attn_mask.dim() == 2:
325
+ if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2:
326
+ raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}")
327
+ shape = [s1, s2]
328
+ elif attn_mask.dim() == 4:
329
+ if attn_mask.shape[1] == 1:
330
+ shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
331
+ else:
332
+ shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
333
+
334
+ if sparse_mode == 0:
335
+ attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
336
+ attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
337
+ attn_mask = attn_mask_u + attn_mask_l
338
+ elif sparse_mode == 1: # no sparse
339
+ attn_mask = torch.from_numpy(np.zeros(shape))
340
+ elif sparse_mode == 2:
341
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
342
+ elif sparse_mode == 3:
343
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
344
+ elif sparse_mode == 4:
345
+ attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
346
+ attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
347
+ attn_mask = attn_mask_u + attn_mask_l
348
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS,
349
+ # 因此可以认为FA的输入已经是正确的attn_mask了
350
+ return attn_mask.to(dtype)
351
+
352
+
353
+ def generate_kv(key, value, n1, n2):
354
+ # N不等长适配by cdy
355
+ if not (n1 == n2):
356
+ k_new = broadcast_kv(n1, n2, key, key.dtype)
357
+ v_new = broadcast_kv(n1, n2, value, value.dtype)
358
+ else:
359
+ k_new = key
360
+ v_new = value
361
+ return k_new, v_new
362
+
363
+
364
+ def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value):
365
+ """
366
+ attention = softmax(QK^T/sqrt(d))V
367
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
368
+ """
369
+ logger.info("Using QKV to rebuild original softmax")
370
+ qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
371
+ softmax_res, _, _ = softmax_forward(qk)
372
+ return softmax_res
373
+
374
+
375
+ def rebuild_softmax_by_max_sum(softmax_params):
376
+ """
377
+ attention = softmax(QK^T/sqrt(d))V
378
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
379
+ """
380
+ q = softmax_params.q
381
+ k = softmax_params.k
382
+ attn_mask = softmax_params.attn_mask
383
+ pse = softmax_params.pse
384
+ scalar_value = softmax_params.scalar_value
385
+ softmax_max = softmax_params.softmax_max
386
+ softmax_sum = softmax_params.softmax_sum
387
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
388
+
389
+ qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
390
+ if softmax_max.shape[-1] == 0:
391
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
392
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
393
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
394
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
395
+ return softmax_res
396
+
397
+
398
+ def get_head_num(*args, **kwargs):
399
+ if kwargs.get("head_num", None):
400
+ head_num = kwargs.get("head_num")
401
+ elif len(args) >= 4:
402
+ head_num = args[3]
403
+ else:
404
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
405
+ return head_num
406
+
407
+
408
+ def get_input_layout(*args, **kwargs):
409
+ if kwargs.get("input_layout", None):
410
+ input_layout = kwargs.get("input_layout")
411
+ elif len(args) >= 5:
412
+ input_layout = args[4]
413
+ else:
414
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
415
+ return input_layout
416
+
417
+
418
+ def npu_fusion_attention_forward_patch(*args, **kwargs):
419
+ if len(args) < 2:
420
+ raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
421
+
422
+ # query, key, value, head_num, input_layout
423
+ head_num = get_head_num(*args, **kwargs)
424
+ input_layout = get_input_layout(*args, **kwargs)
425
+
426
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
427
+ if n1 == n2 and s1 == s2:
428
+ logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
429
+ else:
430
+ logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
431
+ if not (n1 % n2 == 0 and n1 >= n2):
432
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
433
+
434
+ dims_kwargs = {
435
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
436
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
437
+ }
438
+ new_kwargs = {
439
+ "keep_prob": 1,
440
+ "scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)),
441
+ "sparse_mode": kwargs.get("sparse_mode", 0),
442
+ "prefix": kwargs.get("prefix"),
443
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
444
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
445
+ "pse": kwargs.get("pse"),
446
+ "padding_mask": kwargs.get("padding_mask"),
447
+ "attn_mask": kwargs.get("attn_mask")
448
+ }
449
+
450
+ return args, dims_kwargs, new_kwargs
451
+
452
+
453
+ def npu_fusion_attention_backward_patch(*args, **kwargs):
454
+ if len(args) != 6:
455
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
456
+
457
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
458
+ if n1 == n2 and s1 == s2:
459
+ logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
460
+ else:
461
+ logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
462
+ if not (n1 % n2 == 0 and n1 >= n2):
463
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
464
+
465
+ dims_kwargs = {
466
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
467
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
468
+ }
469
+
470
+ new_kwargs = {
471
+ "keep_prob": 1,
472
+ "scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)),
473
+ "sparse_mode": kwargs.get("sparse_mode", 0),
474
+ "prefix": kwargs.get("prefix"),
475
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
476
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
477
+ "pse": kwargs.get("pse"),
478
+ "padding_mask": kwargs.get("padding_mask"),
479
+ "softmax_max": kwargs.get("softmax_max"),
480
+ "softmax_sum": kwargs.get("softmax_sum"),
481
+ "softmax_in": kwargs.get("softmax_in"),
482
+ "attention_in": kwargs.get("attention_in"),
483
+ "seed": kwargs.get("seed", 0),
484
+ "offset": kwargs.get("offset", 0),
485
+ "numels": kwargs.get("numels", 0),
486
+ "attn_mask": kwargs.get("attn_mask")
487
+ }
488
+
489
+ return args, dims_kwargs, new_kwargs
490
+
491
+
492
+ class FlashAttentionScore(nn.Module):
493
+ def __init__(self):
494
+ super(FlashAttentionScore, self).__init__()
495
+ # You can initialize any parameters here if necessary
496
+
497
+ def forward(self, *inputs, **kwargs):
498
+ # Extract the inputs for the attention calculation
499
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs)
500
+ query, key, value = new_args[0], new_args[1], new_args[2]
501
+
502
+ input_layout = get_input_layout(*inputs, **kwargs)
503
+
504
+ n1 = dims_kwargs.get("n1")
505
+ n2 = dims_kwargs.get("n2")
506
+ s1 = dims_kwargs.get("s1")
507
+ s2 = dims_kwargs.get("s2")
508
+ b = dims_kwargs.get("b")
509
+ dtype = dims_kwargs.get("dtype")
510
+ attn_mask = new_kwargs.get("attn_mask")
511
+ keep_prob = new_kwargs.get("keep_prob")
512
+ sparse_mode = new_kwargs.get("sparse_mode")
513
+ pre_tockens = new_kwargs.get("pre_tockens")
514
+ next_tockens = new_kwargs.get("next_tokens")
515
+ pse = new_kwargs.get("real_shift")
516
+ scalar_value = new_kwargs.get("scalar_value")
517
+
518
+ args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
519
+
520
+ attn_mask = generate_attn_mask(*args_temp)
521
+ query = convert_to_bnsd(query, n1, input_layout)
522
+ key = convert_to_bnsd(key, n2, input_layout)
523
+ value = convert_to_bnsd(value, n2, input_layout)
524
+
525
+ forward_params = FaForwardParams(
526
+ q=query,
527
+ k=key,
528
+ v=value,
529
+ drop_mask=None,
530
+ attn_mask=attn_mask,
531
+ pse=pse,
532
+ scalar_value=scalar_value,
533
+ keep_prob=keep_prob
534
+ )
535
+
536
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
537
+
538
+ # If output dimension is 5, reshape accordingly
539
+ if out_golden.dim() == 5:
540
+ out_golden = out_golden.reshape(out_golden.size(0),
541
+ out_golden.size(1) * out_golden.size(2),
542
+ out_golden.size(3), out_golden.size(4))
543
+
544
+ out_golden = convert_from_bnsd(out_golden, input_layout)
545
+
546
+ # Ensure the output matches the desired layout
547
+ out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
548
+
549
+ return out_golden
550
+
551
+ def backward(self, *inputs, **kwargs):
552
+ # The backward pass will be similar to what was described for the gradient computation
553
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs)
554
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
555
+ n1 = dims_kwargs.get("n1")
556
+ n2 = dims_kwargs.get("n2")
557
+ s1 = dims_kwargs.get("s1")
558
+ s2 = dims_kwargs.get("s2")
559
+ b = dims_kwargs.get("b")
560
+ dtype = dims_kwargs.get("dtype")
561
+ attn_mask = new_kwargs.get("attn_mask")
562
+ keep_prob = new_kwargs.get("keep_prob")
563
+ sparse_mode = new_kwargs.get("sparse_mode")
564
+ pre_tockens = new_kwargs.get("pre_tockens")
565
+ next_tockens = new_kwargs.get("next_tockens")
566
+ pse = new_kwargs.get("pse")
567
+ softmax_max = new_kwargs.get("softmax_max")
568
+ softmax_sum = new_kwargs.get("softmax_sum")
569
+ scalar_value = new_kwargs.get("scalar_value")
570
+
571
+ args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
572
+ attn_mask = generate_attn_mask(*args_temp)
573
+
574
+ query = convert_to_bnsd(query, n1, input_layout)
575
+ dx = convert_to_bnsd(dx, n1, input_layout)
576
+ key = convert_to_bnsd(key, n2, input_layout)
577
+ value = convert_to_bnsd(value, n2, input_layout)
578
+
579
+ k_new, v_new = generate_kv(key, value, n1, n2)
580
+
581
+ if SOFTMAX_BUILD_MODE == "QKV":
582
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value)
583
+ else:
584
+ softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum)
585
+ softmax_res = rebuild_softmax_by_max_sum(softmax_params)
586
+
587
+ backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob)
588
+ dq, dk, dv = fusion_attention_backward(backward_params)
589
+
590
+ # Reshape as needed
591
+ if dq.dim() == 5:
592
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
593
+ if dk.dim() == 5:
594
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
595
+ if dv.dim() == 5:
596
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
597
+
598
+ dq = convert_from_bnsd(dq, input_layout)
599
+ dk = convert_from_bnsd(dk, input_layout)
600
+ dv = convert_from_bnsd(dv, input_layout)
601
+
602
+ return dq.cpu(), dk.cpu(), dv.cpu()
@@ -0,0 +1,41 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore
17
+
18
+
19
+ class FusionOperator:
20
+ """
21
+ 所有融合算子的父类,定义了通用的接口和属性。
22
+ """
23
+
24
+ # 初始化操作符字典
25
+ def __init__(self):
26
+ self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符
27
+ self._register_operators()
28
+
29
+ def __getattr__(self, name):
30
+ """ 动态获取算子类 """
31
+ if hasattr(self, name):
32
+ return getattr(self, name)
33
+ else:
34
+ raise AttributeError(f"'FusionOperator' object has no attribute '{name}'")
35
+
36
+ def _register_operators(self):
37
+ """ 注册操作符到父类,以便通过 fusion.xxx 调用 """
38
+ self.flash_attention_score = FlashAttentionScore()
39
+
40
+
41
+ fusion = FusionOperator()
@@ -16,12 +16,13 @@
16
16
  import os
17
17
  import csv
18
18
 
19
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
19
+ from msprobe.core.common.const import Const, CompareConst
20
20
  from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
21
21
  from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
22
22
  from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
23
23
  from msprobe.core.common.file_utils import check_file_or_directory_path
24
24
  from msprobe.mindspore.common.log import logger
25
+ from msprobe.mindspore.common.const import MsCompareConst
25
26
 
26
27
 
27
28
  class ResultCsvEntry:
@@ -27,10 +27,11 @@ import numpy as np
27
27
  from tqdm import tqdm
28
28
 
29
29
  # 本地应用/库特定导入
30
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
30
+ from msprobe.core.common.const import Const, CompareConst
31
31
  from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
32
32
  from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
33
33
  from msprobe.mindspore.common.log import logger
34
+ from msprobe.mindspore.common.const import MsCompareConst
34
35
 
35
36
 
36
37
  class MultiApiAccuracyChecker(ApiAccuracyChecker):
@@ -19,7 +19,8 @@ import sys
19
19
  from pathlib import Path
20
20
  import mindspore
21
21
  from msprobe.mindspore.common.log import logger
22
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
22
+ from msprobe.core.common.const import Const, CompareConst
23
+ from msprobe.mindspore.common.const import MsCompareConst
23
24
  import torch as mindtorch
24
25
  from torch import Tensor as mindtorch_tensor
25
26
  import torch.nn.functional as mindtorch_func