mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,15 +1,15 @@
1
- import torch
2
-
3
-
4
- def npu_rms_norm(x, gamma, epsilon=1e-5):
5
- rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon)
6
- res = x * rstd * gamma
7
- return res.cpu(), rstd.float().cpu()
8
-
9
-
10
- def npu_rms_norm_backward(grad, x, gamma, rstd):
11
- mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True)
12
- grad_x = (grad * gamma - x * rstd * mean_gy) * rstd
13
- grad_gamma = x * grad * rstd
14
- return grad_x.cpu(), grad_gamma.cpu()
15
-
1
+ import torch
2
+
3
+
4
+ def npu_rms_norm(x, gamma, epsilon=1e-5):
5
+ rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon)
6
+ res = x * rstd * gamma
7
+ return res, rstd.float()
8
+
9
+
10
+ def npu_rms_norm_backward(grad, x, gamma, rstd):
11
+ mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True)
12
+ grad_x = (grad * gamma - x * rstd * mean_gy) * rstd
13
+ grad_gamma = x * grad * rstd
14
+ return grad_x.cpu(), grad_gamma.cpu()
15
+
@@ -1,52 +1,52 @@
1
- import torch
2
-
3
-
4
- def npu_rotary_mul(x, r1, r2):
5
- x1, x2 = torch.chunk(x, 2, -1)
6
- x_new = torch.cat((-x2, x1), dim=-1)
7
- output = r1 * x + r2 * x_new
8
- return output.cpu()
9
-
10
-
11
- def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
12
- x.requires_grad = True
13
- r1.requires_grad = True
14
- r2.requires_grad = True
15
- # golden
16
- x1, x2 = torch.chunk(x, 2, -1)
17
- x_new = torch.cat((-x2, x1), dim=-1)
18
- golden_tensor = r1 * x + r2 * x_new
19
- golden_tensor.backward(dy_tensor)
20
- r1_shape = r1.shape
21
- r1_grad = torch.zeros(r1_shape).type(torch.float32)
22
- r2_grad = torch.zeros(r1_shape).type(torch.float32)
23
- x1, x2 = torch.chunk(x.float(), 2, -1)
24
- x_new2 = torch.cat((-x2, x1), dim=-1)
25
- x_shape = x.shape
26
- h = x.float()
27
- grad = dy_tensor.float()
28
- condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
29
- ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
30
- (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
31
- condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
32
- ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
33
- (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
34
- condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
35
- ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
36
- (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
37
- if condition_1:
38
- for i in range(x_shape[0]):
39
- for j in range(x_shape[2]):
40
- r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :])
41
- r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :])
42
- elif condition_2:
43
- for i in range(x_shape[0]):
44
- for j in range(x_shape[1]):
45
- r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :])
46
- r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :])
47
- elif condition_3:
48
- for i in range(x_shape[1]):
49
- for j in range(x_shape[2]):
50
- r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
51
- r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
52
- return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
1
+ import torch
2
+
3
+
4
+ def npu_rotary_mul(x, r1, r2):
5
+ x1, x2 = torch.chunk(x, 2, -1)
6
+ x_new = torch.cat((-x2, x1), dim=-1)
7
+ output = r1 * x + r2 * x_new
8
+ return output
9
+
10
+
11
+ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
12
+ x.requires_grad = True
13
+ r1.requires_grad = True
14
+ r2.requires_grad = True
15
+ # golden
16
+ x1, x2 = torch.chunk(x, 2, -1)
17
+ x_new = torch.cat((-x2, x1), dim=-1)
18
+ golden_tensor = r1 * x + r2 * x_new
19
+ golden_tensor.backward(dy_tensor)
20
+ r1_shape = r1.shape
21
+ r1_grad = torch.zeros(r1_shape).type(torch.float32)
22
+ r2_grad = torch.zeros(r1_shape).type(torch.float32)
23
+ x1, x2 = torch.chunk(x.float(), 2, -1)
24
+ x_new2 = torch.cat((-x2, x1), dim=-1)
25
+ x_shape = x.shape
26
+ h = x.float()
27
+ grad = dy_tensor.float()
28
+ condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
29
+ ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
30
+ (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
31
+ condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
32
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
33
+ (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
34
+ condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
35
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
36
+ (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
37
+ if condition_1:
38
+ for i in range(x_shape[0]):
39
+ for j in range(x_shape[2]):
40
+ r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :])
41
+ r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :])
42
+ elif condition_2:
43
+ for i in range(x_shape[0]):
44
+ for j in range(x_shape[1]):
45
+ r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :])
46
+ r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :])
47
+ elif condition_3:
48
+ for i in range(x_shape[1]):
49
+ for j in range(x_shape[2]):
50
+ r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
51
+ r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
52
+ return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
@@ -1,26 +1,26 @@
1
- import torch
2
-
3
-
4
- def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask):
5
- if fixed_triu_mask:
6
- mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
7
- dtype = x.dtype
8
- x = (x * scale).masked_fill(mask, value=-10000)
9
- x = x - torch.max(x, dim=-1, keepdims=True)[0]
10
- x = torch.exp(x.float())
11
- y = torch.div(x, torch.sum(x, dim=-1, keepdims=True))
12
- return y.to(dtype).cpu()
13
-
14
-
15
- def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask):
16
- if fixed_triu_mask:
17
- mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
18
- dtype = y_grad.dtype
19
- y_grad = y_grad.float()
20
- y = y.float()
21
- x_grad = y_grad * y
22
- x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True)
23
- x_grad = x_grad * y
24
- x_grad = x_grad * scale
25
- x_grad = x_grad.masked_fill(mask, value=0)
26
- return x_grad.to(dtype).cpu()
1
+ import torch
2
+
3
+
4
+ def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask):
5
+ if fixed_triu_mask:
6
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
7
+ dtype = x.dtype
8
+ x = (x * scale).masked_fill(mask, value=-10000)
9
+ x = x - torch.max(x, dim=-1, keepdims=True)[0]
10
+ x = torch.exp(x.float())
11
+ y = torch.div(x, torch.sum(x, dim=-1, keepdims=True))
12
+ return y.to(dtype)
13
+
14
+
15
+ def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask):
16
+ if fixed_triu_mask:
17
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
18
+ dtype = y_grad.dtype
19
+ y_grad = y_grad.float()
20
+ y = y.float()
21
+ x_grad = y_grad * y
22
+ x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True)
23
+ x_grad = x_grad * y
24
+ x_grad = x_grad * scale
25
+ x_grad = x_grad.masked_fill(mask, value=0)
26
+ return x_grad.to(dtype).cpu()
@@ -1,55 +1,55 @@
1
- import torch
2
-
3
-
4
- def npu_swiglu(x, dim=-1):
5
- tensor_dtype = x.dtype
6
-
7
- inTensors = torch.chunk(x, 2, dim=dim)
8
- if tensor_dtype == torch.float32:
9
- tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
10
- output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
11
- else:
12
- tensor_self_float = inTensors[0].type(torch.float)
13
- tensor_other_float = inTensors[1].type(torch.float)
14
- tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
15
- torch.float32) * tensor_other_float
16
- output_data = tensor_out_float.type(tensor_dtype)
17
- return output_data.cpu()
18
-
19
-
20
- def npu_swiglu_backward(grad, x, dim=-1):
21
- tensor_dtype = grad.dtype
22
- in_tensors = torch.chunk(x, 2, dim=dim)
23
- tensor_grad_out = grad
24
-
25
- if tensor_dtype == torch.float16:
26
- tensor_out1 = torch.mul(
27
- torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))),
28
- tensor_grad_out.type(torch.float32)).type(torch.float16)
29
- tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32),
30
- swish(1, in_tensors[0].type(torch.float32))).type(torch.float16)
31
- output = torch.cat((tensor_out1, tensor_out2), dim)
32
- elif tensor_dtype == torch.bfloat16:
33
- tensor_self_float = in_tensors[0].type(torch.float)
34
- tensor_other_float = in_tensors[1].type(torch.float)
35
- tensor_gradout_float = tensor_grad_out.type(torch.float)
36
-
37
- tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type(
38
- torch.float32) * tensor_other_float
39
- tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float
40
- tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim)
41
- output = tensor_out_float.type(torch.bfloat16)
42
- else:
43
- tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out)
44
- tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0]))
45
- output = torch.cat((tensor_out1, tensor_out2), dim)
46
- return output.cpu()
47
-
48
-
49
- def swish_grad(beta, x):
50
- return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta
51
-
52
-
53
- def swish(beta, x):
54
- return x * torch.sigmoid(beta * x)
55
-
1
+ import torch
2
+
3
+
4
+ def npu_swiglu(x, dim=-1):
5
+ tensor_dtype = x.dtype
6
+
7
+ inTensors = torch.chunk(x, 2, dim=dim)
8
+ if tensor_dtype == torch.float32:
9
+ tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
10
+ output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
11
+ else:
12
+ tensor_self_float = inTensors[0].type(torch.float)
13
+ tensor_other_float = inTensors[1].type(torch.float)
14
+ tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
15
+ torch.float32) * tensor_other_float
16
+ output_data = tensor_out_float.type(tensor_dtype)
17
+ return output_data
18
+
19
+
20
+ def npu_swiglu_backward(grad, x, dim=-1):
21
+ tensor_dtype = grad.dtype
22
+ in_tensors = torch.chunk(x, 2, dim=dim)
23
+ tensor_grad_out = grad
24
+
25
+ if tensor_dtype == torch.float16:
26
+ tensor_out1 = torch.mul(
27
+ torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))),
28
+ tensor_grad_out.type(torch.float32)).type(torch.float16)
29
+ tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32),
30
+ swish(1, in_tensors[0].type(torch.float32))).type(torch.float16)
31
+ output = torch.cat((tensor_out1, tensor_out2), dim)
32
+ elif tensor_dtype == torch.bfloat16:
33
+ tensor_self_float = in_tensors[0].type(torch.float)
34
+ tensor_other_float = in_tensors[1].type(torch.float)
35
+ tensor_gradout_float = tensor_grad_out.type(torch.float)
36
+
37
+ tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type(
38
+ torch.float32) * tensor_other_float
39
+ tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float
40
+ tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim)
41
+ output = tensor_out_float.type(torch.bfloat16)
42
+ else:
43
+ tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out)
44
+ tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0]))
45
+ output = torch.cat((tensor_out1, tensor_out2), dim)
46
+ return output.cpu()
47
+
48
+
49
+ def swish_grad(beta, x):
50
+ return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta
51
+
52
+
53
+ def swish(beta, x):
54
+ return x * torch.sigmoid(beta * x)
55
+
@@ -1,2 +1,2 @@
1
- from .parse_json import parse_json_info_forward_backward
2
- from .utils import seed_all
1
+ from .parse_json import parse_json_info_forward_backward
2
+ from .utils import seed_all
@@ -1,14 +1,14 @@
1
- from ptdbg_ascend import compare
2
-
3
- pkl_path = "%s"
4
- dump_data_dir = "%s"
5
-
6
- dump_path_param = {
7
- "npu_pkl_path": ,
8
- "bench_pkl_path": ,
9
- "npu_dump_data_dir": ,
10
- "bench_dump_data_dir": ,
11
- "is_print_compare_log": True
12
- }
13
-
14
- compare(dump_path_param, output_path="", stack_mode=%s)
1
+ from ptdbg_ascend import compare
2
+
3
+ pkl_path = "%s"
4
+ dump_data_dir = "%s"
5
+
6
+ dump_path_param = {
7
+ "npu_pkl_path": ,
8
+ "bench_pkl_path": ,
9
+ "npu_dump_data_dir": ,
10
+ "bench_dump_data_dir": ,
11
+ "is_print_compare_log": True
12
+ }
13
+
14
+ compare(dump_path_param, output_path="", stack_mode=%s)
@@ -1,32 +1,21 @@
1
- import os
2
- import time
3
- import sys
4
- from msprobe.pytorch.common.utils import get_rank_if_initialized
5
- from msprobe.core.common.log import BaseLogger
6
- from msprobe.core.common.exceptions import DistributedNotInitializedError
7
-
8
-
9
- class PyTorchLogger(BaseLogger):
10
- def __init__(self):
11
- super().__init__()
12
-
13
- def get_rank(self):
14
- try:
15
- current_rank = get_rank_if_initialized()
16
- except DistributedNotInitializedError:
17
- current_rank = None
18
- return current_rank
19
-
20
- def _print_log(self, level, msg, end='\n'):
21
- current_rank = self.get_rank()
22
- current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
23
- pid = os.getpid()
24
- if current_rank is not None:
25
- full_msg = f"{current_time} ({pid}) [rank {current_rank}] [{level}] {msg}"
26
- else:
27
- full_msg = f"{current_time} ({pid}) [{level}] {msg}"
28
- print(full_msg, end=end)
29
- sys.stdout.flush()
30
-
31
-
1
+ import os
2
+ import time
3
+ import sys
4
+ from msprobe.pytorch.common.utils import get_rank_if_initialized
5
+ from msprobe.core.common.log import BaseLogger
6
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
7
+
8
+
9
+ class PyTorchLogger(BaseLogger):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def get_rank(self):
14
+ try:
15
+ current_rank = get_rank_if_initialized()
16
+ except DistributedNotInitializedError:
17
+ current_rank = None
18
+ return current_rank
19
+
20
+
32
21
  logger = PyTorchLogger()
@@ -1,39 +1,39 @@
1
- import json
2
-
3
- from msprobe.core.common.exceptions import ParseJsonException
4
- from msprobe.core.common.file_check import FileOpen
5
-
6
-
7
- def parse_json_info_forward_backward(json_path):
8
- def parse_data_name_with_pattern(data_name, pattern):
9
- name_struct = data_name.split('.')
10
- if not name_struct[-1] == pattern:
11
- raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
12
- f"{data_name} in file {json_path}")
13
- api_name = '.'.join(name_struct[:-1])
14
- return api_name
15
-
16
- with FileOpen(json_path, 'r') as f:
17
- dump_json = json.load(f)
18
-
19
- real_data_path = dump_json.get("dump_data_dir")
20
- dump_data = dump_json.get("data")
21
- if not dump_data:
22
- raise ParseJsonException(ParseJsonException.InvalidDumpJson, "dump数据中没有data字段")
23
-
24
- forward_data = {}
25
- backward_data = {}
26
- for data_name, data_item in dump_data.items():
27
- if "Module" in data_name:
28
- continue
29
- if "forward" in data_name:
30
- api_name = parse_data_name_with_pattern(data_name, "forward")
31
- forward_data.update({api_name: data_item})
32
- elif "backward" in data_name:
33
- api_name = parse_data_name_with_pattern(data_name, "backward")
34
- backward_data.update({api_name: data_item})
35
- else:
36
- raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
37
- f"{data_name} in file {json_path}.")
38
-
39
- return forward_data, backward_data, real_data_path
1
+ import json
2
+
3
+ from msprobe.core.common.exceptions import ParseJsonException
4
+ from msprobe.core.common.file_utils import FileOpen
5
+
6
+
7
+ def parse_json_info_forward_backward(json_path):
8
+ def parse_data_name_with_pattern(data_name, pattern):
9
+ name_struct = data_name.split('.')
10
+ if not name_struct[-1] == pattern:
11
+ raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
12
+ f"{data_name} in file {json_path}")
13
+ api_name = '.'.join(name_struct[:-1])
14
+ return api_name
15
+
16
+ with FileOpen(json_path, 'r') as f:
17
+ dump_json = json.load(f)
18
+
19
+ real_data_path = dump_json.get("dump_data_dir")
20
+ dump_data = dump_json.get("data")
21
+ if not dump_data:
22
+ raise ParseJsonException(ParseJsonException.InvalidDumpJson, "dump数据中没有data字段")
23
+
24
+ forward_data = {}
25
+ backward_data = {}
26
+ for data_name, data_item in dump_data.items():
27
+ if "Module" in data_name:
28
+ continue
29
+ if "forward" in data_name:
30
+ api_name = parse_data_name_with_pattern(data_name, "forward")
31
+ forward_data.update({api_name: data_item})
32
+ elif "backward" in data_name:
33
+ api_name = parse_data_name_with_pattern(data_name, "backward")
34
+ backward_data.update({api_name: data_item})
35
+ else:
36
+ raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
37
+ f"{data_name} in file {json_path}.")
38
+
39
+ return forward_data, backward_data, real_data_path