mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,151 @@
1
+ # 无标杆工具场景验证和性能基线报告
2
+
3
+ ## 环境信息
4
+
5
+ NPU:Atlas A2 训练系列产品
6
+
7
+ CPU:
8
+
9
+ ![输入图片说明](img/cpu_info.png)
10
+
11
+ Torch:2.1.0
12
+
13
+ CANN:8.0.T5
14
+
15
+ 除上述环境信息影响性能外,API的数量、种类以及Shape都会对性能产生影响,因此本次选取不同场景网络和不同算子进行测试。
16
+
17
+ ## 模型信息和性能基线
18
+
19
+ 大模型在使用msprobe工具dump数据时,建议先简化模型层数,减少dump数据量。
20
+
21
+ 以下场景的性能基线测试数据均为多次测试后取平均值,因此实际运行时性能数据可能会根据环境状态稍有浮动。
22
+
23
+
24
+
25
+ ### LLaMA2-7B
26
+
27
+ NUM_LAYER:1,1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelLink
28
+ 其中,softmax算子为FLOAT32,输入输出均为2G大小,为模型最大显存开销的API。
29
+
30
+ 在该模型下、对无标杆工具处理模式、插装范围、扰动方式组合下性能和显存基线进行覆盖。
31
+
32
+ 性能基线报告
33
+ 其中耗时为训练10步,去除第一步耗时所得的平均每步耗时。
34
+
35
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
36
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
37
+ | / | / | / | / | 0.24 | 13.69 | 1 | 1 | 混精模式基线 |
38
+ | check | 前 | ["softmax"] | improve_precision | 0.26 | 13.69 | 1.08 | 1 | softmax本身为高精度,跳过 |
39
+ | check | 前 | ["softmax"] | add_noise | 0.54 | 19.17 | 2.25 | 1.40 | |
40
+ | check | 前 | ["softmax"] | bit_noise | 0.56 | 19.17 | 2.33 | 1.40 | |
41
+ | check | 前 | ["softmax"] | change_value | 0.48 | 14.9 | 2 | 1.09 | |
42
+ | check | 前 | ["softmax"] | no_change | 0.47 | 14.9 | 1.96 | 1.09 | |
43
+ | check | 前 | ["softmax"] | to_cpu | 26.45 | 22.67 | 110.21 | 1.66 | 不建议整网 |
44
+ | check | 前 | ["matmul"] | improve_precision | 0.57 | 13.69 | 2.38 | 1 | |
45
+ | check | 前 | ["matmul"] | change_value | 0.48 | 13.69 | 2 | 1 | |
46
+ | check | 前 | ["matmul"] | to_cpu | 78.43 | 19.20 | 326.79 | 1.40 | 不建议整网 |
47
+ | check | 前 | [] | improve_precision | 3.45 | 18.79 | 14.37 | 1.37 | |
48
+ | check | 前 | [] | add_noise | 4.67 | 19.17 | 19.46 | 1.40 | |
49
+ | check | 前 | [] | bit_noise | 16.99 | 19.17 | 70.79 | 1.40 | |
50
+ | check | 前 | [] | no_change | 3.22 | 14.90 | 13.42 | 1.09 | |
51
+ | check | 反 | ["softmax"] | improve_precision | 6.23 | 25.69 | 25.96 | 1.88 | 不建议整网 |
52
+ | check | 反 | ["softmax"] | change_value | 22.76 | 25.69 | 94.83 | 1.88 | 不建议整网 |
53
+ | check | 反 | ["softmax"] | to_cpu | 141.71 | 26.19 | 590.46 | 1.91 | 不建议整网 |
54
+ | fix | 前 | ["softmax"] | to_cpu | 9.70 | 16.67 | 40.42 | 1.22 | 不支持整网、不支持反向 |
55
+ | fix | 前 | ["softmax"] | improve_precision | 0.26 | 14.67 | 1.08 | 1.07 | 不支持整网、不支持反向 |
56
+ | 预热 | 前 | [] | improve_precision | 155.07 | 24.79 | 646.13 | 1.81 | 低精度模型基线、只测预热的迭代 |
57
+ | 预热 | 反 | [] | improve_precision | 72.29 | 22.01 | 301.21 | 1.61 | 低精度模型基线、只测预热的迭代,grad_output为高精度的算子跳过 |
58
+
59
+ ### Aquila2-7B
60
+
61
+ NUM_LAYER:1,1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelLink
62
+
63
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
64
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
65
+ | / | / | / | / | 0.17 | 13.66 | 1 | 1 | 混精模式基线 |
66
+ | check | 前 | [] | improve_precision | 1.57 | 14.24 | 9.24 | 1.04 | |
67
+ | check | 反 | [] | add_noise | 21.05 | 14.19 | 123.82 | 1.04 | |
68
+ | fix | 前 | [] | improve_precision | 0.95 | 15.55 | 5.59 | 1.14 | |
69
+
70
+ ### Baichuan2-7B
71
+
72
+ NUM_LAYER:1,1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelLink
73
+
74
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
75
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
76
+ | / | / | / | / | 0.26 | 12.12 | 1 | 1 | 混精模式基线 |
77
+ | check | 前 | [] | improve_precision | 1.02 | 12.27 | 3.92 | 1.01 | |
78
+ | check | 反 | [] | add_noise | 11.15 | 12.67 | 42.88 | 1.05 | |
79
+ | fix | 前 | [] | improve_precision | 0.95 | 12.82 | 3.65 | 1.06 | |
80
+
81
+ ### Bloom-7B
82
+
83
+ NUM_LAYER:1,1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelLink
84
+
85
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
86
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
87
+ | / | / | / | / | 0.14 | 9.51 | 1 | 1 | 混精模式基线 |
88
+ | check | 前 | [] | improve_precision | 1.64 | 11.58 | 11.71 | 1.22 | |
89
+ | check | 反 | [] | add_noise | 17.15 | 9.51 | 122.5 | 1 | |
90
+ | fix | 前 | [] | improve_precision | 0.87 | 10.62 | 6.21 | 1.12 | |
91
+
92
+ ### Interlm-7B
93
+
94
+ NUM_LAYER:1,1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelLink
95
+
96
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
97
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
98
+ | / | / | / | / | 0.13 | 10.76 | 1 | 1 | 混精模式基线 |
99
+ | check | 前 | [] | improve_precision | 1.19 | 11.68 | 9.15 | 1.09 | |
100
+ | check | 反 | [] | add_noise | 11.69 | 10.89 | 89.92 | 1.01 | |
101
+ | fix | 前 | [] | improve_precision | 0.75 | 11.68 | 5.77 | 1.09 | |
102
+
103
+ ### Qwen-7B
104
+
105
+ NUM_LAYER:1,1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelLink
106
+
107
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
108
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
109
+ | / | / | / | / | 0.28 | 18.41 | 1 | 1 | 混精模式基线 |
110
+ | check | 前 | [] | improve_precision | 2.34 | 23.18 | 8.36 | 1.26 | |
111
+ | check | 反 | [] | add_noise | 22.07 | 19.47 | 78.82 | 1.06 | |
112
+ | fix | 前 | [] | improve_precision | 1.31 | 21.11 | 4.68 | 1.15 | |
113
+
114
+ ### Gemma-7B
115
+
116
+ NUM_LAYER:1,1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelLink
117
+
118
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
119
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
120
+ | / | / | / | / | 0.15 | 11.06 | 1 | 1 | 混精模式基线 |
121
+ | check | 前 | [] | improve_precision | 1.49 | 13.17 | 9.93 | 1.19 | |
122
+ | check | 反 | [] | add_noise | 16.69 | 11.06 | 111.27 | 1 | |
123
+ | fix | 前 | [] | improve_precision | 0.87 | 12.25 | 5.8 | 1.11 | |
124
+
125
+ ### ResNet50-Cifar
126
+ 1卡,主要数据类型:FLOAT16,模型来源: ascend/ModelZoo-PyTorch。
127
+ 主要算子为conv2d,每个step有51个, 因此对conv2d进行检测。
128
+ CV模型、依赖mmcv实现(如果不修改mmcv代码、工具无法获取step信息和反向信息)。
129
+
130
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
131
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
132
+ | / | / | / | / | 0.09 | 7.63 | 1 | 1 | 基线 |
133
+ | check | 前 | ["conv2d"] | improve_precision | 0.889 | 7.94 | 9.81 | 1.04 | |
134
+ | fix | 前 | ["conv2d"] | improve_precision | 0.328 | 7.47 | 3.64 | 0.91 | |
135
+ | fix | 前 | ["conv2d"] | to_cpu | 12.23 | 7.47 | 135.88 | 0.91 | |
136
+
137
+ ### OpenSora1.0
138
+
139
+ 4卡,主要数据类型:FLOAT16,模型来源: ascend/ModelZoo-PyTorch
140
+
141
+ 每张卡每个step中linear算子个数为257个,FA算子个数为83(FA算子反向无效)。
142
+
143
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
144
+ |--------------------------------|-----------------------------------|-----------------|----------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|--------------------------------|
145
+ | / | / | / | / | 0.99 | 17.61 | 1 | 1 | 混精模式基线 |
146
+ | check | 前 | ["linear","npu_fusion_attention"] | improve_precision | 3.88 | 17.61 | 3.92 | 1 | |
147
+ | check | 前 | ["linear","npu_fusion_attention"] | add_noise | 3.46 | 17.61 | 3.49 | 1 | |
148
+ | check | 反 | ["linear"] | improve_precision | 12.61 | 17.61 | 12.74 | 1 | |
149
+ | check | 反 | ["linear"] | add_noise | 9.8 | 17.61 | 9.90 | 1 | |
150
+ | fix | 前 | ["linear"] | to_cpu | 18.83 | 17.61 | 19.02 | 1 | |
151
+ | fix | 前 | ["linear"] | improve_precision | 2.83 | 17.61 | 2.86 | 1 | |
@@ -52,6 +52,7 @@ class ThresholdConfig:
52
52
 
53
53
  DTYPE_PER_THD = {
54
54
  torch.float16: 1.002,
55
+ torch.bfloat16: 1.004,
55
56
  torch.float32: 1.0002,
56
57
  }
57
58
  BENCHMARK_THD_DICT = {
@@ -60,6 +61,8 @@ class ThresholdConfig:
60
61
  torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
61
62
  }
62
63
 
64
+ TENSOR_SPLIT_MAX_CHUNK = 128
65
+
63
66
 
64
67
  class PreheatConfig:
65
68
  IF_PREHEAT = "if_preheat"
@@ -96,3 +96,7 @@ class TorchC:
96
96
  add = torch._C._VariableFunctionsClass.add
97
97
  bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor
98
98
  clone = torch._C._VariableFunctionsClass.clone
99
+ clamp = torch._C._VariableFunctionsClass.clamp
100
+ tensor_split = torch._C._VariableFunctionsClass.tensor_split
101
+ stack = torch._C._VariableFunctionsClass.stack
102
+ reshape = torch._C._VariableFunctionsClass.reshape
@@ -2,7 +2,7 @@ import torch
2
2
  from msprobe.core.common.exceptions import FreeBenchmarkException
3
3
  from msprobe.pytorch.free_benchmark import logger
4
4
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams, data_pre_deal
6
6
  from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
7
7
  from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
8
8
  FuzzHandlerFactory,
@@ -16,7 +16,6 @@ class GradSaver:
16
16
  self.handler_params = handler_params
17
17
  self.api_name = handler_params.api_name
18
18
  self.origin_func = origin_func
19
- self.data_params = DataParams()
20
19
  self.is_compare = True
21
20
  self.kwargs = dict()
22
21
  self.perturbed_grad_input = tuple()
@@ -61,28 +60,25 @@ class GradSaver:
61
60
  _index += 1
62
61
 
63
62
  def compare_grad_results(self, handler, origin_grad, perturbed_grad, index):
64
- # TODO get dtype?
65
- self.data_params.original_result = origin_grad
66
- self.data_params.perturbed_result = perturbed_grad
67
- self.data_params.grad_unequal_flag = False
68
- self.data_params.valid_input_index = index
63
+ data_params = DataParams()
64
+ data_params.original_result = origin_grad
65
+ data_params.perturbed_result = perturbed_grad
66
+ data_params.grad_unequal_flag = False
67
+ data_params.valid_input_index = index
69
68
  try:
70
- handler.handle(self.data_params)
71
- if not self.data_params.is_consistent:
69
+ handler.handle(data_params)
70
+ if not data_params.is_consistent:
72
71
  self.is_compare = False
73
- self.data_params.grad_unequal_flag = True
74
- self.data_params.is_consistent = True
75
- self.data_params.perturbed_result = self.perturbed_grad_input
76
- self.data_params.original_result = self.origin_grad_input
77
- handler.handle(self.data_params)
72
+ data_params.grad_unequal_flag = True
73
+ data_params.is_consistent = True
74
+ data_params.perturbed_result = self.perturbed_grad_input
75
+ data_params.original_result = self.origin_grad_input
76
+ handler.handle(data_params)
78
77
  except Exception as e:
79
78
  logger.warning_on_rank_0(
80
79
  f"[msprobe] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}."
81
80
  f"{e}"
82
81
  )
83
- # 在扰动前后输出对比后释放输出的引用
84
- self.data_params.perturbed_result = None
85
- self.data_params.original_result = None
86
82
 
87
83
  def check_grad_input(self, origin_grad, new_grad_index):
88
84
  if self.perturbed_grad_input is None:
@@ -164,20 +160,20 @@ class GradSaver:
164
160
  return grad_input
165
161
 
166
162
  def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args):
167
- self.data_params.args = [need_grad_tensors, grad_output, inner_args]
168
- self.data_params.kwargs = {}
169
- self.data_params.valid_input_index = 0
170
- self.data_params.origin_func = self.get_grad_input_from_vjp
163
+ data_params = data_pre_deal(
164
+ self.handler_params.api_name,
165
+ self.get_grad_input_from_vjp,
166
+ [need_grad_tensors, grad_output, inner_args],
167
+ {}
168
+ )
171
169
  layer = LayerFactory.create(
172
170
  self.handler_params.api_name,
173
171
  self.handler_params.fuzz_device,
174
172
  self.handler_params.pert_mode,
175
173
  )
176
- layer.handle(self.data_params)
177
- # 在计算扰动输出之后,释放输入的引用
178
- self.data_params.args = None
174
+ layer.handle(data_params)
179
175
  # 确定扰动成功后,才会暂存
180
- if self.data_params.perturbed_result:
176
+ if data_params.perturbed_result:
181
177
  self.perturbed_grad_input = tuple(
182
- [x.cpu() for x in self.data_params.perturbed_result]
178
+ [x.cpu() for x in data_params.perturbed_result]
183
179
  )
@@ -10,7 +10,10 @@ from msprobe.pytorch.free_benchmark.common.enums import (
10
10
  HandlerType,
11
11
  PerturbationMode,
12
12
  )
13
- from msprobe.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params
13
+ from msprobe.pytorch.free_benchmark.common.params import (
14
+ data_pre_deal,
15
+ make_handler_params,
16
+ )
14
17
  from msprobe.pytorch.free_benchmark.compare.grad_saver import GradSaver
15
18
  from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
16
19
  from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
@@ -70,9 +73,9 @@ class FreeBenchmarkCheck(ABC):
70
73
  layer.handle(data_params)
71
74
  handler_params = make_handler_params(name, self.config, self.current_iter)
72
75
  handler = FuzzHandlerFactory.create(handler_params)
73
- handler.handle(data_params)
74
- return data_params.perturbed_result, handler.get_unequal_rows()
75
-
76
+ perturbed_output = handler.handle(data_params)
77
+ return perturbed_output, handler.get_unequal_rows()
78
+
76
79
  def backward(self, name, module, grad_output):
77
80
 
78
81
  if not self.config.fuzz_stage == Const.BACKWARD:
@@ -32,7 +32,7 @@ class AddNoiseLayer(NpuBaseLayer):
32
32
  return type(tensor_obj)([self.add_noise(value) for value in tensor_obj])
33
33
  return tensor_obj
34
34
 
35
- def handle(self, params: DataParams) -> torch.Any:
35
+ def handle(self, params: DataParams):
36
36
  """
37
37
  对输入添加扰动并返回
38
38
  """
@@ -48,7 +48,7 @@ class BitNoiseLayer(NpuBaseLayer):
48
48
  return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
49
49
  return tensor_obj
50
50
 
51
- def handle(self, params: DataParams) -> torch.Any:
51
+ def handle(self, params: DataParams):
52
52
  """
53
53
  对输入添加扰动并返回
54
54
  """
@@ -39,7 +39,7 @@ class ChangeValueLayer(NpuBaseLayer):
39
39
  return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
40
40
  return tensor_obj
41
41
 
42
- def handle(self, params: DataParams) -> torch.Any:
42
+ def handle(self, params: DataParams):
43
43
  """
44
44
  对输入添加扰动并返回
45
45
  """
@@ -17,7 +17,7 @@ class ImprovePrecisionLayer(NpuBaseLayer):
17
17
  and torch.is_floating_point(tensor_obj)
18
18
  and tensor_obj.dtype not in [torch.float32, torch.float64]
19
19
  ):
20
- self._set_improve_valus(tensor_obj)
20
+ self._set_improve_values(tensor_obj)
21
21
  tensor_obj = self._change_dtype(tensor_obj)
22
22
  self.is_added = True
23
23
  return tensor_obj
@@ -32,7 +32,7 @@ class ImprovePrecisionLayer(NpuBaseLayer):
32
32
  )
33
33
  return tensor_obj
34
34
 
35
- def handle(self, params: DataParams) -> torch.Any:
35
+ def handle(self, params: DataParams):
36
36
  logger.info_on_rank_0(
37
37
  f"[msprobe] Free benchmark: Perturbation is "
38
38
  f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
@@ -50,7 +50,7 @@ class ImprovePrecisionLayer(NpuBaseLayer):
50
50
  params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
51
51
  return params.perturbed_result
52
52
 
53
- def _set_improve_valus(self, inputs):
53
+ def _set_improve_values(self, inputs):
54
54
  if inputs.dtype in [torch.float16, torch.bfloat16]:
55
55
  self.perturbed_value = torch.float32
56
56
 
@@ -16,7 +16,7 @@ class NoChangeLayer(NpuBaseLayer):
16
16
  self.is_added = True
17
17
  return tensor_obj
18
18
 
19
- def handle(self, params: DataParams) -> torch.Any:
19
+ def handle(self, params: DataParams):
20
20
  """
21
21
  对输入添加扰动并返回
22
22
  """
@@ -8,7 +8,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
8
8
 
9
9
  class CpuLayer(BaseLayer):
10
10
 
11
- def handle(self, params: DataParams) -> torch.Any:
11
+ def handle(self, params: DataParams):
12
12
 
13
13
  logger.info_on_rank_0(
14
14
  f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
@@ -1,6 +1,7 @@
1
1
  import math
2
2
  from abc import ABC, abstractmethod
3
3
  from typing import Any, Optional, Tuple
4
+ import numpy as np
4
5
 
5
6
  import torch
6
7
  from msprobe.core.common.const import Const
@@ -34,15 +35,36 @@ class FuzzHandler(ABC):
34
35
  origin_ouput = origin_ouput.values
35
36
  perturbed_output = perturbed_output.values
36
37
  if hasattr(perturbed_output, "dtype"):
37
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype)
38
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
38
39
  else:
39
- abs_tol = FuzzThreshold.F32_THD.value
40
+ abs_tol = FuzzThreshold.F32_THD
40
41
  return (
41
42
  origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
42
43
  perturbed_output,
43
44
  abs_tol,
44
45
  )
45
46
 
47
+ @staticmethod
48
+ def tensor_split_for_error_calculate(origin_output, perturbed_output):
49
+ """
50
+ 对将投入误差值计算的扰动前后输出张量进行分块
51
+ :param origin_output: 原始输出
52
+ :param perturbed_output: 扰动后输出
53
+ :return origin_output_chunks: 切块后原始输出列表
54
+ :return perturbed_output_chunks: 切块后扰动后输出列表
55
+ """
56
+ single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
57
+ if single_output_mem == 0 or origin_output.ndim == 0:
58
+ return [origin_output], [perturbed_output]
59
+ # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
60
+ chunks_exp = int(math.log(single_output_mem, 2)) - 4
61
+ chunks = 2 ** chunks_exp
62
+ chunks = max(chunks, 1)
63
+ chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
64
+ origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
65
+ perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
66
+ return origin_output_chunks, perturbed_output_chunks
67
+
46
68
  @staticmethod
47
69
  def convert_overflow_ratio_to_consistent(ratio):
48
70
  if math.isnan(ratio) or math.isinf(ratio):
@@ -61,36 +83,28 @@ class FuzzHandler(ABC):
61
83
  self, origin_output, perturbed_output, norm_type, abs_tol
62
84
  ):
63
85
  if norm_type == NormType.ENDLESS_NORM:
64
- return self.get_endless_norm(origin_output, perturbed_output, abs_tol)
86
+ return self.calculate_error(origin_output, perturbed_output, abs_tol)
65
87
  return ThresholdConfig.COMP_CONSISTENT
66
88
 
67
- def get_endless_norm(self, origin_output, perturbed_output, abs_tol):
68
- ratio_tensor1 = TorchC.where(
69
- TorchC.gt(TorchC.abs(perturbed_output), abs_tol),
70
- TorchC.div(
71
- TorchC.abs(origin_output),
72
- TorchC.add(TorchC.abs(perturbed_output), abs_tol),
73
- ),
74
- 1,
75
- )
76
- ratio_tensor2 = TorchC.where(
77
- TorchC.gt(TorchC.abs(origin_output), abs_tol),
78
- TorchC.div(
79
- TorchC.abs(perturbed_output),
80
- TorchC.add(TorchC.abs(origin_output), abs_tol),
81
- ),
82
- 1,
83
- )
89
+ def calculate_error(self, origin_output, perturbed_output, abs_tol):
90
+ origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
91
+ norm1 = -np.inf
92
+ norm2 = -np.inf
93
+ norm3 = np.inf
94
+ for i, chunk_origin in enumerate(origin_output_chunks):
95
+ if chunk_origin.nelement() == 0:
96
+ break
97
+ chunk_perturbed = perturbed_output_chunks[i]
98
+ ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
99
+ TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
100
+ ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
101
+ TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
102
+ norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
103
+ max_ratio1, max_ratio2 = norm_values.tolist()
104
+ norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
105
+ norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
106
+ norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
84
107
 
85
- norm1 = self.convert_overflow_ratio_to_consistent(
86
- TorchC.max(ratio_tensor1).item()
87
- )
88
- norm2 = self.convert_overflow_ratio_to_consistent(
89
- TorchC.max(ratio_tensor2).item()
90
- )
91
- norm3 = self.convert_overflow_ratio_to_consistent(
92
- TorchC.min(ratio_tensor1).item()
93
- )
94
108
  if norm3 < 0:
95
109
  ratio = ThresholdConfig.SYMBOL_FLIPPING
96
110
  else:
@@ -22,7 +22,6 @@ class FuzzHandlerFactory:
22
22
  handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
23
23
  else:
24
24
  handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
25
- # TODO
26
25
  if not handler:
27
26
  raise FreeBenchmarkException(
28
27
  FreeBenchmarkException.UnsupportedType,
@@ -0,0 +1,75 @@
1
+ from msprobe.pytorch.common.utils import logger
2
+ from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
3
+ from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
4
+ npu_confusion_transpose_backward
5
+ from msprobe.pytorch.bench_functions.fast_gelu import fast_gelu, npu_fast_gelu_backward
6
+ from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval
7
+ from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward
8
+ from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward
9
+ from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad
10
+ from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward
11
+ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
12
+ from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
13
+ npu_scaled_masked_softmax_backward
14
+ from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish
15
+
16
+
17
+ class Register(dict):
18
+ def __init__(self, *args, **kwargs):
19
+ super(Register, self).__init__(*args, **kwargs)
20
+ self._dict = {}
21
+
22
+ def __call__(self, target_func_list):
23
+ for target in target_func_list:
24
+ self.register(target)
25
+ return
26
+
27
+ def __setitem__(self, key, value):
28
+ self._dict[key] = value
29
+
30
+ def __getitem__(self, key):
31
+ return self._dict[key]
32
+
33
+ def __contains__(self, key):
34
+ return key in self._dict
35
+
36
+ def __str__(self):
37
+ return str(self._dict)
38
+
39
+ def keys(self):
40
+ return self._dict.keys()
41
+
42
+ def values(self):
43
+ return self._dict.values()
44
+
45
+ def items(self):
46
+ return self._dict.items()
47
+
48
+ def register(self, target):
49
+
50
+ def add_register_item(key, value):
51
+ if key in self._dict:
52
+ logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
53
+ self[key] = value
54
+ return value
55
+
56
+ if callable(target):
57
+ return add_register_item(target.__name__, target)
58
+ else:
59
+ raise Exception(f"The func {target} is not callable.")
60
+
61
+
62
+ # register for npu custom bench functions
63
+ npu_custom_functions = Register()
64
+ npu_custom_functions([
65
+ npu_apply_adam_w, npu_confusion_transpose, fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
66
+ npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu
67
+ ])
68
+
69
+ # register for npu custom backward bench functions
70
+ npu_custom_grad_functions = Register()
71
+ npu_custom_grad_functions([
72
+ npu_confusion_transpose_backward, npu_fast_gelu_backward, npu_linear_backward, matmul_backward,
73
+ npu_fusion_attention_grad, npu_rms_norm_backward, npu_rotary_mul_backward, npu_scaled_masked_softmax_backward,
74
+ npu_swiglu_backward
75
+ ])
@@ -3,7 +3,7 @@ from msprobe.pytorch.common.log import logger
3
3
  from msprobe.core.common.const import Const
4
4
  from msprobe.pytorch.hook_module.api_registry import api_register
5
5
  from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
6
- from msprobe.core.common.exceptions import MsaccException
6
+ from msprobe.core.common.exceptions import MsprobeException
7
7
  from msprobe.core.data_dump.scope import BaseScope
8
8
 
9
9
  module_count = {}
@@ -12,10 +12,10 @@ module_count = {}
12
12
  def module_dump(module, dump_name):
13
13
  if not isinstance(module, nn.Module):
14
14
  logger.error("The parameter:module in module_dump is not a Module subclass.")
15
- raise MsaccException(MsaccException.INVALID_PARAM_ERROR)
15
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
16
16
  if not isinstance(dump_name, str):
17
17
  logger.error("The parameter:dump_name in module_dump is not a str type.")
18
- raise MsaccException(MsaccException.INVALID_PARAM_ERROR)
18
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
19
19
  api_register.api_originality()
20
20
  if dump_name not in module_count:
21
21
  module_count[dump_name] = 0
@@ -24,7 +24,7 @@ def module_dump(module, dump_name):
24
24
  dump_name = dump_name + Const.SEP + str(module_count.get(dump_name)) + Const.SEP
25
25
 
26
26
  pdg = PrecisionDebugger()
27
- _, forward_hook, backward_hook = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name)
27
+ _, forward_hook, backward_hook, _ = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name)
28
28
  module.register_forward_hook(forward_hook, with_kwargs=True)
29
29
  module.register_full_backward_hook(backward_hook)
30
30
 
File without changes