mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,28 +1,28 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
- from msprobe.pytorch.free_benchmark.common.params import DataParams
5
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
6
- NpuBaseLayer,
7
- )
8
-
9
-
10
- class NoChangeLayer(NpuBaseLayer):
11
-
12
- def no_change(self, tensor_obj):
13
- """
14
- 不对输入做任何改变、直接二次执行
15
- """
16
- self.is_added = True
17
- return tensor_obj
18
-
19
- def handle(self, params: DataParams) -> torch.Any:
20
- """
21
- 对输入添加扰动并返回
22
- """
23
- logger.info_on_rank_0(
24
- f"[msprobe] Free benchmark: Perturbation is "
25
- f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
26
- )
27
- params.perturbed_value = self.no_change(params.args[params.valid_input_index])
28
- return self.perturbed_result(params)
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
5
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
6
+ NpuBaseLayer,
7
+ )
8
+
9
+
10
+ class NoChangeLayer(NpuBaseLayer):
11
+
12
+ def no_change(self, tensor_obj):
13
+ """
14
+ 不对输入做任何改变、直接二次执行
15
+ """
16
+ self.is_added = True
17
+ return tensor_obj
18
+
19
+ def handle(self, params: DataParams):
20
+ """
21
+ 对输入添加扰动并返回
22
+ """
23
+ logger.info_on_rank_0(
24
+ f"[msprobe] Free benchmark: Perturbation is "
25
+ f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
26
+ )
27
+ params.perturbed_value = self.no_change(params.args[params.valid_input_index])
28
+ return self.perturbed_result(params)
@@ -1,45 +1,45 @@
1
- from abc import abstractmethod
2
- from typing import Any
3
-
4
- import torch
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
-
8
-
9
- class NpuBaseLayer(BaseLayer):
10
- def __init__(self, api_name: str) -> None:
11
- super().__init__(api_name)
12
- self.perturbed_value = None # 扰动的元素
13
- self.is_added = False # 标记当前算子输入是否调整
14
-
15
- @staticmethod
16
- def perturbed_result(params: DataParams) -> Any:
17
- args_front = params.args[: params.valid_input_index]
18
- args_rear = params.args[params.valid_input_index + 1:]
19
- # 此处会将有inplace属性的算子换为非inplace
20
- if "inplace" in params.kwargs:
21
- params.kwargs["inplace"] = False
22
- params.perturbed_result = params.origin_func(
23
- *args_front, params.perturbed_value, *args_rear, **params.kwargs
24
- )
25
- return params.perturbed_result
26
-
27
- @abstractmethod
28
- def handle(self, params: DataParams) -> Any:
29
- pass
30
-
31
- def pre_check(self, tensor_obj):
32
- """
33
- 检查张量是否符合标准(float类型且最大值大于对应精度最小值)
34
- """
35
- # 只针对第一个满足要求的添加扰动
36
- if self.is_added:
37
- return False
38
- if not torch.is_floating_point(tensor_obj):
39
- return False
40
- if not self._check_details(tensor_obj):
41
- return False
42
- return True
43
-
44
- def _check_details(self, tensor_obj):
45
- return True
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+
4
+ import torch
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
+
8
+
9
+ class NpuBaseLayer(BaseLayer):
10
+ def __init__(self, api_name: str) -> None:
11
+ super().__init__(api_name)
12
+ self.perturbed_value = None # 扰动的元素
13
+ self.is_added = False # 标记当前算子输入是否调整
14
+
15
+ @staticmethod
16
+ def perturbed_result(params: DataParams) -> Any:
17
+ args_front = params.args[: params.valid_input_index]
18
+ args_rear = params.args[params.valid_input_index + 1:]
19
+ # 此处会将有inplace属性的算子换为非inplace
20
+ if "inplace" in params.kwargs:
21
+ params.kwargs["inplace"] = False
22
+ params.perturbed_result = params.origin_func(
23
+ *args_front, params.perturbed_value, *args_rear, **params.kwargs
24
+ )
25
+ return params.perturbed_result
26
+
27
+ @abstractmethod
28
+ def handle(self, params: DataParams) -> Any:
29
+ pass
30
+
31
+ def pre_check(self, tensor_obj):
32
+ """
33
+ 检查张量是否符合标准(float类型且最大值大于对应精度最小值)
34
+ """
35
+ # 只针对第一个满足要求的添加扰动
36
+ if self.is_added:
37
+ return False
38
+ if not torch.is_floating_point(tensor_obj):
39
+ return False
40
+ if not self._check_details(tensor_obj):
41
+ return False
42
+ return True
43
+
44
+ def _check_details(self, tensor_obj):
45
+ return True
@@ -1,19 +1,19 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.params import DataParams
4
- from msprobe.pytorch.free_benchmark.common.utils import Tools
5
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
-
8
-
9
- class CpuLayer(BaseLayer):
10
-
11
- def handle(self, params: DataParams) -> torch.Any:
12
-
13
- logger.info_on_rank_0(
14
- f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
15
- )
16
- new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
17
- new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
18
- params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
19
- return params.perturbed_result
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
4
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
5
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
+
8
+
9
+ class CpuLayer(BaseLayer):
10
+
11
+ def handle(self, params: DataParams):
12
+
13
+ logger.info_on_rank_0(
14
+ f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
15
+ )
16
+ new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
17
+ new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
18
+ params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
19
+ return params.perturbed_result
@@ -1,203 +1,217 @@
1
- import math
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Optional, Tuple
4
-
5
- import torch
6
- from msprobe.core.common.const import Const
7
- from msprobe.pytorch.free_benchmark import logger
8
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
9
- from msprobe.pytorch.free_benchmark.common.enums import (
10
- FuzzThreshold,
11
- NormType,
12
- PerturbationMode,
13
- )
14
- from msprobe.pytorch.free_benchmark.common.params import (
15
- DataParams,
16
- HandlerParams,
17
- make_unequal_row,
18
- )
19
- from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
20
-
21
-
22
- class FuzzHandler(ABC):
23
- def __init__(self, params: HandlerParams) -> None:
24
- self.params = params
25
- self.unequal_rows = []
26
-
27
- @staticmethod
28
- def pre_process(origin_ouput, perturbed_output):
29
- if (
30
- isinstance(origin_ouput, tuple)
31
- and hasattr(origin_ouput, "values")
32
- and hasattr(origin_ouput, "indices")
33
- ):
34
- origin_ouput = origin_ouput.values
35
- perturbed_output = perturbed_output.values
36
- if hasattr(perturbed_output, "dtype"):
37
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype)
38
- else:
39
- abs_tol = FuzzThreshold.F32_THD.value
40
- return (
41
- origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
42
- perturbed_output,
43
- abs_tol,
44
- )
45
-
46
- @staticmethod
47
- def convert_overflow_ratio_to_consistent(ratio):
48
- if math.isnan(ratio) or math.isinf(ratio):
49
- return ThresholdConfig.COMP_CONSISTENT
50
- return ratio
51
-
52
- @abstractmethod
53
- def get_threshold(self, dtype):
54
- pass
55
-
56
- @abstractmethod
57
- def handle(self, data_params: DataParams) -> Any:
58
- pass
59
-
60
- def get_ratio_from_specific_norm(
61
- self, origin_output, perturbed_output, norm_type, abs_tol
62
- ):
63
- if norm_type == NormType.ENDLESS_NORM:
64
- return self.get_endless_norm(origin_output, perturbed_output, abs_tol)
65
- return ThresholdConfig.COMP_CONSISTENT
66
-
67
- def get_endless_norm(self, origin_output, perturbed_output, abs_tol):
68
- ratio_tensor1 = TorchC.where(
69
- TorchC.gt(TorchC.abs(perturbed_output), abs_tol),
70
- TorchC.div(
71
- TorchC.abs(origin_output),
72
- TorchC.add(TorchC.abs(perturbed_output), abs_tol),
73
- ),
74
- 1,
75
- )
76
- ratio_tensor2 = TorchC.where(
77
- TorchC.gt(TorchC.abs(origin_output), abs_tol),
78
- TorchC.div(
79
- TorchC.abs(perturbed_output),
80
- TorchC.add(TorchC.abs(origin_output), abs_tol),
81
- ),
82
- 1,
83
- )
84
-
85
- norm1 = self.convert_overflow_ratio_to_consistent(
86
- TorchC.max(ratio_tensor1).item()
87
- )
88
- norm2 = self.convert_overflow_ratio_to_consistent(
89
- TorchC.max(ratio_tensor2).item()
90
- )
91
- norm3 = self.convert_overflow_ratio_to_consistent(
92
- TorchC.min(ratio_tensor1).item()
93
- )
94
- if norm3 < 0:
95
- ratio = ThresholdConfig.SYMBOL_FLIPPING
96
- else:
97
- ratio = max(norm1, norm2)
98
- return ratio
99
-
100
- def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
101
- try:
102
- origin_output, perturbed_output, abs_tol = self.pre_process(
103
- origin_output, perturbed_output
104
- )
105
- except Exception as e:
106
- logger.warning_on_rank_0(
107
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
108
- f"when computing ratio,"
109
- f" y1 or y2 dtype is not supported {e}"
110
- )
111
- return ThresholdConfig.COMP_NAN
112
- if self.params.fuzz_stage == Const.BACKWARD:
113
- abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
114
- else:
115
- abs_tol = abs_tol ** 0.5
116
- return self.get_ratio_from_specific_norm(
117
- origin_output, perturbed_output, norm_type, abs_tol
118
- )
119
-
120
- def npu_compare(
121
- self, origin_output, perturbed_output
122
- ) -> Tuple[bool, Optional[float]]:
123
-
124
- if isinstance(perturbed_output, int):
125
- return origin_output == perturbed_output, None
126
- elif isinstance(perturbed_output, float):
127
- if perturbed_output == 0:
128
- origin_output += FuzzThreshold.F32_THD
129
- perturbed_output += FuzzThreshold.F32_THD
130
- return (
131
- math.isclose(origin_output, perturbed_output),
132
- origin_output / perturbed_output,
133
- )
134
- elif not isinstance(perturbed_output, torch.Tensor):
135
- logger.warning_on_rank_0(
136
- f"[msprobe] Free Benchmark: For {self.params.api_name} "
137
- f"The compare for output type {type(perturbed_output)} is not supported"
138
- )
139
-
140
- threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
141
- ratio = self.ratio_calculate(
142
- origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
143
- )
144
- if ratio == ThresholdConfig.SYMBOL_FLIPPING:
145
- is_consistent = False
146
- else:
147
- is_consistent = threshold >= ratio >= 1 / threshold
148
- return is_consistent, ratio
149
-
150
- def cmp_output_npu(self, data_params: DataParams):
151
- npu_consistent = True
152
- max_fuzz_ratio = 0
153
- try:
154
- if isinstance(data_params.original_result, torch.Tensor):
155
- is_consistent, ratio = self.npu_compare(
156
- data_params.original_result, data_params.perturbed_result
157
- )
158
- npu_consistent = is_consistent
159
- max_fuzz_ratio = (
160
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
161
- )
162
- data_params.is_consistent = is_consistent and data_params.is_consistent
163
- if not is_consistent and data_params.grad_unequal_flag:
164
- self.unequal_rows.append(
165
- make_unequal_row(data_params, self.params, ratio=ratio)
166
- )
167
-
168
- elif isinstance(data_params.original_result, (list, tuple)):
169
- for index_, origin_item in enumerate(data_params.original_result):
170
- is_consistent, ratio = self.npu_compare(
171
- origin_item, data_params.perturbed_result[index_]
172
- )
173
- npu_consistent = npu_consistent and is_consistent
174
- max_fuzz_ratio = (
175
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
176
- )
177
- data_params.is_consistent = (
178
- is_consistent and data_params.is_consistent
179
- )
180
- if not is_consistent and data_params.grad_unequal_flag:
181
- self.unequal_rows.append(
182
- make_unequal_row(
183
- data_params, self.params, ratio=ratio, index=index_
184
- )
185
- )
186
- except Exception as e:
187
- logger.warning_on_rank_0(
188
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
189
- f"when campare the result exception raise {e}"
190
- )
191
- return npu_consistent, max_fuzz_ratio
192
-
193
- def get_unequal_rows(self):
194
- return self.unequal_rows
195
-
196
- def _get_default_threshold(self, dtype):
197
- if self.params.pert_mode == PerturbationMode.NO_CHANGE:
198
- threshold = ThresholdConfig.COMP_CONSISTENT
199
- else:
200
- threshold = ThresholdConfig.DTYPE_PER_THD.get(
201
- dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
202
- )
203
- return threshold
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Optional, Tuple
4
+ import numpy as np
5
+
6
+ import torch
7
+ from msprobe.core.common.const import Const
8
+ from msprobe.pytorch.free_benchmark import logger
9
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
10
+ from msprobe.pytorch.free_benchmark.common.enums import (
11
+ FuzzThreshold,
12
+ NormType,
13
+ PerturbationMode,
14
+ )
15
+ from msprobe.pytorch.free_benchmark.common.params import (
16
+ DataParams,
17
+ HandlerParams,
18
+ make_unequal_row,
19
+ )
20
+ from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
21
+
22
+
23
+ class FuzzHandler(ABC):
24
+ def __init__(self, params: HandlerParams) -> None:
25
+ self.params = params
26
+ self.unequal_rows = []
27
+
28
+ @staticmethod
29
+ def pre_process(origin_ouput, perturbed_output):
30
+ if (
31
+ isinstance(origin_ouput, tuple)
32
+ and hasattr(origin_ouput, "values")
33
+ and hasattr(origin_ouput, "indices")
34
+ ):
35
+ origin_ouput = origin_ouput.values
36
+ perturbed_output = perturbed_output.values
37
+ if hasattr(perturbed_output, "dtype"):
38
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
39
+ else:
40
+ abs_tol = FuzzThreshold.F32_THD
41
+ return (
42
+ origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
43
+ perturbed_output,
44
+ abs_tol,
45
+ )
46
+
47
+ @staticmethod
48
+ def tensor_split_for_error_calculate(origin_output, perturbed_output):
49
+ """
50
+ 对将投入误差值计算的扰动前后输出张量进行分块
51
+ :param origin_output: 原始输出
52
+ :param perturbed_output: 扰动后输出
53
+ :return origin_output_chunks: 切块后原始输出列表
54
+ :return perturbed_output_chunks: 切块后扰动后输出列表
55
+ """
56
+ single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
57
+ if single_output_mem == 0 or origin_output.ndim == 0:
58
+ return [origin_output], [perturbed_output]
59
+ # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
60
+ chunks_exp = int(math.log(single_output_mem, 2)) - 4
61
+ chunks = 2 ** chunks_exp
62
+ chunks = max(chunks, 1)
63
+ chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
64
+ origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
65
+ perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
66
+ return origin_output_chunks, perturbed_output_chunks
67
+
68
+ @staticmethod
69
+ def convert_overflow_ratio_to_consistent(ratio):
70
+ if math.isnan(ratio) or math.isinf(ratio):
71
+ return ThresholdConfig.COMP_CONSISTENT
72
+ return ratio
73
+
74
+ @abstractmethod
75
+ def get_threshold(self, dtype):
76
+ pass
77
+
78
+ @abstractmethod
79
+ def handle(self, data_params: DataParams) -> Any:
80
+ pass
81
+
82
+ def get_ratio_from_specific_norm(
83
+ self, origin_output, perturbed_output, norm_type, abs_tol
84
+ ):
85
+ if norm_type == NormType.ENDLESS_NORM:
86
+ return self.calculate_error(origin_output, perturbed_output, abs_tol)
87
+ return ThresholdConfig.COMP_CONSISTENT
88
+
89
+ def calculate_error(self, origin_output, perturbed_output, abs_tol):
90
+ origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
91
+ norm1 = -np.inf
92
+ norm2 = -np.inf
93
+ norm3 = np.inf
94
+ for i, chunk_origin in enumerate(origin_output_chunks):
95
+ if chunk_origin.nelement() == 0:
96
+ break
97
+ chunk_perturbed = perturbed_output_chunks[i]
98
+ ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
99
+ TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
100
+ ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
101
+ TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
102
+ norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
103
+ max_ratio1, max_ratio2 = norm_values.tolist()
104
+ norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
105
+ norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
106
+ norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
107
+
108
+ if norm3 < 0:
109
+ ratio = ThresholdConfig.SYMBOL_FLIPPING
110
+ else:
111
+ ratio = max(norm1, norm2)
112
+ return ratio
113
+
114
+ def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
115
+ try:
116
+ origin_output, perturbed_output, abs_tol = self.pre_process(
117
+ origin_output, perturbed_output
118
+ )
119
+ except Exception as e:
120
+ logger.warning_on_rank_0(
121
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
122
+ f"when computing ratio,"
123
+ f" y1 or y2 dtype is not supported {e}"
124
+ )
125
+ return ThresholdConfig.COMP_NAN
126
+ if self.params.fuzz_stage == Const.BACKWARD:
127
+ abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
128
+ else:
129
+ abs_tol = abs_tol ** 0.5
130
+ return self.get_ratio_from_specific_norm(
131
+ origin_output, perturbed_output, norm_type, abs_tol
132
+ )
133
+
134
+ def npu_compare(
135
+ self, origin_output, perturbed_output
136
+ ) -> Tuple[bool, Optional[float]]:
137
+
138
+ if isinstance(perturbed_output, int):
139
+ return origin_output == perturbed_output, None
140
+ elif isinstance(perturbed_output, float):
141
+ if perturbed_output == 0:
142
+ origin_output += FuzzThreshold.F32_THD
143
+ perturbed_output += FuzzThreshold.F32_THD
144
+ return (
145
+ math.isclose(origin_output, perturbed_output),
146
+ origin_output / perturbed_output,
147
+ )
148
+ elif not isinstance(perturbed_output, torch.Tensor):
149
+ logger.warning_on_rank_0(
150
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
151
+ f"The compare for output type {type(perturbed_output)} is not supported"
152
+ )
153
+
154
+ threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
155
+ ratio = self.ratio_calculate(
156
+ origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
157
+ )
158
+ if ratio == ThresholdConfig.SYMBOL_FLIPPING:
159
+ is_consistent = False
160
+ else:
161
+ is_consistent = threshold >= ratio >= 1 / threshold
162
+ return is_consistent, ratio
163
+
164
+ def cmp_output_npu(self, data_params: DataParams):
165
+ npu_consistent = True
166
+ max_fuzz_ratio = 0
167
+ try:
168
+ if isinstance(data_params.original_result, torch.Tensor):
169
+ is_consistent, ratio = self.npu_compare(
170
+ data_params.original_result, data_params.perturbed_result
171
+ )
172
+ npu_consistent = is_consistent
173
+ max_fuzz_ratio = (
174
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
175
+ )
176
+ data_params.is_consistent = is_consistent and data_params.is_consistent
177
+ if not is_consistent and data_params.grad_unequal_flag:
178
+ self.unequal_rows.append(
179
+ make_unequal_row(data_params, self.params, ratio=ratio)
180
+ )
181
+
182
+ elif isinstance(data_params.original_result, (list, tuple)):
183
+ for index_, origin_item in enumerate(data_params.original_result):
184
+ is_consistent, ratio = self.npu_compare(
185
+ origin_item, data_params.perturbed_result[index_]
186
+ )
187
+ npu_consistent = npu_consistent and is_consistent
188
+ max_fuzz_ratio = (
189
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
190
+ )
191
+ data_params.is_consistent = (
192
+ is_consistent and data_params.is_consistent
193
+ )
194
+ if not is_consistent and data_params.grad_unequal_flag:
195
+ self.unequal_rows.append(
196
+ make_unequal_row(
197
+ data_params, self.params, ratio=ratio, index=index_
198
+ )
199
+ )
200
+ except Exception as e:
201
+ logger.warning_on_rank_0(
202
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
203
+ f"when campare the result exception raise {e}"
204
+ )
205
+ return npu_consistent, max_fuzz_ratio
206
+
207
+ def get_unequal_rows(self):
208
+ return self.unequal_rows
209
+
210
+ def _get_default_threshold(self, dtype):
211
+ if self.params.pert_mode == PerturbationMode.NO_CHANGE:
212
+ threshold = ThresholdConfig.COMP_CONSISTENT
213
+ else:
214
+ threshold = ThresholdConfig.DTYPE_PER_THD.get(
215
+ dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
216
+ )
217
+ return threshold