mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,39 +1,39 @@
1
- from typing import Any
2
-
3
- from msprobe.pytorch.free_benchmark import logger
4
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
6
- from msprobe.pytorch.free_benchmark.common.utils import Tools
7
- from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
8
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
9
-
10
-
11
- class CheckerHandler(FuzzHandler):
12
- def other_compare(self, data_params: DataParams) -> bool:
13
- is_consistent = SingleCompare().compare_seq(
14
- data_params.original_result, data_params.perturbed_result
15
- )
16
- if not is_consistent:
17
- self.unequal_rows.append(
18
- make_unequal_row(data_params, self.params)
19
- )
20
-
21
- def get_threshold(self, dtype):
22
- return self._get_default_threshold(dtype)
23
-
24
- def handle(self, data_params: DataParams) -> Any:
25
- if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
26
- data_params.perturbed_result
27
- ):
28
- return data_params.original_result
29
- try:
30
- if self.params.fuzz_device == DeviceType.NPU:
31
- self.cmp_output_npu(data_params)
32
- else:
33
- self.other_compare(data_params)
34
- except Exception as e:
35
- logger.warning_on_rank_0(
36
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
37
- f"when campare the result exception raise {e}"
38
- )
39
- return data_params.original_result
1
+ from typing import Any
2
+
3
+ from msprobe.pytorch.free_benchmark import logger
4
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
6
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
7
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
8
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
9
+
10
+
11
+ class CheckerHandler(FuzzHandler):
12
+ def other_compare(self, data_params: DataParams) -> bool:
13
+ is_consistent = SingleCompare().compare_seq(
14
+ data_params.original_result, data_params.perturbed_result
15
+ )
16
+ if not is_consistent:
17
+ self.unequal_rows.append(
18
+ make_unequal_row(data_params, self.params)
19
+ )
20
+
21
+ def get_threshold(self, dtype):
22
+ return self._get_default_threshold(dtype)
23
+
24
+ def handle(self, data_params: DataParams) -> Any:
25
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
26
+ data_params.perturbed_result
27
+ ):
28
+ return data_params.original_result
29
+ try:
30
+ if self.params.fuzz_device == DeviceType.NPU:
31
+ self.cmp_output_npu(data_params)
32
+ else:
33
+ self.other_compare(data_params)
34
+ except Exception as e:
35
+ logger.warning_on_rank_0(
36
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
37
+ f"when campare the result exception raise {e}"
38
+ )
39
+ return data_params.original_result
@@ -1,24 +1,24 @@
1
- from typing import Any
2
-
3
- from msprobe.pytorch.free_benchmark.common.params import DataParams
4
- from msprobe.pytorch.free_benchmark.common.utils import Tools
5
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
6
- from msprobe.pytorch.free_benchmark import logger
7
-
8
-
9
- class FixHandler(FuzzHandler):
10
-
11
- def get_threshold(self, dtype):
12
- return self._get_default_threshold(dtype)
13
-
14
- def handle(self, data_params: DataParams) -> Any:
15
- try:
16
- return Tools.convert_fuzz_output_to_origin(
17
- data_params.original_result, data_params.perturbed_result
18
- )
19
- except Exception as e:
20
- logger.warning_on_rank_0(
21
- f"[msprobe] Free Benchmark: For {self.params.api_name} "
22
- f"Fix output failed. "
23
- )
1
+ from typing import Any
2
+
3
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
4
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
5
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
6
+ from msprobe.pytorch.free_benchmark import logger
7
+
8
+
9
+ class FixHandler(FuzzHandler):
10
+
11
+ def get_threshold(self, dtype):
12
+ return self._get_default_threshold(dtype)
13
+
14
+ def handle(self, data_params: DataParams) -> Any:
15
+ try:
16
+ return Tools.convert_fuzz_output_to_origin(
17
+ data_params.original_result, data_params.perturbed_result
18
+ )
19
+ except Exception as e:
20
+ logger.warning_on_rank_0(
21
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
22
+ f"Fix output failed. "
23
+ )
24
24
  return data_params.original_result
@@ -1,31 +1,30 @@
1
- from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
- from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
- from msprobe.pytorch.free_benchmark.common.enums import HandlerType
4
- from msprobe.pytorch.free_benchmark.common.params import HandlerParams
5
- from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
6
- from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
7
- from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
8
-
9
-
10
- class FuzzHandlerFactory:
11
-
12
- result_handlers = {
13
- HandlerType.CHECK: CheckerHandler,
14
- HandlerType.FIX: FixHandler,
15
- HandlerType.PREHEAT: PreheatHandler,
16
- }
17
-
18
- @staticmethod
19
- def create(params: HandlerParams):
20
- if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
21
- if not if_preheat:
22
- handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
23
- else:
24
- handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
25
- # TODO
26
- if not handler:
27
- raise FreeBenchmarkException(
28
- FreeBenchmarkException.UnsupportedType,
29
- f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
30
- )
31
- return handler(params)
1
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
+ from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
+ from msprobe.pytorch.free_benchmark.common.enums import HandlerType
4
+ from msprobe.pytorch.free_benchmark.common.params import HandlerParams
5
+ from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
6
+ from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
7
+ from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
8
+
9
+
10
+ class FuzzHandlerFactory:
11
+
12
+ result_handlers = {
13
+ HandlerType.CHECK: CheckerHandler,
14
+ HandlerType.FIX: FixHandler,
15
+ HandlerType.PREHEAT: PreheatHandler,
16
+ }
17
+
18
+ @staticmethod
19
+ def create(params: HandlerParams):
20
+ if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
21
+ if not if_preheat:
22
+ handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
23
+ else:
24
+ handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
25
+ if not handler:
26
+ raise FreeBenchmarkException(
27
+ FreeBenchmarkException.UnsupportedType,
28
+ f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
29
+ )
30
+ return handler(params)
@@ -1,170 +1,170 @@
1
- import math
2
- from typing import Any
3
-
4
- from msprobe.pytorch.free_benchmark import logger
5
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
6
- from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
7
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
8
- from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
9
- from msprobe.pytorch.free_benchmark.common.utils import Tools
10
- from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
11
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
12
-
13
-
14
- class PreheatHandler(FuzzHandler):
15
-
16
- def __init__(self, params: HandlerParams) -> None:
17
- super().__init__(params)
18
- self.pure_name = Tools.get_pure_api_name(self.params.api_name)
19
-
20
- def get_threshold(self, dtype):
21
- return preheat_counter.get_api_thd(self.pure_name, dtype)
22
-
23
- def compare_npu_and_cpu(self, data_params: DataParams):
24
- args = Tools.convert_device_and_dtype(
25
- data_params.args, DeviceType.CPU, change_dtype=True
26
- )
27
- kwargs = Tools.convert_device_and_dtype(
28
- data_params.kwargs, DeviceType.CPU, change_dtype=True
29
- )
30
- cpu_result = data_params.origin_func(*args, **kwargs)
31
- return SingleCompare().compare_seq(data_params.original_result, cpu_result)
32
-
33
- def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
34
- # 存储当前step所有输出比值和对应npu\cpu比对结果
35
- preheat_counter.update_preheat_record(
36
- self.pure_name,
37
- first_dtype,
38
- (max_fuzz_ratio, cpu_consistent),
39
- )
40
- if self._need_adjust_threshold():
41
- self._adjust_threshold()
42
-
43
- def handle(self, data_params: DataParams) -> Any:
44
-
45
- if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
46
- data_params.perturbed_result
47
- ):
48
- return data_params.original_result
49
-
50
- if self.params.step == 0:
51
- preheat_counter.add_one_step_used_api(self.pure_name)
52
- return data_params.original_result
53
-
54
- # 如果当前api,step需要预热
55
- npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
56
- data_params.is_consistent = npu_consistent
57
-
58
- preheat_counter.check_step(self.params.step)
59
-
60
- if self.params.preheat_config.get("preheat_step") <= self.params.step:
61
- return data_params.original_result
62
-
63
- if not data_params.grad_unequal_flag:
64
- data_params.grad_unequal_flag = True
65
- data_params.is_consistent = False
66
- return data_params.original_result
67
- preheat_counter.add_api_called_time(self.pure_name)
68
-
69
- if not self._is_take_a_sample():
70
- return data_params.original_result
71
-
72
- cpu_consistent = True
73
- try:
74
- cpu_consistent = self.compare_npu_and_cpu(data_params)
75
- except Exception as e:
76
- logger.warning_on_rank_0(
77
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
78
- f"when campare to cpu exception raise {e}"
79
- )
80
- try:
81
- first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
82
- except RuntimeError:
83
- logger.warning_on_rank_0(
84
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
85
- f"the output sequence does not contain tensors."
86
- )
87
- if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
88
- self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
89
-
90
- return data_params.original_result
91
-
92
- def _is_take_a_sample(self) -> bool:
93
- need_sample_set = self._get_need_sample_set()
94
- curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
95
- res = curr_called_seq in need_sample_set
96
- if res:
97
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
98
- logger.info_on_rank_0(
99
- f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
100
- f"api_name {self.params.api_name}, "
101
- f"curr_called_seq: {curr_called_seq}/{total_count}"
102
- )
103
- preheat_counter.add_api_sample_time(self.pure_name)
104
- return res
105
-
106
- def _get_sample_count_per_step(self) -> set:
107
- """
108
- 每一个step中应该采集的样本数
109
- """
110
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
111
- preheat_step = self.params.preheat_config.get("preheat_step")
112
- max_sample = self.params.preheat_config.get("max_sample")
113
- return min(math.ceil(total_count / preheat_step), max_sample)
114
-
115
- def _get_need_sample_set(self):
116
- """
117
- 需要采集的api集合
118
- """
119
- # 每一步样本数
120
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
- sample_count_per_step = self._get_sample_count_per_step()
122
- need_sample_set = set()
123
- prehead_step = self.params.preheat_config.get("preheat_step")
124
- for i in range(1, sample_count_per_step + 1):
125
- count = (prehead_step * (i - 1) + self.params.step) % total_count
126
- if count == 0:
127
- count = total_count
128
- need_sample_set.add(count)
129
- return need_sample_set
130
-
131
- def _need_adjust_threshold(self) -> bool:
132
- sample_count_per_step = self._get_sample_count_per_step()
133
- sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
134
- res = sampled_time >= sample_count_per_step
135
- return res
136
-
137
- def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
- con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
- incon_ratio = [
140
- ratio for ratio, is_consistent in compare_result if not is_consistent
141
- ]
142
- old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
- new_thd = old_thd
144
- # 正例负例都存在
145
- if con_ratio and incon_ratio:
146
- if min(incon_ratio) > max(con_ratio):
147
- new_thd = min(min(incon_ratio), old_thd)
148
- preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
149
- elif con_ratio:
150
- # 存在漏报
151
- if max(con_ratio) > old_thd:
152
- new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
153
- else:
154
- new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
155
- else:
156
- new_thd = min(min(incon_ratio), old_thd)
157
- preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
158
- return new_thd
159
-
160
- def _adjust_threshold(self):
161
- for dtype_str, compare_result in preheat_counter.preheat_record[
162
- self.pure_name
163
- ].items():
164
- new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
165
- threshold = self._get_default_threshold(
166
- preheat_counter.dtype_map.get(dtype_str)
167
- )
168
- preheat_counter.update_api_thd(
169
- self.pure_name, dtype_str, new_thd, threshold
170
- )
1
+ import math
2
+ from typing import Any
3
+
4
+ from msprobe.pytorch.free_benchmark import logger
5
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
6
+ from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
7
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
8
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
9
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
10
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
11
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
12
+
13
+
14
+ class PreheatHandler(FuzzHandler):
15
+
16
+ def __init__(self, params: HandlerParams) -> None:
17
+ super().__init__(params)
18
+ self.pure_name = Tools.get_pure_api_name(self.params.api_name)
19
+
20
+ def get_threshold(self, dtype):
21
+ return preheat_counter.get_api_thd(self.pure_name, dtype)
22
+
23
+ def compare_npu_and_cpu(self, data_params: DataParams):
24
+ args = Tools.convert_device_and_dtype(
25
+ data_params.args, DeviceType.CPU, change_dtype=True
26
+ )
27
+ kwargs = Tools.convert_device_and_dtype(
28
+ data_params.kwargs, DeviceType.CPU, change_dtype=True
29
+ )
30
+ cpu_result = data_params.origin_func(*args, **kwargs)
31
+ return SingleCompare().compare_seq(data_params.original_result, cpu_result)
32
+
33
+ def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
34
+ # 存储当前step所有输出比值和对应npu\cpu比对结果
35
+ preheat_counter.update_preheat_record(
36
+ self.pure_name,
37
+ first_dtype,
38
+ (max_fuzz_ratio, cpu_consistent),
39
+ )
40
+ if self._need_adjust_threshold():
41
+ self._adjust_threshold()
42
+
43
+ def handle(self, data_params: DataParams) -> Any:
44
+
45
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
46
+ data_params.perturbed_result
47
+ ):
48
+ return data_params.original_result
49
+
50
+ if self.params.step == 0:
51
+ preheat_counter.add_one_step_used_api(self.pure_name)
52
+ return data_params.original_result
53
+
54
+ # 如果当前api,step需要预热
55
+ npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
56
+ data_params.is_consistent = npu_consistent
57
+
58
+ preheat_counter.check_step(self.params.step)
59
+
60
+ if self.params.preheat_config.get("preheat_step") <= self.params.step:
61
+ return data_params.original_result
62
+
63
+ if not data_params.grad_unequal_flag:
64
+ data_params.grad_unequal_flag = True
65
+ data_params.is_consistent = False
66
+ return data_params.original_result
67
+ preheat_counter.add_api_called_time(self.pure_name)
68
+
69
+ if not self._is_take_a_sample():
70
+ return data_params.original_result
71
+
72
+ cpu_consistent = True
73
+ try:
74
+ cpu_consistent = self.compare_npu_and_cpu(data_params)
75
+ except Exception as e:
76
+ logger.warning_on_rank_0(
77
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
78
+ f"when campare to cpu exception raise {e}"
79
+ )
80
+ try:
81
+ first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
82
+ except RuntimeError:
83
+ logger.warning_on_rank_0(
84
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
85
+ f"the output sequence does not contain tensors."
86
+ )
87
+ if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
88
+ self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
89
+
90
+ return data_params.original_result
91
+
92
+ def _is_take_a_sample(self) -> bool:
93
+ need_sample_set = self._get_need_sample_set()
94
+ curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
95
+ res = curr_called_seq in need_sample_set
96
+ if res:
97
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
98
+ logger.info_on_rank_0(
99
+ f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
100
+ f"api_name {self.params.api_name}, "
101
+ f"curr_called_seq: {curr_called_seq}/{total_count}"
102
+ )
103
+ preheat_counter.add_api_sample_time(self.pure_name)
104
+ return res
105
+
106
+ def _get_sample_count_per_step(self) -> set:
107
+ """
108
+ 每一个step中应该采集的样本数
109
+ """
110
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
111
+ preheat_step = self.params.preheat_config.get("preheat_step")
112
+ max_sample = self.params.preheat_config.get("max_sample")
113
+ return min(math.ceil(total_count / preheat_step), max_sample)
114
+
115
+ def _get_need_sample_set(self):
116
+ """
117
+ 需要采集的api集合
118
+ """
119
+ # 每一步样本数
120
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
+ sample_count_per_step = self._get_sample_count_per_step()
122
+ need_sample_set = set()
123
+ prehead_step = self.params.preheat_config.get("preheat_step")
124
+ for i in range(1, sample_count_per_step + 1):
125
+ count = (prehead_step * (i - 1) + self.params.step) % total_count
126
+ if count == 0:
127
+ count = total_count
128
+ need_sample_set.add(count)
129
+ return need_sample_set
130
+
131
+ def _need_adjust_threshold(self) -> bool:
132
+ sample_count_per_step = self._get_sample_count_per_step()
133
+ sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
134
+ res = sampled_time >= sample_count_per_step
135
+ return res
136
+
137
+ def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
+ con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
+ incon_ratio = [
140
+ ratio for ratio, is_consistent in compare_result if not is_consistent
141
+ ]
142
+ old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
+ new_thd = old_thd
144
+ # 正例负例都存在
145
+ if con_ratio and incon_ratio:
146
+ if min(incon_ratio) > max(con_ratio):
147
+ new_thd = min(min(incon_ratio), old_thd)
148
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
149
+ elif con_ratio:
150
+ # 存在漏报
151
+ if max(con_ratio) > old_thd:
152
+ new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
153
+ else:
154
+ new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
155
+ else:
156
+ new_thd = min(min(incon_ratio), old_thd)
157
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
158
+ return new_thd
159
+
160
+ def _adjust_threshold(self):
161
+ for dtype_str, compare_result in preheat_counter.preheat_record[
162
+ self.pure_name
163
+ ].items():
164
+ new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
165
+ threshold = self._get_default_threshold(
166
+ preheat_counter.dtype_map.get(dtype_str)
167
+ )
168
+ preheat_counter.update_api_thd(
169
+ self.pure_name, dtype_str, new_thd, threshold
170
+ )
@@ -0,0 +1,76 @@
1
+ from msprobe.pytorch.common.utils import logger
2
+ from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
3
+ from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
4
+ npu_confusion_transpose_backward
5
+ from msprobe.pytorch.bench_functions.fast_gelu import npu_fast_gelu, npu_fast_gelu_backward
6
+ from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval
7
+ from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward
8
+ from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward
9
+ from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad, \
10
+ gpu_fusion_attention
11
+ from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward
12
+ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
13
+ from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
14
+ npu_scaled_masked_softmax_backward
15
+ from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish
16
+
17
+
18
+ class Register(dict):
19
+ def __init__(self, *args, **kwargs):
20
+ super(Register, self).__init__(*args, **kwargs)
21
+ self._dict = {}
22
+
23
+ def __call__(self, target_func_list):
24
+ for target in target_func_list:
25
+ self.register(target)
26
+ return
27
+
28
+ def __setitem__(self, key, value):
29
+ self._dict[key] = value
30
+
31
+ def __getitem__(self, key):
32
+ return self._dict[key]
33
+
34
+ def __contains__(self, key):
35
+ return key in self._dict
36
+
37
+ def __str__(self):
38
+ return str(self._dict)
39
+
40
+ def keys(self):
41
+ return self._dict.keys()
42
+
43
+ def values(self):
44
+ return self._dict.values()
45
+
46
+ def items(self):
47
+ return self._dict.items()
48
+
49
+ def register(self, target):
50
+
51
+ def add_register_item(key, value):
52
+ if key in self._dict:
53
+ logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
54
+ self[key] = value
55
+ return value
56
+
57
+ if callable(target):
58
+ return add_register_item(target.__name__, target)
59
+ else:
60
+ raise Exception(f"The func {target} is not callable.")
61
+
62
+
63
+ # register for npu custom bench functions
64
+ npu_custom_functions = Register()
65
+ npu_custom_functions([
66
+ npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
67
+ npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
68
+ ])
69
+
70
+ # register for npu custom backward bench functions
71
+ npu_custom_grad_functions = Register()
72
+ npu_custom_grad_functions([
73
+ npu_confusion_transpose_backward, npu_fast_gelu_backward, npu_linear_backward, matmul_backward,
74
+ npu_fusion_attention_grad, npu_rms_norm_backward, npu_rotary_mul_backward, npu_scaled_masked_softmax_backward,
75
+ npu_swiglu_backward
76
+ ])