mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,42 +1,57 @@
1
- from msprobe.mindspore.free_benchmark.common.config import Config
2
- from msprobe.mindspore.common.const import FreeBenchmarkConst
3
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
4
- from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
5
- from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
6
-
7
-
8
- class ForwardSelfChecker:
9
-
10
- def __init__(self, api_name: str):
11
- self.api_name = api_name
12
-
13
- def handle(self, params: HandlerParams):
14
- """
15
- 装饰器实际执行逻辑
16
-
17
- """
18
- perturbation = PerturbationFactory.create(self.api_name)
19
- params.fuzzed_result = perturbation.handle(params)
20
- params.original_result = params.original_func(*params.args, **params.kwargs)
21
- if params.fuzzed_result is not False:
22
- return self.deal_fuzzed_and_original_result(params)
23
- return params.original_result
24
-
25
- def get_compare_data(self, params: HandlerParams):
26
- if self.api_name not in FreeBenchmarkConst.COMMUNICATION_API_LIST:
27
- return
28
- # 以下为通讯类api处理逻辑
29
- params.fuzzed_result = params.fuzzed_value
30
- if Config.pert_type == FreeBenchmarkConst.IMPROVE_PRECISION:
31
- params.original_result = params.args
32
- else:
33
- params.original_result = params.args[params.index]
34
-
35
- def deal_fuzzed_and_original_result(self, params: HandlerParams):
36
- original_result = params.original_result
37
- self.get_compare_data(params)
38
- handler = HandlerFactory.create(self.api_name)
39
- result = handler.handle(params)
40
- if self.api_name in FreeBenchmarkConst.COMMUNICATION_API_LIST:
41
- result = original_result
42
- return result
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.common.const import Const, FreeBenchmarkConst
17
+ from msprobe.mindspore.free_benchmark.common.config import Config
18
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
19
+ from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
20
+ from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
21
+
22
+
23
+ class ForwardSelfChecker:
24
+
25
+ def __init__(self, api_name: str):
26
+ self.api_name = api_name
27
+
28
+ def handle(self, params: HandlerParams):
29
+ """
30
+ 装饰器实际执行逻辑
31
+
32
+ """
33
+ perturbation = PerturbationFactory.create(self.api_name)
34
+ params.fuzzed_result = perturbation.handle(params)
35
+ params.original_result = params.original_func(*params.args, **params.kwargs)
36
+ if params.fuzzed_result is not False:
37
+ return self.deal_fuzzed_and_original_result(params)
38
+ return params.original_result
39
+
40
+ def get_compare_data(self, params: HandlerParams):
41
+ if self.api_name not in Const.COMMUNICATION_API_LIST:
42
+ return
43
+ # 以下为通讯类api处理逻辑
44
+ params.fuzzed_result = params.fuzzed_value
45
+ if Config.pert_type == FreeBenchmarkConst.IMPROVE_PRECISION:
46
+ params.original_result = params.args
47
+ else:
48
+ params.original_result = params.args[params.index]
49
+
50
+ def deal_fuzzed_and_original_result(self, params: HandlerParams):
51
+ original_result = params.original_result
52
+ self.get_compare_data(params)
53
+ handler = HandlerFactory.create(self.api_name)
54
+ result = handler.handle(params)
55
+ if self.api_name in Const.COMMUNICATION_API_LIST:
56
+ result = original_result
57
+ return result
@@ -1,107 +1,122 @@
1
- import os
2
- import sys
3
- import traceback
4
- from functools import wraps
5
- from typing import Tuple, Dict, List
6
-
7
- from mindspore import ops
8
-
9
- from msprobe.mindspore.runtime import Runtime
10
- from msprobe.mindspore.common.log import logger
11
- from msprobe.mindspore.free_benchmark.common.config import Config
12
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
13
- from .dec_forward import ForwardSelfChecker
14
-
15
-
16
- def decorate(original_func, decorate_func, api_name=None):
17
- """
18
- 总装饰器
19
- """
20
- @wraps(original_func)
21
- def fuzz_wrapper(*args, **kwargs):
22
-
23
- def __exec_decorate_func():
24
- params = data_pre_deal(api_name, original_func, *args, **kwargs)
25
- result = decorate_func(params)
26
- return result
27
-
28
- try:
29
- if Runtime.rank_id == -1:
30
- Runtime.rank_id = os.environ.get("RANK_ID", -1)
31
- if need_wrapper_func():
32
- logger.info(f"[{api_name}] is checking.")
33
- return __exec_decorate_func()
34
- except Exception as e:
35
- logger.error(f"[{api_name}] Error: {str(e)}")
36
- logger.error(f"[{api_name}] Error detail: {traceback.format_exc()}")
37
-
38
- return original_func(*args, **kwargs)
39
-
40
- return fuzz_wrapper
41
-
42
-
43
- def decorate_forward_function(func, api_name=None):
44
- """
45
- 前向装饰器
46
- """
47
-
48
- if not api_name:
49
- api_name = func.__name__
50
-
51
- def forward_func(params: HandlerParams):
52
- forward = ForwardSelfChecker(api_name)
53
- result = forward.handle(params)
54
- return result
55
-
56
- return decorate(func, forward_func, api_name)
57
-
58
-
59
- def stack_depth_check() -> bool:
60
- nested_depth = 1
61
- frame = sys._getframe(1)
62
- while frame:
63
- if frame.f_code.co_name == "fuzz_wrapper":
64
- nested_depth -= 1
65
- if nested_depth < 0:
66
- return False
67
- frame = frame.f_back
68
- return True
69
-
70
-
71
- def get_target_arg_index(args: Tuple) -> int:
72
- """
73
- 类型校验
74
-
75
- """
76
- for i, arg in enumerate(args):
77
- if ops.is_tensor(arg):
78
- if not ops.is_floating_point(arg):
79
- continue
80
- return i
81
- if isinstance(arg, (List, Tuple, Dict)):
82
- return i
83
- return -1
84
-
85
-
86
- def data_pre_deal(api_name, func, *args, **kwargs):
87
- params = HandlerParams()
88
- params.args = args
89
- params.kwargs = kwargs
90
- params.original_func = func
91
- index = get_target_arg_index(args)
92
- if index == -1:
93
- raise Exception(f"{api_name} has no supported input type")
94
- params.index = index
95
- return params
96
-
97
-
98
- def need_wrapper_func():
99
- if not (Runtime.is_running and Config.is_enable):
100
- return False
101
- if not stack_depth_check():
102
- return False
103
- if Config.steps and Runtime.step_count not in Config.steps:
104
- return False
105
- if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
106
- return False
107
- return True
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import sys
18
+ import traceback
19
+ from functools import wraps
20
+ from typing import Dict, List, Tuple
21
+
22
+ from mindspore import ops
23
+
24
+ from msprobe.mindspore.common.log import logger
25
+ from msprobe.mindspore.free_benchmark.common.config import Config
26
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
27
+ from msprobe.mindspore.free_benchmark.decorator.dec_forward import ForwardSelfChecker
28
+ from msprobe.mindspore.runtime import Runtime
29
+
30
+
31
+ def decorate(original_func, decorate_func, api_name=None):
32
+ """
33
+ 总装饰器
34
+ """
35
+ @wraps(original_func)
36
+ def fuzz_wrapper(*args, **kwargs):
37
+
38
+ def __exec_decorate_func():
39
+ params = data_pre_deal(api_name, original_func, *args, **kwargs)
40
+ result = decorate_func(params)
41
+ return result
42
+
43
+ try:
44
+ if Runtime.rank_id == -1:
45
+ Runtime.rank_id = os.environ.get("RANK_ID", -1)
46
+ if need_wrapper_func():
47
+ logger.info(f"[{api_name}] is checking.")
48
+ return __exec_decorate_func()
49
+ except Exception as e:
50
+ logger.error(f"[{api_name}] Error: {str(e)}")
51
+ logger.error(f"[{api_name}] Error detail: {traceback.format_exc()}")
52
+
53
+ return original_func(*args, **kwargs)
54
+
55
+ return fuzz_wrapper
56
+
57
+
58
+ def decorate_forward_function(func, api_name=None):
59
+ """
60
+ 前向装饰器
61
+ """
62
+
63
+ if not api_name:
64
+ api_name = func.__name__
65
+
66
+ def forward_func(params: HandlerParams):
67
+ forward = ForwardSelfChecker(api_name)
68
+ result = forward.handle(params)
69
+ return result
70
+
71
+ return decorate(func, forward_func, api_name)
72
+
73
+
74
+ def stack_depth_check() -> bool:
75
+ nested_depth = 1
76
+ frame = sys._getframe(1)
77
+ while frame:
78
+ if frame.f_code.co_name == "fuzz_wrapper":
79
+ nested_depth -= 1
80
+ if nested_depth < 0:
81
+ return False
82
+ frame = frame.f_back
83
+ return True
84
+
85
+
86
+ def get_target_arg_index(args: Tuple) -> int:
87
+ """
88
+ 类型校验
89
+
90
+ """
91
+ for i, arg in enumerate(args):
92
+ if ops.is_tensor(arg):
93
+ if not ops.is_floating_point(arg):
94
+ continue
95
+ return i
96
+ if isinstance(arg, (List, Tuple, Dict)):
97
+ return i
98
+ return -1
99
+
100
+
101
+ def data_pre_deal(api_name, func, *args, **kwargs):
102
+ params = HandlerParams()
103
+ params.args = args
104
+ params.kwargs = kwargs
105
+ params.original_func = func
106
+ index = get_target_arg_index(args)
107
+ if index == -1:
108
+ raise Exception(f"{api_name} has no supported input type")
109
+ params.index = index
110
+ return params
111
+
112
+
113
+ def need_wrapper_func():
114
+ if not (Runtime.is_running and Config.is_enable):
115
+ return False
116
+ if not stack_depth_check():
117
+ return False
118
+ if Config.steps and Runtime.step_count not in Config.steps:
119
+ return False
120
+ if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
121
+ return False
122
+ return True
@@ -1,90 +1,105 @@
1
- import math
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Tuple, Optional
4
-
5
- import mindspore as ms
6
- from mindspore import Tensor, ops
7
-
8
- from msprobe.mindspore.common.log import logger
9
- from msprobe.mindspore.free_benchmark.common.utils import Tools
10
- from msprobe.mindspore.common.const import FreeBenchmarkConst
11
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
12
-
13
-
14
- class BaseHandler(ABC):
15
-
16
- def __init__(self, api_name: str):
17
- self.api_name = api_name
18
-
19
- @staticmethod
20
- def pre_calculate(original_output, fuzzed_output):
21
- abs_tol = FreeBenchmarkConst.PERT_VALUE_DICT.get(fuzzed_output.dtype,
22
- FreeBenchmarkConst.PERT_VALUE_DICT.get(ms.float32))
23
-
24
- return original_output.to(fuzzed_output.dtype), fuzzed_output, abs_tol
25
-
26
- @staticmethod
27
- def get_threshold(dtype):
28
- err = Tools.get_default_error_threshold(dtype)
29
- return err
30
-
31
- @staticmethod
32
- def convert_overflow_ratio_to_consistent(ratio):
33
- if math.isnan(ratio) or math.isinf(ratio):
34
- return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
35
- return ratio
36
-
37
- @staticmethod
38
- def get_endless_norm(first_tensor, second_tensor, abs_tol):
39
- if first_tensor.dtype != ms.bfloat16 and second_tensor.dtype != ms.bfloat16:
40
- ratio_tensor1 = ops.where(ops.abs(second_tensor) > abs_tol, ops.div(first_tensor, second_tensor), 1)
41
- ratio_tensor2 = ops.where(ops.abs(first_tensor) > abs_tol, ops.div(second_tensor, first_tensor), 1)
42
- else:
43
- ratio_tensor1 = ops.where(ops.abs(second_tensor).to(ms.float32) > abs_tol,
44
- ops.div(first_tensor.to(ms.float32), second_tensor.to(ms.float32)), 1)
45
- ratio_tensor2 = ops.where(ops.abs(first_tensor).to(ms.float32) > abs_tol,
46
- ops.div(second_tensor.to(ms.float32), first_tensor.to(ms.float32)), 1)
47
- norm1 = BaseHandler.convert_overflow_ratio_to_consistent(ops.max(ratio_tensor1)[0].to(ms.float32).item())
48
- norm2 = BaseHandler.convert_overflow_ratio_to_consistent(ops.max(ratio_tensor2)[0].to(ms.float32).item())
49
- norm3 = BaseHandler.convert_overflow_ratio_to_consistent(ops.min(ratio_tensor1)[0].to(ms.float32).item())
50
- ratio = FreeBenchmarkConst.SYMBOL_FLIPPING_RATIO if norm3 < 0 else max(norm1, norm2)
51
-
52
- return ratio
53
-
54
- @staticmethod
55
- def ratio_calculate(original_output, fuzzed_output) -> float:
56
- try:
57
- original_output, fuzzed_output, abs_tol = BaseHandler.pre_calculate(original_output, fuzzed_output)
58
- except Exception as e:
59
- logger.error(f"When computing ratio, y1 or y2 dtype is not supported {str(e)}")
60
- return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
61
-
62
- abs_tol = abs_tol ** 0.5
63
-
64
- return BaseHandler.get_endless_norm(original_output, fuzzed_output, abs_tol)
65
-
66
- @staticmethod
67
- def npu_compare(original_output, fuzzed_output) -> Tuple[bool, Optional[float]]:
68
- if not isinstance(fuzzed_output, Tensor):
69
- logger.error(f"The compare for output type `{type(fuzzed_output)}` is not supported")
70
- return True, 1.0
71
-
72
- # 范数计算等
73
- err_thd = BaseHandler.get_threshold(original_output.dtype)
74
- ratio = BaseHandler.ratio_calculate(original_output, fuzzed_output)
75
- is_consistent = err_thd >= ratio >= 1.0 / err_thd
76
- return is_consistent, ratio
77
-
78
- @staticmethod
79
- def is_float_tensor(output) -> bool:
80
- if isinstance(output, Tensor) and ops.is_floating_point(output):
81
- return True
82
- if isinstance(output, (list, tuple)):
83
- for i in output:
84
- if isinstance(i, Tensor) and ops.is_floating_point(i):
85
- return True
86
- return False
87
-
88
- @abstractmethod
89
- def handle(self, params: HandlerParams) -> Any:
90
- pass
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+ from typing import Any, Optional, Tuple
19
+
20
+ import mindspore as ms
21
+ from mindspore import Tensor, ops
22
+
23
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
24
+ from msprobe.mindspore.common.log import logger
25
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
26
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
27
+
28
+
29
+ class BaseHandler(ABC):
30
+
31
+ def __init__(self, api_name: str):
32
+ self.api_name = api_name
33
+
34
+ @staticmethod
35
+ def pre_calculate(original_output, fuzzed_output):
36
+ abs_tol = FreeBenchmarkConst.PERT_VALUE_DICT.get(fuzzed_output.dtype,
37
+ FreeBenchmarkConst.PERT_VALUE_DICT.get(ms.float32))
38
+
39
+ return original_output.to(fuzzed_output.dtype), fuzzed_output, abs_tol
40
+
41
+ @staticmethod
42
+ def get_threshold(dtype):
43
+ err = Tools.get_default_error_threshold(dtype)
44
+ return err
45
+
46
+ @staticmethod
47
+ def convert_overflow_ratio_to_consistent(ratio):
48
+ if math.isnan(ratio) or math.isinf(ratio):
49
+ return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
50
+ return ratio
51
+
52
+ @staticmethod
53
+ def get_endless_norm(first_tensor, second_tensor, abs_tol):
54
+ if first_tensor.dtype != ms.bfloat16 and second_tensor.dtype != ms.bfloat16:
55
+ ratio_tensor1 = ops.where(ops.abs(second_tensor) > abs_tol, ops.div(first_tensor, second_tensor), 1)
56
+ ratio_tensor2 = ops.where(ops.abs(first_tensor) > abs_tol, ops.div(second_tensor, first_tensor), 1)
57
+ else:
58
+ ratio_tensor1 = ops.where(ops.abs(second_tensor).to(ms.float32) > abs_tol,
59
+ ops.div(first_tensor.to(ms.float32), second_tensor.to(ms.float32)), 1)
60
+ ratio_tensor2 = ops.where(ops.abs(first_tensor).to(ms.float32) > abs_tol,
61
+ ops.div(second_tensor.to(ms.float32), first_tensor.to(ms.float32)), 1)
62
+ norm1 = BaseHandler.convert_overflow_ratio_to_consistent(ops.max(ratio_tensor1)[0].to(ms.float32).item())
63
+ norm2 = BaseHandler.convert_overflow_ratio_to_consistent(ops.max(ratio_tensor2)[0].to(ms.float32).item())
64
+ norm3 = BaseHandler.convert_overflow_ratio_to_consistent(ops.min(ratio_tensor1)[0].to(ms.float32).item())
65
+ ratio = FreeBenchmarkConst.SYMBOL_FLIPPING_RATIO if norm3 < 0 else max(norm1, norm2)
66
+
67
+ return ratio
68
+
69
+ @staticmethod
70
+ def ratio_calculate(original_output, fuzzed_output) -> float:
71
+ try:
72
+ original_output, fuzzed_output, abs_tol = BaseHandler.pre_calculate(original_output, fuzzed_output)
73
+ except Exception as e:
74
+ logger.error(f"When computing ratio, y1 or y2 dtype is not supported {str(e)}")
75
+ return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
76
+
77
+ abs_tol = abs_tol ** 0.5
78
+
79
+ return BaseHandler.get_endless_norm(original_output, fuzzed_output, abs_tol)
80
+
81
+ @staticmethod
82
+ def npu_compare(original_output, fuzzed_output) -> Tuple[bool, Optional[float]]:
83
+ if not isinstance(fuzzed_output, Tensor):
84
+ logger.error(f"The compare for output type `{type(fuzzed_output)}` is not supported")
85
+ return True, 1.0
86
+
87
+ # 范数计算等
88
+ err_thd = BaseHandler.get_threshold(original_output.dtype)
89
+ ratio = BaseHandler.ratio_calculate(original_output, fuzzed_output)
90
+ is_consistent = err_thd >= ratio >= 1.0 / err_thd
91
+ return is_consistent, ratio
92
+
93
+ @staticmethod
94
+ def is_float_tensor(output) -> bool:
95
+ if isinstance(output, Tensor) and ops.is_floating_point(output):
96
+ return True
97
+ if isinstance(output, (list, tuple)):
98
+ for i in output:
99
+ if isinstance(i, Tensor) and ops.is_floating_point(i):
100
+ return True
101
+ return False
102
+
103
+ @abstractmethod
104
+ def handle(self, params: HandlerParams) -> Any:
105
+ pass
@@ -1,41 +1,56 @@
1
- from typing import Any
2
- from dataclasses import asdict
3
-
4
- from mindspore import Tensor, ops
5
-
6
- from msprobe.mindspore.common.log import logger
7
- from msprobe.mindspore.free_benchmark.common.config import Config
8
- from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler
9
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
10
- from msprobe.mindspore.free_benchmark.common.utils import make_unequal_row
11
- from msprobe.core.data_dump.json_writer import DataWriter
12
-
13
-
14
- class CheckHandler(BaseHandler):
15
-
16
- def npu_compare_and_save(self, original_output, fuzzed_output, params: HandlerParams, output_index=None):
17
- is_consistent, ratio = self.npu_compare(original_output, fuzzed_output)
18
- params.is_consistent = params.is_consistent and is_consistent
19
- if not is_consistent:
20
- row = make_unequal_row(self.api_name, params, ratio, output_index)
21
- data_dict = asdict(row)
22
- DataWriter.write_data_to_csv(
23
- data_dict.values(),
24
- data_dict.keys(),
25
- Config.dump_path
26
- )
27
- logger.error(f"{self.api_name} is not consistent")
28
-
29
- def handle(self, params: HandlerParams) -> Any:
30
- try:
31
- if not self.is_float_tensor(params.fuzzed_result):
32
- return params.original_result
33
- if isinstance(params.fuzzed_result, Tensor):
34
- self.npu_compare_and_save(params.original_result, params.fuzzed_result, params)
35
- elif isinstance(params.fuzzed_result, (list, tuple)):
36
- for i, item in enumerate(params.original_result):
37
- if ops.is_tensor(item) and ops.is_floating_point(item):
38
- self.npu_compare_and_save(item, params.fuzzed_result[i], params, output_index=i)
39
- except Exception as e:
40
- logger.error(str(e))
41
- return params.original_result
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import asdict
17
+ from typing import Any
18
+
19
+ from mindspore import Tensor, ops
20
+
21
+ from msprobe.core.data_dump.json_writer import DataWriter
22
+ from msprobe.mindspore.common.log import logger
23
+ from msprobe.mindspore.free_benchmark.common.config import Config
24
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
+ from msprobe.mindspore.free_benchmark.common.utils import make_unequal_row
26
+ from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler
27
+
28
+
29
+ class CheckHandler(BaseHandler):
30
+
31
+ def npu_compare_and_save(self, original_output, fuzzed_output, params: HandlerParams, output_index=None):
32
+ is_consistent, ratio = self.npu_compare(original_output, fuzzed_output)
33
+ params.is_consistent = params.is_consistent and is_consistent
34
+ if not is_consistent:
35
+ row = make_unequal_row(self.api_name, params, ratio, output_index)
36
+ data_dict = asdict(row)
37
+ DataWriter.write_data_to_csv(
38
+ data_dict.values(),
39
+ data_dict.keys(),
40
+ Config.dump_path
41
+ )
42
+ logger.error(f"{self.api_name} is not consistent")
43
+
44
+ def handle(self, params: HandlerParams) -> Any:
45
+ try:
46
+ if not self.is_float_tensor(params.fuzzed_result):
47
+ return params.original_result
48
+ if isinstance(params.fuzzed_result, Tensor):
49
+ self.npu_compare_and_save(params.original_result, params.fuzzed_result, params)
50
+ elif isinstance(params.fuzzed_result, (list, tuple)):
51
+ for i, item in enumerate(params.original_result):
52
+ if ops.is_tensor(item) and ops.is_floating_point(item):
53
+ self.npu_compare_and_save(item, params.fuzzed_result[i], params, output_index=i)
54
+ except Exception as e:
55
+ logger.error(str(e))
56
+ return params.original_result