mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,99 +0,0 @@
1
- # coding=utf-8
2
- import os
3
- import unittest
4
- import copy
5
-
6
- from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import *
7
- from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
8
-
9
- base_dir = os.path.dirname(os.path.realpath(__file__))
10
- forward_file = os.path.join(base_dir, "forward.json")
11
- forward_content = get_json_contents(forward_file)
12
- for key, value in forward_content.items():
13
- api_full_name = key
14
- api_info_dict = value
15
-
16
- max_value = 1.3945078125
17
- min_value = -1.444359375
18
-
19
-
20
- class TestDataGenerateMethods(unittest.TestCase):
21
- def test_gen_api_params(self):
22
- api_info = copy.deepcopy(api_info_dict)
23
- args_params, kwargs_params = gen_api_params(api_info, True, None, None)
24
- max_diff = abs(args_params[0].max() - max_value)
25
- min_diff = abs(args_params[0].min() - min_value)
26
- self.assertEqual(len(args_params), 2)
27
- self.assertEqual(args_params[0].dtype, torch.float16)
28
- self.assertEqual(args_params[1], 2)
29
- self.assertLessEqual(max_diff, 0.001)
30
- self.assertLessEqual(min_diff, 0.001)
31
- self.assertEqual(args_params[0].shape, torch.Size([2048, 2, 1, 256]))
32
- self.assertEqual(kwargs_params, {'dim': -1})
33
-
34
- def test_gen_args(self):
35
- args_result = gen_args(api_info_dict.get('input_args'), "conv2d")
36
- max_diff = abs(args_result[0].max() - max_value)
37
- min_diff = abs(args_result[0].min() - min_value)
38
- self.assertEqual(len(args_result), 2)
39
- self.assertEqual(args_result[0].dtype, torch.float16)
40
- self.assertLessEqual(max_diff, 0.001)
41
- self.assertLessEqual(min_diff, 0.001)
42
- self.assertEqual(args_result[0].shape, torch.Size([2048, 2, 1, 256]))
43
-
44
- def test_gen_data(self):
45
- data = gen_data(api_info_dict.get('input_args')[0], "conv2d", True, None, None)
46
- max_diff = abs(data.max() - max_value)
47
- min_diff = abs(data.min() - min_value)
48
- self.assertEqual(data.dtype, torch.float16)
49
- self.assertEqual(data.requires_grad, True)
50
- self.assertLessEqual(max_diff, 0.001)
51
- self.assertLessEqual(min_diff, 0.001)
52
- self.assertEqual(data.shape, torch.Size([2048, 2, 1, 256]))
53
-
54
- def test_gen_kwargs(self):
55
- api_info = copy.deepcopy(api_info_dict)
56
- kwargs_params = gen_kwargs(api_info, None)
57
- self.assertEqual(kwargs_params, {'dim': -1})
58
-
59
- def test_gen_kwargs_2(self):
60
- k_dict = {"inplace": {"type": "bool", "value": "False"}}
61
- for key, value in k_dict.items():
62
- gen_torch_kwargs(k_dict, key, value)
63
- self.assertEqual(k_dict, {'inplace': False})
64
-
65
- def test_gen_random_tensor(self):
66
- data = gen_random_tensor(api_info_dict.get('input_args')[0], None)
67
- max_diff = abs(data.max() - max_value)
68
- min_diff = abs(data.min() - min_value)
69
- self.assertEqual(data.dtype, torch.float16)
70
- self.assertEqual(data.requires_grad, False)
71
- self.assertLessEqual(max_diff, 0.001)
72
- self.assertLessEqual(min_diff, 0.001)
73
- self.assertEqual(data.shape, torch.Size([2048, 2, 1, 256]))
74
-
75
- def test_gen_common_tensor(self):
76
- info = api_info_dict.get('input_args')[0]
77
- low, high = info.get('Min'), info.get('Max')
78
- low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
79
- low_info = [low, low_origin]
80
- high_info = [high, high_origin]
81
- data_dtype = info.get('dtype')
82
- shape = tuple(info.get('shape'))
83
- data = gen_common_tensor(low_info, high_info, shape, data_dtype, None)
84
- max_diff = abs(data.max() - max_value)
85
- min_diff = abs(data.min() - min_value)
86
- self.assertEqual(data.dtype, torch.float16)
87
- self.assertEqual(data.requires_grad, False)
88
- self.assertLessEqual(max_diff, 0.001)
89
- self.assertLessEqual(min_diff, 0.001)
90
- self.assertEqual(data.shape, torch.Size([2048, 2, 1, 256]))
91
-
92
- def test_gen_bool_tensor(self):
93
- info = {"type": "torch.Tensor", "dtype": "torch.bool", "shape": [1, 1, 160, 256], "Max": 1, "Min": 0,
94
- "requires_grad": False}
95
- low, high = info.get("Min"), info.get("Max")
96
- shape = tuple(info.get("shape"))
97
- data = gen_bool_tensor(low, high, shape)
98
- self.assertEqual(data.shape, torch.Size([1, 1, 160, 256]))
99
- self.assertEqual(data.dtype, torch.bool)
@@ -1,115 +0,0 @@
1
- import os
2
- import glob
3
- import unittest
4
- import logging
5
- from unittest.mock import patch, mock_open, MagicMock
6
- import json
7
- import signal
8
- from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import split_json_file, signal_handler, run_parallel_ut, \
9
- prepare_config, main, ParallelUTConfig
10
-
11
-
12
- class TestMultiRunUT(unittest.TestCase):
13
-
14
- def setUp(self):
15
- self.test_json_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dump.json")
16
- self.test_data = {'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}}
17
- self.test_json_content = json.dumps(self.test_data)
18
- self.forward_split_files_content = [
19
- {'key1': 'TRUE', 'key2': 'TRUE'},
20
- {'key3': 'TRUE', 'key4': 'TRUE'}
21
- ]
22
-
23
- @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileOpen')
24
- def test_split_json_file(self, mock_FileOpen):
25
- mock_FileOpen.return_value.__enter__.return_value = mock_open(read_data=self.test_json_content).return_value
26
- num_splits = 2
27
- split_files, total_items = split_json_file(self.test_json_file, num_splits, False)
28
- self.assertEqual(len(split_files), num_splits)
29
- self.assertEqual(total_items, len(self.test_data.get('data')))
30
-
31
-
32
- @patch('subprocess.Popen')
33
- @patch('os.path.exists', return_value=True)
34
- @patch('builtins.open', new_callable=mock_open)
35
- @patch('json.load', side_effect=lambda f: {'key1': 'TRUE', 'key2': 'TRUE'})
36
- def test_run_parallel_ut(self, mock_json_load, mock_file, mock_exists, mock_popen):
37
- mock_process = MagicMock()
38
- mock_process.poll.side_effect = [None, None, 1]
39
- mock_process.stdout.readline.side_effect = ['[ERROR] Test Error Message\n', '']
40
- mock_popen.return_value = mock_process
41
-
42
- config = ParallelUTConfig(
43
- api_files=['test.json'],
44
- out_path='./',
45
- num_splits=2,
46
- save_error_data_flag=True,
47
- jit_compile_flag=False,
48
- device_id=[0, 1],
49
- result_csv_path='result.csv',
50
- total_items=2,
51
- real_data_path=None
52
- )
53
-
54
- mock_file.side_effect = [
55
- mock_open(read_data=json.dumps(self.forward_split_files_content[0])).return_value,
56
- mock_open(read_data=json.dumps(self.forward_split_files_content[1])).return_value
57
- ]
58
-
59
- run_parallel_ut(config)
60
-
61
- mock_popen.assert_called()
62
- mock_exists.assert_called()
63
-
64
- @patch('os.remove')
65
- @patch('os.path.realpath', side_effect=lambda x: x)
66
- @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_link')
67
- @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix')
68
- @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileChecker')
69
- @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.split_json_file',
70
- return_value=(['forward_split1.json', 'forward_split2.json'], 2))
71
- def test_prepare_config(self, mock_split_json_file, mock_FileChecker, mock_check_file_suffix, mock_check_link,
72
- mock_realpath, mock_remove):
73
- mock_FileChecker_instance = MagicMock()
74
- mock_FileChecker_instance.common_check.return_value = './'
75
- mock_FileChecker.return_value = mock_FileChecker_instance
76
- args = MagicMock()
77
- args.api_info = 'forward.json'
78
- args.out_path = './'
79
- args.num_splits = 2
80
- args.save_error_data = True
81
- args.jit_compile = False
82
- args.device_id = [0, 1]
83
- args.result_csv_path = None
84
- args.real_data_path = None
85
-
86
- config = prepare_config(args)
87
-
88
- self.assertEqual(config.num_splits, 2)
89
- self.assertTrue(config.save_error_data_flag)
90
- self.assertFalse(config.jit_compile_flag)
91
- self.assertEqual(config.device_id, [0, 1])
92
- self.assertEqual(config.total_items, 2)
93
-
94
-
95
- @patch('argparse.ArgumentParser.parse_args')
96
- @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.prepare_config')
97
- @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut')
98
- def test_main(self, mock_run_parallel_ut, mock_prepare_config, mock_parse_args):
99
- main()
100
- mock_parse_args.assert_called()
101
- mock_prepare_config.assert_called()
102
- mock_run_parallel_ut.assert_called()
103
-
104
- def tearDown(self):
105
- current_directory = os.getcwd()
106
- pattern = os.path.join(current_directory, 'accuracy_checking_*')
107
- files = glob.glob(pattern)
108
-
109
- for file in files:
110
- try:
111
- os.remove(file)
112
- logging.info(f"Deleted file: {file}")
113
- except Exception as e:
114
- logging.error(f"Failed to delete file {file}: {e}")
115
-
@@ -1,72 +0,0 @@
1
- # coding=utf-8
2
- import os
3
- import copy
4
- import unittest
5
- import torch
6
- from unittest.mock import patch, DEFAULT
7
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import *
8
- from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
9
-
10
- base_dir = os.path.dirname(os.path.realpath(__file__))
11
- forward_file = os.path.join(base_dir, "forward.json")
12
- forward_content = get_json_contents(forward_file)
13
- for api_full_name, api_info_dict in forward_content.items():
14
- api_full_name = api_full_name
15
- api_info_dict = api_info_dict
16
-
17
-
18
- class TestRunUtMethods(unittest.TestCase):
19
- def test_exec_api(self):
20
- api_info = copy.deepcopy(api_info_dict)
21
-
22
- [api_type, api_name, _, _] = api_full_name.split(".")
23
- args, kwargs, need_grad = get_api_info(api_info, api_name, None)
24
- cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '')
25
- out = exec_api(api_type, api_name, cpu_args, cpu_kwargs)
26
- self.assertEqual(out[0].dtype, torch.float32)
27
- self.assertTrue(out[0].requires_grad)
28
- self.assertEqual(out[0].shape, torch.Size([2048, 2, 1, 128]))
29
-
30
- def test_generate_device_params(self):
31
- mock_tensor = torch.rand([2, 2560, 24, 24], dtype=torch.float32, requires_grad=True)
32
-
33
- with patch.multiple('torch.Tensor',
34
- to=DEFAULT,
35
- clone=DEFAULT,
36
- detach=DEFAULT,
37
- requires_grad_=DEFAULT,
38
- type_as=DEFAULT,
39
- retain_grad=DEFAULT) as mocks:
40
- mocks['clone'].return_value = mock_tensor
41
- mocks['detach'].return_value = mock_tensor
42
- mocks['requires_grad_'].return_value = mock_tensor
43
- mocks['type_as'].return_value = mock_tensor
44
- mocks['retain_grad'].return_value = None
45
- mocks['to'].return_value = mock_tensor
46
-
47
- device_args, device_kwargs = generate_device_params([mock_tensor], {'inplace': False}, True, '')
48
- self.assertEqual(len(device_args), 1)
49
- self.assertEqual(device_args[0].dtype, torch.float32)
50
- self.assertTrue(device_args[0].requires_grad)
51
- self.assertEqual(device_args[0].shape, torch.Size([2, 2560, 24, 24]))
52
- self.assertEqual(device_kwargs, {'inplace': False})
53
-
54
- def test_generate_cpu_params(self):
55
- api_info = copy.deepcopy(api_info_dict)
56
- [api_type, api_name, _, _] = api_full_name.split(".")
57
- args, kwargs, need_grad = get_api_info(api_info, api_name, None)
58
- cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '')
59
- self.assertEqual(len(cpu_args), 2)
60
- self.assertEqual(cpu_args[0].dtype, torch.float32)
61
- self.assertTrue(cpu_args[0].requires_grad)
62
- self.assertEqual(cpu_args[0].shape, torch.Size([2048, 2, 1, 256]))
63
- self.assertEqual(cpu_kwargs, {'dim': -1})
64
-
65
- def test_UtDataInfo(self):
66
- data_info = UtDataInfo(None, None, None, None, None, None, None)
67
- self.assertIsNone(data_info.bench_grad)
68
- self.assertIsNone(data_info.device_grad)
69
- self.assertIsNone(data_info.device_output)
70
- self.assertIsNone(data_info.bench_output)
71
- self.assertIsNone(data_info.grad_in)
72
- self.assertIsNone(data_info.in_fwd_data_list)
@@ -1,17 +0,0 @@
1
- # coding=utf-8
2
- import unittest
3
- from msprobe.pytorch.compare.acc_compare import rename_api
4
-
5
- class TestUtilsMethods(unittest.TestCase):
6
-
7
- def test_rename_api(self):
8
- test_name_1 = "Distributed.broadcast.0.forward.input.0"
9
- expect_name_1 = "Distributed.broadcast.input.0"
10
- actual_name_1 = rename_api(test_name_1, "forward")
11
- self.assertEqual(actual_name_1, expect_name_1)
12
-
13
- test_name_2 = "Torch.sum.0.backward.output.0"
14
- expect_name_2 = "Torch.sum.output.0"
15
- actual_name_2 = rename_api(test_name_2, "backward")
16
- self.assertEqual(actual_name_2, expect_name_2)
17
-
@@ -1,105 +0,0 @@
1
- from unittest import TestCase
2
-
3
- import torch
4
- from msprobe.core.common.const import Const
5
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
6
- from msprobe.pytorch.free_benchmark.common.params import data_pre_deal
7
- from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
8
-
9
-
10
- class TestPerturbedLayer(TestCase):
11
-
12
- # 对输出精度和输入精度一致算子使用升精度扰动因子时, 输出结果的精度也会提升
13
- def test_improve_precision_layer_handle_with_out_dtype_changing(self):
14
- api_name = "Torch.mul.0.forward"
15
- x = torch.randn(2, 3, dtype=torch.float16)
16
- y = torch.randn(2, 3, dtype=torch.float16)
17
- out = torch.mul(x, y)
18
-
19
- data_params = data_pre_deal(api_name, torch.mul, (x, y), {})
20
- data_params.fuzz_stage = Const.FORWARD
21
- data_params.original_result = out
22
-
23
- layer = LayerFactory.create(
24
- api_name, DeviceType.NPU, PerturbationMode.IMPROVE_PRECISION
25
- )
26
- layer.handle(data_params)
27
- self.assertEqual(data_params.original_result.dtype, torch.float16)
28
- self.assertEqual(layer.perturbed_value, torch.float32)
29
- self.assertEqual(data_params.perturbed_result.dtype, torch.float32)
30
-
31
- # 对于可迭代类型的输入, 升精度方法会遍历其中元素对支持类型输入升精度
32
- def test_improve_precision_layer_with_iterable_inputs(self):
33
- api_name = "iterable.0.forward"
34
- tensor_a = torch.randn(2, 3, dtype=torch.bfloat16)
35
- tensor_b = torch.randn(2, 3, dtype=torch.float16)
36
- tensor_c = torch.randn(2, 3, dtype=torch.float32)
37
- tensor_d = torch.randn(2, 3, dtype=torch.float64)
38
- tensor_f = torch.randn(2, 3, dtype=torch.float64).to(torch.int32)
39
- inputs = [tensor_a, tensor_b, {"c": tensor_c, "d": tensor_d}, tensor_f]
40
-
41
- layer = LayerFactory.create(
42
- api_name, DeviceType.NPU, PerturbationMode.IMPROVE_PRECISION
43
- )
44
- Perturbed_value = layer.improve_tensor_precision(inputs)
45
- self.assertEqual(Perturbed_value[0].dtype, torch.float32)
46
- self.assertEqual(Perturbed_value[1].dtype, torch.float32)
47
- self.assertEqual(Perturbed_value[2]["c"].dtype, torch.float32)
48
- self.assertEqual(Perturbed_value[2]["d"].dtype, torch.float64)
49
- self.assertEqual(Perturbed_value[3].dtype, torch.int32)
50
-
51
- # no_change扰动因子不会改变输入
52
- def test_no_change_layer(self):
53
- api_name = "nochange.0.forward"
54
- inputs = torch.as_tensor([1e-9, 1e-2], dtype=torch.float32)
55
- layer = LayerFactory.create(
56
- api_name, DeviceType.NPU, PerturbationMode.NO_CHANGE
57
- )
58
- Perturbed_value = layer.no_change(inputs)
59
- self.assertEqual(Perturbed_value[0], 1e-9)
60
- self.assertEqual(Perturbed_value[1], 1e-2)
61
-
62
- # 对于一维二维张量,change_value扰动因子会交换首尾值的位置
63
- def test_change_value_layer(self):
64
- api_name = "change.0.forward"
65
- inputs_1dim = torch.as_tensor([1e-9, 1e-7, 1e-2], dtype=torch.float32)
66
- inputs_2dim = torch.as_tensor(
67
- [[1e-9, 1e-7, 1e-2], [1e-9, 1e-2, 1e-7]], dtype=torch.float32
68
- )
69
- layer = LayerFactory.create(
70
- api_name, DeviceType.NPU, PerturbationMode.CHANGE_VALUE
71
- )
72
- Perturbed_value_1dim = layer.change_value(inputs_1dim)
73
- layer.is_added = False
74
- Perturbed_value_2dim = layer.change_value(inputs_2dim)
75
- self.assertEqual(Perturbed_value_1dim[0], 1e-2)
76
- self.assertEqual(Perturbed_value_1dim[2], 1e-9)
77
- self.assertEqual(Perturbed_value_2dim[0][0], 1e-7)
78
- self.assertEqual(Perturbed_value_2dim[-1][-1], 1e-9)
79
-
80
- # 对于输入张量,bit_noise扰动因子对大于极小值的部分进行末尾比特翻转
81
- def test_bit_noise_layer(self):
82
- api_name = "bitnoise.0.forward"
83
- inputs = torch.as_tensor(
84
- [4096.00048828125, 16777216, 1e-38], dtype=torch.float32
85
- )
86
- layer = LayerFactory.create(
87
- api_name, DeviceType.NPU, PerturbationMode.BIT_NOISE
88
- )
89
- Perturbed_value = layer.add_bit_noise(inputs)
90
- self.assertEqual(Perturbed_value[0], 4096.0000000000)
91
- self.assertEqual(Perturbed_value[1], 16777218)
92
- self.assertEqual(Perturbed_value[2], 1e-38)
93
-
94
- # 对于输入张量,add_noise扰动因子对大于极小值的部分增加一个小值
95
- def test_add_noise_layer(self):
96
- api_name = "addnoise.0.forward"
97
- inputs = torch.as_tensor(
98
- [1e-1, 1e-2], dtype=torch.bfloat16
99
- )
100
- layer = LayerFactory.create(
101
- api_name, DeviceType.NPU, PerturbationMode.ADD_NOISE
102
- )
103
- Perturbed_value = layer.add_noise(inputs)
104
- self.assertEqual(Perturbed_value[0], 1e-1+1e-4)
105
- self.assertEqual(Perturbed_value[1], 1e-2)
@@ -1,121 +0,0 @@
1
- from abc import ABC
2
- from unittest import TestCase
3
-
4
- import torch
5
- from msprobe.core.common.const import Const
6
- from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig
7
- from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
8
- from msprobe.pytorch.free_benchmark.common.enums import (
9
- DeviceType,
10
- FuzzLevel,
11
- HandlerType,
12
- PerturbationMode,
13
- )
14
- from msprobe.pytorch.free_benchmark.common.params import DataParams, make_handler_params
15
- from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
16
- FuzzHandlerFactory,
17
- )
18
-
19
-
20
- class Config(ABC):
21
- """
22
- 用以提供参数配置
23
- """
24
- def __init__(self, handler_type, preheat_config):
25
- self.fuzz_stage = Const.FORWARD
26
- self.handler_type = handler_type
27
- self.fuzz_device = DeviceType.NPU
28
- self.fuzz_level = FuzzLevel.BASE_LEVEL
29
- self.pert_mode = PerturbationMode.IMPROVE_PRECISION
30
- self.preheat_config = preheat_config
31
-
32
-
33
- class TestFuzzHandler(TestCase):
34
-
35
- def setUp(self) -> None:
36
- origin_inputs = [
37
- torch.as_tensor([3.01, 3.02], dtype=torch.float16),
38
- torch.as_tensor([0.02, 0.02], dtype=torch.float16),
39
- ]
40
- # 将输入乘以一个大于误差阈值1.002的值,模拟二次执行出现误差
41
- perturbed_inputs = [
42
- (value * 1.0021).to(torch.float32).to("cpu") for value in origin_inputs
43
- ]
44
- origin_output = torch.add(*origin_inputs)
45
- perturbed_output = torch.add(*perturbed_inputs)
46
- # 实例有问题的data对象
47
- self.data_params = DataParams(
48
- args=origin_inputs,
49
- kwargs={},
50
- original_result=origin_output,
51
- perturbed_result=perturbed_output,
52
- origin_func=torch.add,
53
- )
54
- self.api_name = "add.0.forward"
55
- self.step = 0
56
-
57
- def test_result_handler_check(self):
58
- # 对于check处理类,扰动前后输出不一致的情况会有UnequalRow对象生成
59
- for _ in range(2):
60
- config = Config(
61
- HandlerType.CHECK, {PreheatConfig.IF_PREHEAT: False}
62
- )
63
- handler_params = make_handler_params(self.api_name, config, self.step)
64
- handler = FuzzHandlerFactory.create(handler_params)
65
- handler.handle(self.data_params)
66
- self.assertEqual(
67
- len(handler.get_unequal_rows()), 1
68
- )
69
-
70
- def test_result_handler_fix(self):
71
- # 对于fix处理类,扰动后输出会替代原始输出, dtype和原始输出一致,但值为新输出值
72
- config = Config(
73
- HandlerType.FIX, {PreheatConfig.IF_PREHEAT: False}
74
- )
75
- handler_params = make_handler_params(self.api_name, config, self.step)
76
- handler = FuzzHandlerFactory.create(handler_params)
77
- result = handler.handle(self.data_params)
78
- self.assertEqual(result.dtype, torch.float16)
79
- self.assertEqual(result.device, self.data_params.original_result.device)
80
- self.assertAlmostEqual(
81
- result[0], self.data_params.perturbed_result.to(torch.float16)[0]
82
- )
83
- self.assertAlmostEqual(
84
- result[1], self.data_params.perturbed_result.to(torch.float16)[1]
85
- )
86
-
87
- def test_result_handler_preheat(self):
88
- # 对于preheat处理类,在预热阶段后的阈值会根据CPU调整
89
- config = Config(
90
- HandlerType.CHECK,
91
- {
92
- PreheatConfig.IF_PREHEAT: True,
93
- PreheatConfig.PREHEAT_STEP: 4,
94
- PreheatConfig.MAX_SAMPLE: 3
95
- }
96
- )
97
- for _ in range(3):
98
- handler_params = make_handler_params(self.api_name, config, 0)
99
- handler = FuzzHandlerFactory.create(handler_params)
100
- handler.handle(self.data_params)
101
- # 通过preheat_counter的数据可以判断预热是否正常执行,这里第一个step会记录api执行次数
102
- self.assertEqual(preheat_counter.get_one_step_used_api("add"), 3)
103
- for step in range(1, 4):
104
- for _ in range(3):
105
- handler_params = make_handler_params(self.api_name, config, step)
106
- handler = FuzzHandlerFactory.create(handler_params)
107
- handler.handle(self.data_params)
108
- # call time记录当前step api的调用次数
109
- self.assertEqual(preheat_counter.get_api_called_time("add"), 3)
110
- # 对于3个step最多采样三次的预热设置,sample time应该每次采样一例
111
- self.assertEqual(preheat_counter.get_api_sample_time("add"), 1)
112
- # 预热阶段,api阈值应该在两个阈值超参之间
113
- api_threshld = preheat_counter.get_api_thd("add", "torch.float16")
114
- self.assertLessEqual(
115
- api_threshld,
116
- ThresholdConfig.PREHEAT_INITIAL_THD
117
- )
118
- self.assertGreaterEqual(
119
- api_threshld,
120
- ThresholdConfig.DTYPE_PER_THD[torch.float16]
121
- )
@@ -1,101 +0,0 @@
1
- import functools
2
- from abc import ABC
3
- from unittest import TestCase
4
-
5
- import torch
6
- import torch.nn as nn
7
- from msprobe.core.common.const import Const
8
- from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck
9
- from msprobe.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig
10
- from msprobe.pytorch.free_benchmark.common.enums import (
11
- DeviceType,
12
- FuzzLevel,
13
- HandlerType,
14
- PerturbationMode,
15
- )
16
-
17
-
18
- class Config(ABC):
19
- """
20
- 用以提供参数配置
21
- """
22
-
23
- def __init__(self, fuzz_stage, handler_type):
24
- self.fuzz_stage = fuzz_stage
25
- self.handler_type = handler_type
26
- self.fuzz_device = DeviceType.NPU
27
- self.fuzz_level = FuzzLevel.BASE_LEVEL
28
- self.pert_mode = PerturbationMode.IMPROVE_PRECISION
29
- self.preheat_config = {PreheatConfig.IF_PREHEAT: False}
30
-
31
-
32
- class WrapMul(nn.Module):
33
- """
34
- 用nn.module包装mul算子, 在forward中调用torch.mul
35
- """
36
-
37
- def __init__(self, op_name) -> None:
38
- super().__init__()
39
- self.op_name = op_name
40
-
41
- def forward(self, *args, **kwargs):
42
- return torch.mul(*args, **kwargs)
43
-
44
-
45
- class UnequalDataProcessor(ABC):
46
- """
47
- 接口类, 处理检测不一致结果
48
- """
49
-
50
- def __init__(self) -> None:
51
- super().__init__()
52
- self.unequal_rows = []
53
-
54
- def update_unequal_rows(self, unequal_rows):
55
- self.unequal_rows.append(unequal_rows)
56
-
57
-
58
- class TestInterface(TestCase):
59
- def setUp(self):
60
- self.api_name = "Torch.mul.0"
61
-
62
- def testForwardFix(self):
63
- # 对于前向接口,在forward钩子中开启FIX,返回结果给hook的输出
64
- config = Config(Const.FORWARD, HandlerType.FIX)
65
- checker = FreeBenchmarkCheck(config)
66
- # 执行算子前向
67
- x = torch.randn(2, 3).to(torch.float16)
68
- y = torch.randn(2, 3).to(torch.float16)
69
- mul_module = WrapMul(self.api_name)
70
- out = mul_module(x, y)
71
- # 模拟forward hook中调用无标杆前向检测接口
72
- result, _ = checker.forward(
73
- self.api_name,
74
- mul_module,
75
- args=(x, y),
76
- kwargs={},
77
- output=out,
78
- )
79
- self.assertEqual(result.dtype, torch.float32)
80
-
81
- def testBackwardCheck(self):
82
- # 对于反向接口,在pre forward时暂存input, 然后在backwrad后进行对比
83
- config = Config(Const.BACKWARD, HandlerType.CHECK)
84
- checker = FreeBenchmarkCheck(config)
85
- processor = UnequalDataProcessor()
86
- # 初始化输入输出
87
- x = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True)
88
- y = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True)
89
- grad_output = torch.tensor([1,1], dtype=torch.float16)
90
- backward_name = Const.SEP.join([self.api_name, Const.BACKWARD])
91
- # 执行前向生成grad saver实例
92
- mul_module = WrapMul(self.api_name)
93
- checker.pre_forward(backward_name, mul_module, processor, (x, y), {})
94
- # 执行算子前向和反向, 并反向获取扰动后grad_input
95
- out = mul_module(x, y)
96
- checker.backward(backward_name, mul_module, grad_output)
97
- out.backward(torch.ones_like(out))
98
- # module是否添加暂存器, 其中反向钩子执行扰动后grad_input是否正确
99
- self.assertTrue(hasattr(mul_module, CommonField.GRADSAVER))
100
- grad_saver = getattr(mul_module, CommonField.GRADSAVER)
101
- self.assertEqual(grad_saver.perturbed_grad_input[0][0], 2)
@@ -1,15 +0,0 @@
1
- import unittest
2
-
3
- import torch.nn as nn
4
- from msprobe.pytorch import PrecisionDebugger
5
- from msprobe.pytorch.functional.dump_module import module_dump, module_count
6
-
7
-
8
- class TestDumpModule(unittest.TestCase):
9
- def setUp(self):
10
- self.module = nn.Linear(in_features=8, out_features=4)
11
-
12
- def test_module_dump(self):
13
- PrecisionDebugger(dump_path="./dump")
14
- module_dump(self.module, "TestModule")
15
- self.assertTrue("TestModule" in module_count)