mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,580 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+ import torch
18
+ import torch.nn as nn
19
+ import numpy as np
20
+
21
+ from einops import rearrange
22
+
23
+
24
+ from msprobe.pytorch.common.utils import logger
25
+
26
+ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
27
+ SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
28
+
29
+ FaForwardParams = namedtuple("FaForwardParams",
30
+ ["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"])
31
+ FaBackwardParams = namedtuple("FaBackwardParams",
32
+ ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"])
33
+ RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
34
+ ["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"])
35
+
36
+
37
+ def softmax_forward(x):
38
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
39
+ x_sub = x.sub(x_max)
40
+ y = torch.exp(x_sub)
41
+ x_sum = y.sum(dim=-1, keepdims=True)
42
+ res = y.div(x_sum)
43
+ return res, x_max, x_sum
44
+
45
+
46
+ def softmax_grad(dp, softmax_res):
47
+ muls = dp * softmax_res
48
+ muls_r = muls.sum(dim=-1, keepdims=True)
49
+ sub_r = dp - muls_r
50
+ res = sub_r * softmax_res
51
+ return res
52
+
53
+
54
+ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
55
+ # 检查维度
56
+ if kv_tensor.dim() != 4:
57
+ raise ValueError(f"broadcast_kv: kv_tensor 必须是 4 维 (B, N_kv, S, D),但得到 {kv_tensor.shape}")
58
+ if num_kv_heads == 0 or num_kv_heads > num_heads:
59
+ raise ValueError("broadcast_kv: num_kv_heads 必须大于 0 且不超过 num_heads。")
60
+ if num_heads % num_kv_heads != 0:
61
+ raise ValueError(f"broadcast_kv: num_heads({num_heads}) 必须能被 num_kv_heads({num_kv_heads}) 整除。")
62
+
63
+
64
+ factor = num_heads // num_kv_heads
65
+ kv_shape = kv_tensor.shape
66
+ b = kv_shape[0]
67
+ s = kv_shape[2]
68
+ d = kv_shape[3]
69
+ kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
70
+ for i in range(num_heads):
71
+ j = i // factor
72
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
73
+ return kv_res
74
+
75
+
76
+ def calculate_qk(q, k, attn_mask, pse, scalar_value):
77
+ # 基本形状检查
78
+ if q.dim() < 4 or k.dim() < 4:
79
+ raise ValueError(f"calculate_qk: q,k 必须至少 4 维,q={q.dim()},k={k.dim()}")
80
+ # 检查 head_dim 一致性
81
+ if q.size(-1) != k.size(-1):
82
+ raise ValueError(f"calculate_qk: q.head_dim({q.size(-1)}) != k.head_dim({k.size(-1)})")
83
+
84
+ if k.dim() != 4:
85
+ raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})")
86
+
87
+ if k.dim() == 3:
88
+ k = k.unsqueeze(1) # 在head维度扩展
89
+
90
+ if pse is None or len(pse.shape) == 0:
91
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value)
92
+ else:
93
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value)
94
+ if attn_mask is None or len(attn_mask.shape) == 0:
95
+ return qk
96
+ else:
97
+ qk = qk + attn_mask.bool() * (-40000.0) # -10000
98
+ return qk
99
+
100
+
101
+ def fusion_attention_forward(forward_params):
102
+ q = forward_params.q
103
+ k = forward_params.k
104
+ v = forward_params.v
105
+ drop_mask = forward_params.drop_mask
106
+ attn_mask = forward_params.attn_mask
107
+ pse = forward_params.pse
108
+ scalar_value = forward_params.scalar_value
109
+ keep_prob = forward_params.keep_prob
110
+
111
+ # 拦截 keep_prob 为 0 的情况,防止除零
112
+ if keep_prob == 0:
113
+ raise ValueError("fusion_attention_forward: keep_prob 不能为 0,避免除零错误。")
114
+
115
+ qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
116
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
117
+ if drop_mask is None or len(drop_mask.shape) == 0:
118
+ drop_res = softmax_res
119
+ else:
120
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
121
+ y = torch.matmul(drop_res, v)
122
+ return y, softmax_max, softmax_sum
123
+
124
+
125
+ def fusion_attention_backward(backward_params):
126
+ dx = backward_params.dx
127
+ q = backward_params.q
128
+ k = backward_params.k
129
+ v = backward_params.v
130
+ softmax_res = backward_params.softmax_res
131
+ drop_mask = backward_params.drop_mask
132
+ pse = backward_params.pse
133
+ scalar_value = backward_params.scalar_value
134
+ keep_prob = backward_params.keep_prob
135
+
136
+ # 拦截 keep_prob 为 0 的情况,防止除零
137
+ if keep_prob == 0:
138
+ raise ValueError("fusion_attention_backward: keep_prob 不能为 0,避免除零错误。")
139
+
140
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
141
+ if drop_mask is None or len(drop_mask.shape) == 0:
142
+ drop_res = softmax_res.permute(0, 1, 3, 2)
143
+ dp_drop = dp
144
+ else:
145
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
146
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
147
+ dv = torch.matmul(drop_res, dx)
148
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value)
149
+ dq = torch.matmul(softmax_grad_res, k)
150
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
151
+ return dq, dk, dv
152
+
153
+
154
+ def parse_bsnd_args(query, key, head_num, input_layout):
155
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
156
+ b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
157
+
158
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
159
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
160
+
161
+ if input_layout == "TND":
162
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
163
+
164
+ # 防止 head_num 为 0
165
+ if n1 == 0:
166
+ raise ValueError("parse_bsnd_args: head_num (n1) 不能为 0,避免除零错误。")
167
+
168
+ try:
169
+ if input_layout == "BSH":
170
+ b, s1, h1 = query.shape
171
+ _, s2, h2 = key.shape
172
+ d = h1 // n1
173
+ # 拦截 d 为 0 的情况
174
+ if d == 0:
175
+ raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
176
+ n2 = h2 // d
177
+ elif input_layout == "SBH":
178
+ s1, b, h1 = query.shape
179
+ s2, _, h2 = key.shape
180
+ d = h1 // n1
181
+ if d == 0:
182
+ raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
183
+ n2 = h2 // d
184
+ elif input_layout == "BSND":
185
+ b, s1, n1, d = query.shape
186
+ _, s2, n2, _ = key.shape
187
+ if d == 0:
188
+ raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
189
+ h1 = n1 * d
190
+ h2 = n2 * d
191
+ elif input_layout == "BNSD":
192
+ b, n1, s1, d = query.shape
193
+ _, n2, s2, _ = key.shape
194
+ if d == 0:
195
+ raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
196
+ h1 = n1 * d
197
+ h2 = n2 * d
198
+ except Exception as e:
199
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
200
+
201
+ ret = (b, s1, s2, n1, n2, d, h1, h2, query.dtype)
202
+ return ret
203
+
204
+
205
+ def convert_from_bnsd(_input, input_layout):
206
+ """
207
+ transform qkv from bnsd to input_layout.
208
+ B: batch_size
209
+ S: sequence_length
210
+ N: num_heads
211
+ D: head_dim
212
+ Args:
213
+ _input (torch.Tensor): tensor of shape (B,N,S,D)
214
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
215
+ Returns:
216
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
217
+ """
218
+ if input_layout == "BSH":
219
+ # (B,N,S,D)=>(B,S,N*D)
220
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
221
+ elif input_layout == "SBH":
222
+ # (B,N,S,D)=>(S,B,N*D)
223
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
224
+ elif input_layout == "BSND":
225
+ # (B,N,S,D)=>(B,S,N,D)
226
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
227
+ elif input_layout == "TND":
228
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
229
+ else:
230
+ out = _input
231
+ return out
232
+
233
+
234
+ def convert_to_bnsd(_input, n, input_layout):
235
+ """
236
+ transform qkv from input_layout to bnsd.
237
+ B: batch_size
238
+ S: sequence_length
239
+ N: num_heads
240
+ D: head_dim
241
+ Args:
242
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
243
+ n (int): num_heads
244
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
245
+ Returns:
246
+ tensor of shape (B,N,S,D)
247
+ """
248
+ if input_layout == "BSH":
249
+ # (B,S,N*D)=>(B,N,S,D)
250
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
251
+ elif input_layout == "SBH":
252
+ # (S,B,N*D)=>(B,N,S,D)
253
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
254
+ elif input_layout == "BSND":
255
+ # (B,S,N,D)=>(B,N,S,D)
256
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
257
+ elif input_layout == "TND":
258
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
259
+ else:
260
+ out = _input
261
+ if out.dim() != 4:
262
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
263
+ return out.to(GTYPE)
264
+
265
+
266
+ def generate_attn_mask(*args):
267
+ """
268
+ # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
269
+ ===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
270
+ """
271
+
272
+ sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
273
+ shape = [s1, s2]
274
+
275
+ if attn_mask is not None:
276
+ # 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
277
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
278
+ logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}")
279
+
280
+ if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048:
281
+ if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)):
282
+ if sparse_mode == 2:
283
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
284
+ elif sparse_mode == 3:
285
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
286
+ elif sparse_mode == 4:
287
+ attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
288
+ attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
289
+ attn_mask = attn_mask_u + attn_mask_l
290
+ logger.debug(f"反向转换attn_mask {attn_mask.shape}")
291
+ return attn_mask.to(dtype)
292
+
293
+ return attn_mask.to(dtype)
294
+
295
+ if attn_mask is not None:
296
+ if attn_mask.dim() == 2:
297
+ if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2:
298
+ raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}")
299
+ shape = [s1, s2]
300
+ elif attn_mask.dim() == 4:
301
+ if attn_mask.shape[1] == 1:
302
+ shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
303
+ else:
304
+ shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
305
+
306
+ if sparse_mode == 0:
307
+ attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
308
+ attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
309
+ attn_mask = attn_mask_u + attn_mask_l
310
+ elif sparse_mode == 1: # no sparse
311
+ attn_mask = torch.from_numpy(np.zeros(shape))
312
+ elif sparse_mode == 2:
313
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
314
+ elif sparse_mode == 3:
315
+ attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
316
+ elif sparse_mode == 4:
317
+ attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
318
+ attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
319
+ attn_mask = attn_mask_u + attn_mask_l
320
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS,
321
+ # 因此可以认为FA的输入已经是正确的attn_mask了
322
+ return attn_mask.to(dtype)
323
+
324
+
325
+ def generate_kv(key, value, n1, n2):
326
+ # N不等长适配by cdy
327
+ if not (n1 == n2):
328
+ k_new = broadcast_kv(n1, n2, key, key.dtype)
329
+ v_new = broadcast_kv(n1, n2, value, value.dtype)
330
+ else:
331
+ k_new = key
332
+ v_new = value
333
+ return k_new, v_new
334
+
335
+
336
+ def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value):
337
+ """
338
+ attention = softmax(QK^T/sqrt(d))V
339
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
340
+ """
341
+ logger.info("Using QKV to rebuild original softmax")
342
+ qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
343
+ softmax_res, _, _ = softmax_forward(qk)
344
+ return softmax_res
345
+
346
+
347
+ def rebuild_softmax_by_max_sum(softmax_params):
348
+ """
349
+ attention = softmax(QK^T/sqrt(d))V
350
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
351
+ """
352
+ q = softmax_params.q
353
+ k = softmax_params.k
354
+ attn_mask = softmax_params.attn_mask
355
+ pse = softmax_params.pse
356
+ scalar_value = softmax_params.scalar_value
357
+ softmax_max = softmax_params.softmax_max
358
+ softmax_sum = softmax_params.softmax_sum
359
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
360
+
361
+ qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
362
+ if softmax_max.shape[-1] == 0:
363
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
364
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
365
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
366
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
367
+ return softmax_res
368
+
369
+
370
+ def get_head_num(*args, **kwargs):
371
+ if kwargs.get("head_num", None):
372
+ head_num = kwargs.get("head_num")
373
+ elif len(args) >= 4:
374
+ head_num = args[3]
375
+ else:
376
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
377
+ return head_num
378
+
379
+
380
+ def get_input_layout(*args, **kwargs):
381
+ if kwargs.get("input_layout", None):
382
+ input_layout = kwargs.get("input_layout")
383
+ elif len(args) >= 5:
384
+ input_layout = args[4]
385
+ else:
386
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
387
+ return input_layout
388
+
389
+
390
+ def npu_fusion_attention_forward_patch(*args, **kwargs):
391
+ if len(args) < 2:
392
+ raise RuntimeError("npu_fusion_attention_forward_patch: length of args should be greater than or equal to 2.")
393
+
394
+ # query, key, value, head_num, input_layout
395
+ head_num = get_head_num(*args, **kwargs)
396
+ input_layout = get_input_layout(*args, **kwargs)
397
+
398
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
399
+ # 此处 d 已在 parse_bsnd_args 中检查为非零
400
+ if n1 == n2 and s1 == s2:
401
+ logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
402
+ else:
403
+ logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
404
+ if n2 == 0:
405
+ raise ValueError("n2 不能为 0,避免除零错误。")
406
+ if not (n1 % n2 == 0 and n1 >= n2):
407
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
408
+
409
+ dims_kwargs = {
410
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
411
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
412
+ }
413
+ new_kwargs = {
414
+ "keep_prob": 1, # 注意:如果外部传入 keep_prob 为 0,也会在 fusion_attention_forward 中捕获
415
+ "scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)),
416
+ "sparse_mode": kwargs.get("sparse_mode", 0),
417
+ "prefix": kwargs.get("prefix"),
418
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
419
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
420
+ "pse": kwargs.get("pse"),
421
+ "padding_mask": kwargs.get("padding_mask"),
422
+ "attn_mask": kwargs.get("attn_mask")
423
+ }
424
+
425
+ return args, dims_kwargs, new_kwargs
426
+
427
+
428
+ def npu_fusion_attention_backward_patch(*args, **kwargs):
429
+ if len(args) != 6:
430
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
431
+
432
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
433
+ # 此处 d 已在 parse_bsnd_args 中检查为非零
434
+ if n1 == n2 and s1 == s2:
435
+ logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
436
+ else:
437
+ logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
438
+ if n2 == 0:
439
+ raise ValueError("n2 不能为 0,避免除零错误。")
440
+ if not (n1 % n2 == 0 and n1 >= n2):
441
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
442
+
443
+ dims_kwargs = {
444
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
445
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
446
+ }
447
+
448
+ new_kwargs = {
449
+ "keep_prob": 1, # 同上,fusion_attention_backward 内会拦截 keep_prob 为 0 的情况
450
+ "scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)),
451
+ "sparse_mode": kwargs.get("sparse_mode", 0),
452
+ "prefix": kwargs.get("prefix"),
453
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
454
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
455
+ "pse": kwargs.get("pse"),
456
+ "padding_mask": kwargs.get("padding_mask"),
457
+ "softmax_max": kwargs.get("softmax_max"),
458
+ "softmax_sum": kwargs.get("softmax_sum"),
459
+ "softmax_in": kwargs.get("softmax_in"),
460
+ "attention_in": kwargs.get("attention_in"),
461
+ "seed": kwargs.get("seed", 0),
462
+ "offset": kwargs.get("offset", 0),
463
+ "numels": kwargs.get("numels", 0),
464
+ "attn_mask": kwargs.get("attn_mask")
465
+ }
466
+
467
+ return args, dims_kwargs, new_kwargs
468
+
469
+
470
+ class FlashAttentionScore(nn.Module):
471
+ def __init__(self):
472
+ super(FlashAttentionScore, self).__init__()
473
+ # You can initialize any parameters here if necessary
474
+
475
+ def forward(self, *inputs, **kwargs):
476
+ # Extract the inputs for the attention calculation
477
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs)
478
+ query, key, value = new_args[0], new_args[1], new_args[2]
479
+
480
+ input_layout = get_input_layout(*inputs, **kwargs)
481
+
482
+ n1 = dims_kwargs.get("n1")
483
+ n2 = dims_kwargs.get("n2")
484
+ s1 = dims_kwargs.get("s1")
485
+ s2 = dims_kwargs.get("s2")
486
+ b = dims_kwargs.get("b")
487
+ dtype = dims_kwargs.get("dtype")
488
+ attn_mask = new_kwargs.get("attn_mask")
489
+ keep_prob = new_kwargs.get("keep_prob")
490
+ sparse_mode = new_kwargs.get("sparse_mode")
491
+ pre_tockens = new_kwargs.get("pre_tockens")
492
+ next_tockens = new_kwargs.get("next_tokens")
493
+ pse = new_kwargs.get("real_shift")
494
+ scalar_value = new_kwargs.get("scalar_value")
495
+
496
+ args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
497
+
498
+ attn_mask = generate_attn_mask(*args_temp)
499
+ query = convert_to_bnsd(query, n1, input_layout)
500
+ key = convert_to_bnsd(key, n2, input_layout)
501
+ value = convert_to_bnsd(value, n2, input_layout)
502
+
503
+ forward_params = FaForwardParams(
504
+ q=query,
505
+ k=key,
506
+ v=value,
507
+ drop_mask=None,
508
+ attn_mask=attn_mask,
509
+ pse=pse,
510
+ scalar_value=scalar_value,
511
+ keep_prob=keep_prob
512
+ )
513
+
514
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
515
+
516
+ # If output dimension is 5, reshape accordingly
517
+ if out_golden.dim() == 5:
518
+ out_golden = out_golden.reshape(out_golden.size(0),
519
+ out_golden.size(1) * out_golden.size(2),
520
+ out_golden.size(3), out_golden.size(4))
521
+
522
+ out_golden = convert_from_bnsd(out_golden, input_layout)
523
+
524
+ # Ensure the output matches the desired layout
525
+ out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
526
+
527
+ return out_golden
528
+
529
+ def backward(self, *inputs, **kwargs):
530
+ # The backward pass will be similar to what was described for the gradient computation
531
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs)
532
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
533
+ n1 = dims_kwargs.get("n1")
534
+ n2 = dims_kwargs.get("n2")
535
+ s1 = dims_kwargs.get("s1")
536
+ s2 = dims_kwargs.get("s2")
537
+ b = dims_kwargs.get("b")
538
+ dtype = dims_kwargs.get("dtype")
539
+ attn_mask = new_kwargs.get("attn_mask")
540
+ keep_prob = new_kwargs.get("keep_prob")
541
+ sparse_mode = new_kwargs.get("sparse_mode")
542
+ pre_tockens = new_kwargs.get("pre_tockens")
543
+ next_tockens = new_kwargs.get("next_tockens")
544
+ pse = new_kwargs.get("pse")
545
+ softmax_max = new_kwargs.get("softmax_max")
546
+ softmax_sum = new_kwargs.get("softmax_sum")
547
+ scalar_value = new_kwargs.get("scalar_value")
548
+
549
+ args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
550
+ attn_mask = generate_attn_mask(*args_temp)
551
+
552
+ query = convert_to_bnsd(query, n1, input_layout)
553
+ dx = convert_to_bnsd(dx, n1, input_layout)
554
+ key = convert_to_bnsd(key, n2, input_layout)
555
+ value = convert_to_bnsd(value, n2, input_layout)
556
+
557
+ k_new, v_new = generate_kv(key, value, n1, n2)
558
+
559
+ if SOFTMAX_BUILD_MODE == "QKV":
560
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value)
561
+ else:
562
+ softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum)
563
+ softmax_res = rebuild_softmax_by_max_sum(softmax_params)
564
+
565
+ backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob)
566
+ dq, dk, dv = fusion_attention_backward(backward_params)
567
+
568
+ # Reshape as needed
569
+ if dq.dim() == 5:
570
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
571
+ if dk.dim() == 5:
572
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
573
+ if dv.dim() == 5:
574
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
575
+
576
+ dq = convert_from_bnsd(dq, input_layout)
577
+ dk = convert_from_bnsd(dk, input_layout)
578
+ dv = convert_from_bnsd(dv, input_layout)
579
+
580
+ return dq.cpu(), dk.cpu(), dv.cpu()
@@ -0,0 +1,41 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore
17
+
18
+
19
+ class FusionOperator:
20
+ """
21
+ 所有融合算子的父类,定义了通用的接口和属性。
22
+ """
23
+
24
+ # 初始化操作符字典
25
+ def __init__(self):
26
+ self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符
27
+ self._register_operators()
28
+
29
+ def __getattr__(self, name):
30
+ """ 动态获取算子类 """
31
+ if hasattr(self, name):
32
+ return getattr(self, name)
33
+ else:
34
+ raise AttributeError(f"'FusionOperator' object has no attribute '{name}'")
35
+
36
+ def _register_operators(self):
37
+ """ 注册操作符到父类,以便通过 fusion.xxx 调用 """
38
+ self.flash_attention_score = FlashAttentionScore()
39
+
40
+
41
+ fusion = FusionOperator()
@@ -39,6 +39,8 @@ def add_api_accuracy_checker_argument(parser):
39
39
  help="<optional> The ut task result out path.")
40
40
  parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
41
41
  help="<optional> the exit csv for continue")
42
+ parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
43
+ help="<optional> Save compare failed api output.", required=False)
42
44
 
43
45
 
44
46
  def multi_add_api_accuracy_checker_argument(parser):
@@ -49,6 +51,8 @@ def multi_add_api_accuracy_checker_argument(parser):
49
51
  help="<optional> The ut task result out path.")
50
52
  parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
51
53
  help="<optional> the exit csv for continue")
54
+ parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
55
+ help="<optional> Save compare failed api output.", required=False)
52
56
  #以下属于多线程参数
53
57
  parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
54
58
  help="<optional> set device id to run ut, must be unique and in range 0-7",
@@ -16,12 +16,13 @@
16
16
  import os
17
17
  import csv
18
18
 
19
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
19
+ from msprobe.core.common.const import Const, CompareConst
20
20
  from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
21
21
  from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
22
22
  from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
23
23
  from msprobe.core.common.file_utils import check_file_or_directory_path
24
24
  from msprobe.mindspore.common.log import logger
25
+ from msprobe.mindspore.common.const import MsCompareConst
25
26
 
26
27
 
27
28
  class ResultCsvEntry:
@@ -187,7 +188,7 @@ class DataManager:
187
188
 
188
189
  def record_exception_skip(self, api_name, forward_or_backward, err_msg):
189
190
  '''
190
- record exception_skip infomation into self.record_exception_skip.
191
+ record exception_skip information into self.record_exception_skip.
191
192
  self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
192
193
  string in key is api_name, string in value is err_msg
193
194
  '''
@@ -269,7 +270,7 @@ class DataManager:
269
270
  entry.backward_pass_status,
270
271
  overall_err_msg
271
272
  ]
272
- # change row if this api has excption_skip infomation
273
+ # change row if this api has exception_skip information
273
274
  if api_name in self.results_exception_skip:
274
275
  if self.results_exception_skip[api_name][Const.FORWARD] is not None:
275
276
  row[1] = CompareConst.SKIP
@@ -0,0 +1,9 @@
1
+ {
2
+ "dump_json_path": "./dump.json",
3
+ "api_name": "Mint.split.1",
4
+ "extract_api_path": "Mint.split.1.json",
5
+ "propagation": "backward",
6
+ "data_mode": "random_data",
7
+ "random_seed": 1234,
8
+ "iter_times": 1
9
+ }