mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,421 @@
1
+ import torch
2
+ import numpy as np
3
+ from einops import rearrange
4
+
5
+ from msprobe.pytorch.common.utils import logger
6
+
7
+ gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
8
+ softmax_build_mode = "QKV" # "MAX_SUM"
9
+
10
+ """
11
+ # 前向函数声明对比
12
+ 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
13
+ 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
14
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
15
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
16
+ gen_mask_parallel=True, sync=False
17
+
18
+ # 反向函数声明对比
19
+ 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
20
+ 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
21
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
22
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
23
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
24
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
25
+ """
26
+
27
+
28
+ def softmax_forward(x):
29
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
30
+ x_sub = x.sub(x_max)
31
+ y = torch.exp(x_sub)
32
+ x_sum = y.sum(dim=-1, keepdims=True)
33
+ res = y.div(x_sum)
34
+ return res, x_max, x_sum
35
+
36
+
37
+ def softmax_grad(dp, softmax_res):
38
+ muls = dp * softmax_res
39
+ muls_r = muls.sum(dim=-1, keepdims=True)
40
+ sub_r = dp - muls_r
41
+ res = sub_r * softmax_res
42
+ return res
43
+
44
+
45
+ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
46
+ if num_kv_heads == 0 or num_kv_heads < num_heads:
47
+ raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.")
48
+
49
+ factor = num_heads // num_kv_heads
50
+ kv_shape = kv_tensor.shape
51
+ B = kv_shape[0]
52
+ S = kv_shape[2]
53
+ D = kv_shape[3]
54
+ kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
55
+ for i in range(num_heads):
56
+ j = i // factor
57
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
58
+ return kv_res
59
+
60
+
61
+ def calculate_qk(q, k, atten_mask, pse, scale):
62
+ if pse is None or len(pse.shape) == 0:
63
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
64
+ else:
65
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
66
+ if atten_mask is None or len(atten_mask.shape) == 0:
67
+ return qk
68
+ else:
69
+ qk = qk + atten_mask.bool() * (-40000.0) # -10000
70
+ return qk
71
+
72
+
73
+ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
74
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
75
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
76
+ if drop_mask is None or len(drop_mask.shape) == 0:
77
+ drop_res = softmax_res
78
+ else:
79
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
80
+ y = torch.matmul(drop_res, v)
81
+ return y, softmax_max, softmax_sum
82
+
83
+
84
+ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
85
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
86
+ if drop_mask is None or len(drop_mask.shape) == 0:
87
+ drop_res = softmax_res.permute(0, 1, 3, 2)
88
+ dp_drop = dp
89
+ else:
90
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
91
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
92
+ dv = torch.matmul(drop_res, dx)
93
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
94
+ dq = torch.matmul(softmax_grad_res, k)
95
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
96
+ return dq, dk, dv
97
+
98
+
99
+ def parse_bsnd_args(query, key, head_num, input_layout):
100
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
101
+ B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
102
+
103
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
104
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
105
+
106
+ if input_layout == "TND":
107
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
108
+ try:
109
+ if input_layout == "BSH":
110
+ B, S1, H1 = query.shape
111
+ _, S2, H2 = key.shape
112
+ D = H1 // N1
113
+ N2 = H2 // D
114
+ elif input_layout == "SBH":
115
+ S1, B, H1 = query.shape
116
+ S2, _, H2 = key.shape
117
+ D = H1 // N1
118
+ N2 = H2 // D
119
+ elif input_layout == "BSND":
120
+ B, S1, N1, D = query.shape
121
+ _, S2, N2, _ = key.shape
122
+ H1 = N1 * D
123
+ H2 = N2 * D
124
+ elif input_layout == "BNSD":
125
+ B, N1, S1, D = query.shape
126
+ _, N2, S2, _ = key.shape
127
+ H1 = N1 * D
128
+ H2 = N2 * D
129
+ except Exception as e:
130
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
131
+
132
+ if D == 0:
133
+ raise ValueError(f"Value D must be non-zero.")
134
+ DTYPE = query.dtype
135
+ return B, S1, S2, N1, N2, D, H1, H2, DTYPE
136
+
137
+
138
+ def convert_from_bnsd(_input, input_layout):
139
+ if input_layout == "BSH":
140
+ # (B,N,S,D)=>(B,S,N*D)
141
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
142
+ elif input_layout == "SBH":
143
+ # (B,N,S,D)=>(S,B,N*D)
144
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
145
+ elif input_layout == "BSND":
146
+ # (B,N,S,D)=>(B,S,N,D)
147
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
148
+ elif input_layout == "TND":
149
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
150
+ else:
151
+ out = _input
152
+ return out
153
+
154
+
155
+ def convert_to_bnsd(_input, n, input_layout):
156
+ # 默认"BNSD"无需处理
157
+ if input_layout == "BSH":
158
+ # (B,S,N*D)=>(B,N,S,D)
159
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
160
+ elif input_layout == "SBH":
161
+ # (S,B,N*D)=>(B,N,S,D)
162
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
163
+ elif input_layout == "BSND":
164
+ # (B,S,N,D)=>(B,N,S,D)
165
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
166
+ elif input_layout == "TND":
167
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
168
+ else:
169
+ out = _input
170
+ if out.dim() != 4:
171
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
172
+ return out.to(gtype)
173
+
174
+
175
+ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
176
+ """
177
+ # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
178
+ ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
179
+ """
180
+ shape = [S1, S2]
181
+
182
+ if atten_mask is not None:
183
+ # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
184
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
185
+ logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
186
+
187
+ if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
188
+ if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
189
+ if sparse_mode == 2:
190
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
191
+ elif sparse_mode == 3:
192
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
193
+ elif sparse_mode == 4:
194
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
195
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
196
+ atten_mask = atten_mask_u + atten_mask_l
197
+ logger.debug(f"反向转换atten_mask {atten_mask.shape}")
198
+ return atten_mask.to(dtype)
199
+
200
+ return atten_mask.to(dtype)
201
+
202
+ if atten_mask is not None:
203
+ if atten_mask.dim() == 2:
204
+ if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
205
+ raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
206
+ shape = [S1, S2]
207
+ elif atten_mask.dim() == 4:
208
+ if atten_mask.shape[1] == 1:
209
+ shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
210
+ else:
211
+ shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
212
+
213
+ if sparse_mode == 0:
214
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
215
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
216
+ atten_mask = atten_mask_u + atten_mask_l
217
+ elif sparse_mode == 1: # no sparse
218
+ atten_mask = torch.from_numpy(np.zeros(shape))
219
+ elif sparse_mode == 2:
220
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
221
+ elif sparse_mode == 3:
222
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
223
+ elif sparse_mode == 4:
224
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
225
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
226
+ atten_mask = atten_mask_u + atten_mask_l
227
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
228
+ # 因此可以认为FA的输入已经是正确的atten_mask了
229
+ return atten_mask.to(dtype)
230
+
231
+
232
+ def generate_kv(key, value, N1, N2):
233
+ # N不等长适配by cdy
234
+ if not (N1 == N2):
235
+ k_new = broadcast_kv(N1, N2, key, key.dtype)
236
+ v_new = broadcast_kv(N1, N2, value, value.dtype)
237
+ else:
238
+ k_new = key
239
+ v_new = value
240
+ return k_new, v_new
241
+
242
+
243
+ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
244
+ """
245
+ attention = softmax(QK^T/sqrt(d))V
246
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
247
+ """
248
+ logger.info("Using QKV to rebuild original softmax")
249
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
250
+ softmax_res, x_max, x_sum = softmax_forward(qk)
251
+ return softmax_res
252
+
253
+
254
+ def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
255
+ """
256
+ attention = softmax(QK^T/sqrt(d))V
257
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
258
+ """
259
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
260
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
261
+ if softmax_max.shape[-1] == 0:
262
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
263
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
264
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
265
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
266
+ return softmax_res
267
+
268
+
269
+ def npu_fusion_attention_forward_patch(*args, **kwargs):
270
+ # query, key, value, head_num, input_layout
271
+ if len(args) != 5:
272
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
273
+
274
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4])
275
+ if N1 == N2 and S1 == S2:
276
+ logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
277
+ else:
278
+ logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
279
+ if not (N1 % N2 == 0 and N1 >= N2):
280
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
281
+
282
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
283
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
284
+
285
+ new_kwargs = {"keep_prob": 1,
286
+ "scale": kwargs.get("scale", 1 / (D ** 0.5)),
287
+ "sparse_mode": kwargs.get("sparse_mode", 0),
288
+ "prefix": kwargs.get("prefix"),
289
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
290
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
291
+ "pse": kwargs.get("pse"),
292
+ "padding_mask": kwargs.get("padding_mask"),
293
+ "atten_mask": kwargs.get("atten_mask")}
294
+
295
+ return args, dims_kwargs, new_kwargs
296
+
297
+
298
+ def npu_fusion_attention_backward_patch(*args, **kwargs):
299
+ if len(args) != 6:
300
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
301
+
302
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
303
+ if N1 == N2 and S1 == S2:
304
+ logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
305
+ else:
306
+ logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
307
+ if not (N1 % N2 == 0 and N1 >= N2):
308
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
309
+
310
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
311
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
312
+
313
+ new_kwargs = {"keep_prob": 1,
314
+ "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
315
+ "sparse_mode": kwargs.get("sparse_mode", 0),
316
+ "prefix": kwargs.get("prefix"),
317
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
318
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
319
+ "pse": kwargs.get("pse"),
320
+ "padding_mask": kwargs.get("padding_mask"),
321
+ "softmax_max": kwargs.get("softmax_max"),
322
+ "softmax_sum": kwargs.get("softmax_sum"),
323
+ "softmax_in": kwargs.get("softmax_in"),
324
+ "attention_in": kwargs.get("attention_in"),
325
+ "seed": kwargs.get("seed", 0),
326
+ "offset": kwargs.get("offset", 0),
327
+ "numels": kwargs.get("numels", 0),
328
+ "atten_mask": kwargs.get("atten_mask")}
329
+
330
+ return args, dims_kwargs, new_kwargs
331
+
332
+
333
+ def npu_fusion_attention(*args, **kwargs):
334
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
335
+ query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4]
336
+ N1 = dims_kwargs.get("N1")
337
+ N2 = dims_kwargs.get("N2")
338
+ S1 = dims_kwargs.get("S1")
339
+ S2 = dims_kwargs.get("S2")
340
+ B = dims_kwargs.get("B")
341
+ DTYPE = dims_kwargs.get("DTYPE")
342
+ atten_mask = new_kwargs.get("atten_mask")
343
+ keep_prob = new_kwargs.get("keep_prob")
344
+ sparse_mode = new_kwargs.get("sparse_mode")
345
+ pre_tockens = new_kwargs.get("pre_tockens")
346
+ next_tockens = new_kwargs.get("next_tockens")
347
+ pse = new_kwargs.get("pse")
348
+ scale = new_kwargs.get("scale")
349
+
350
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
351
+ query = convert_to_bnsd(query, N1, input_layout)
352
+ key = convert_to_bnsd(key, N2, input_layout)
353
+ value = convert_to_bnsd(value, N2, input_layout)
354
+ k_new, v_new = generate_kv(key, value, N1, N2)
355
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
356
+ drop_mask=None, atten_mask=atten_mask,
357
+ pse=pse, scale=scale,
358
+ keep_prob=keep_prob)
359
+ if out_golden.dim() == 5:
360
+ out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
361
+ out_golden.size(4))
362
+ out_golden = convert_from_bnsd(out_golden, input_layout)
363
+
364
+ return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
365
+
366
+
367
+ def npu_fusion_attention_grad(*args, **kwargs):
368
+ # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
369
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
370
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
371
+ N1 = dims_kwargs.get("N1")
372
+ N2 = dims_kwargs.get("N2")
373
+ S1 = dims_kwargs.get("S1")
374
+ S2 = dims_kwargs.get("S2")
375
+ B = dims_kwargs.get("B")
376
+ D = dims_kwargs.get("D")
377
+ DTYPE = dims_kwargs.get("DTYPE")
378
+ atten_mask = new_kwargs.get("atten_mask")
379
+ keep_prob = new_kwargs.get("keep_prob")
380
+ sparse_mode = new_kwargs.get("sparse_mode")
381
+ pre_tockens = new_kwargs.get("pre_tockens")
382
+ next_tockens = new_kwargs.get("next_tockens")
383
+ pse = new_kwargs.get("pse")
384
+ softmax_max = new_kwargs.get("softmax_max")
385
+ softmax_sum = new_kwargs.get("softmax_sum")
386
+ scale_value = new_kwargs.get("scale_value")
387
+
388
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
389
+ query = convert_to_bnsd(query, N1, input_layout)
390
+ dx = convert_to_bnsd(dx, N1, input_layout)
391
+ key = convert_to_bnsd(key, N2, input_layout)
392
+ value = convert_to_bnsd(value, N2, input_layout)
393
+ k_new, v_new = generate_kv(key, value, N1, N2)
394
+
395
+ if softmax_build_mode == "QKV":
396
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
397
+ else:
398
+ softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
399
+
400
+ dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
401
+
402
+ # N不等长适配by cdy
403
+ if not (N1 == N2):
404
+ if N2 == 0:
405
+ raise ValueError("dims_kwargs.N2 must be non-zero.")
406
+ G = int(N1 / N2)
407
+ dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
408
+ dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
409
+
410
+ if dq.dim() == 5:
411
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
412
+ if dk.dim() == 5:
413
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
414
+ if dv.dim() == 5:
415
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
416
+
417
+ dq = convert_from_bnsd(dq, input_layout)
418
+ dk = convert_from_bnsd(dk, input_layout)
419
+ dv = convert_from_bnsd(dv, input_layout)
420
+
421
+ return dq.cpu(), dk.cpu(), dv.cpu()
@@ -0,0 +1,15 @@
1
+ import torch
2
+
3
+
4
+ def npu_rms_norm(x, gamma, epsilon=1e-5):
5
+ rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon)
6
+ res = x * rstd * gamma
7
+ return res.cpu(), rstd.float().cpu()
8
+
9
+
10
+ def npu_rms_norm_backward(grad, x, gamma, rstd):
11
+ mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True)
12
+ grad_x = (grad * gamma - x * rstd * mean_gy) * rstd
13
+ grad_gamma = x * grad * rstd
14
+ return grad_x.cpu(), grad_gamma.cpu()
15
+
@@ -0,0 +1,52 @@
1
+ import torch
2
+
3
+
4
+ def npu_rotary_mul(x, r1, r2):
5
+ x1, x2 = torch.chunk(x, 2, -1)
6
+ x_new = torch.cat((-x2, x1), dim=-1)
7
+ output = r1 * x + r2 * x_new
8
+ return output.cpu()
9
+
10
+
11
+ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
12
+ x.requires_grad = True
13
+ r1.requires_grad = True
14
+ r2.requires_grad = True
15
+ # golden
16
+ x1, x2 = torch.chunk(x, 2, -1)
17
+ x_new = torch.cat((-x2, x1), dim=-1)
18
+ golden_tensor = r1 * x + r2 * x_new
19
+ golden_tensor.backward(dy_tensor)
20
+ r1_shape = r1.shape
21
+ r1_grad = torch.zeros(r1_shape).type(torch.float32)
22
+ r2_grad = torch.zeros(r1_shape).type(torch.float32)
23
+ x1, x2 = torch.chunk(x.float(), 2, -1)
24
+ x_new2 = torch.cat((-x2, x1), dim=-1)
25
+ x_shape = x.shape
26
+ h = x.float()
27
+ grad = dy_tensor.float()
28
+ condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
29
+ ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
30
+ (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
31
+ condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
32
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
33
+ (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
34
+ condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
35
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
36
+ (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
37
+ if condition_1:
38
+ for i in range(x_shape[0]):
39
+ for j in range(x_shape[2]):
40
+ r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :])
41
+ r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :])
42
+ elif condition_2:
43
+ for i in range(x_shape[0]):
44
+ for j in range(x_shape[1]):
45
+ r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :])
46
+ r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :])
47
+ elif condition_3:
48
+ for i in range(x_shape[1]):
49
+ for j in range(x_shape[2]):
50
+ r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
51
+ r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
52
+ return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
@@ -0,0 +1,26 @@
1
+ import torch
2
+
3
+
4
+ def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask):
5
+ if fixed_triu_mask:
6
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
7
+ dtype = x.dtype
8
+ x = (x * scale).masked_fill(mask, value=-10000)
9
+ x = x - torch.max(x, dim=-1, keepdims=True)[0]
10
+ x = torch.exp(x.float())
11
+ y = torch.div(x, torch.sum(x, dim=-1, keepdims=True))
12
+ return y.to(dtype).cpu()
13
+
14
+
15
+ def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask):
16
+ if fixed_triu_mask:
17
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
18
+ dtype = y_grad.dtype
19
+ y_grad = y_grad.float()
20
+ y = y.float()
21
+ x_grad = y_grad * y
22
+ x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True)
23
+ x_grad = x_grad * y
24
+ x_grad = x_grad * scale
25
+ x_grad = x_grad.masked_fill(mask, value=0)
26
+ return x_grad.to(dtype).cpu()
@@ -0,0 +1,55 @@
1
+ import torch
2
+
3
+
4
+ def npu_swiglu(x, dim=-1):
5
+ tensor_dtype = x.dtype
6
+
7
+ inTensors = torch.chunk(x, 2, dim=dim)
8
+ if tensor_dtype == torch.float32:
9
+ tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
10
+ output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
11
+ else:
12
+ tensor_self_float = inTensors[0].type(torch.float)
13
+ tensor_other_float = inTensors[1].type(torch.float)
14
+ tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
15
+ torch.float32) * tensor_other_float
16
+ output_data = tensor_out_float.type(tensor_dtype)
17
+ return output_data.cpu()
18
+
19
+
20
+ def npu_swiglu_backward(grad, x, dim=-1):
21
+ tensor_dtype = grad.dtype
22
+ in_tensors = torch.chunk(x, 2, dim=dim)
23
+ tensor_grad_out = grad
24
+
25
+ if tensor_dtype == torch.float16:
26
+ tensor_out1 = torch.mul(
27
+ torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))),
28
+ tensor_grad_out.type(torch.float32)).type(torch.float16)
29
+ tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32),
30
+ swish(1, in_tensors[0].type(torch.float32))).type(torch.float16)
31
+ output = torch.cat((tensor_out1, tensor_out2), dim)
32
+ elif tensor_dtype == torch.bfloat16:
33
+ tensor_self_float = in_tensors[0].type(torch.float)
34
+ tensor_other_float = in_tensors[1].type(torch.float)
35
+ tensor_gradout_float = tensor_grad_out.type(torch.float)
36
+
37
+ tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type(
38
+ torch.float32) * tensor_other_float
39
+ tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float
40
+ tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim)
41
+ output = tensor_out_float.type(torch.bfloat16)
42
+ else:
43
+ tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out)
44
+ tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0]))
45
+ output = torch.cat((tensor_out1, tensor_out2), dim)
46
+ return output.cpu()
47
+
48
+
49
+ def swish_grad(beta, x):
50
+ return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta
51
+
52
+
53
+ def swish(beta, x):
54
+ return x * torch.sigmoid(beta * x)
55
+
@@ -1,5 +1,7 @@
1
1
  import json
2
+
2
3
  from msprobe.core.common.exceptions import ParseJsonException
4
+ from msprobe.core.common.file_check import FileOpen
3
5
 
4
6
 
5
7
  def parse_json_info_forward_backward(json_path):
@@ -11,7 +13,7 @@ def parse_json_info_forward_backward(json_path):
11
13
  api_name = '.'.join(name_struct[:-1])
12
14
  return api_name
13
15
 
14
- with open(json_path, 'r') as f:
16
+ with FileOpen(json_path, 'r') as f:
15
17
  dump_json = json.load(f)
16
18
 
17
19
  real_data_path = dump_json.get("dump_data_dir")