mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,33 +1,49 @@
1
- import os
2
- from msprobe.core.common.utils import CompareException, load_yaml
3
-
4
-
5
- class AtenIrMapping():
6
- def __init__(self):
7
- cur_path = os.path.dirname(os.path.realpath(__file__))
8
- yaml_path = os.path.join(cur_path, "mapping.yaml")
9
- self.aten_mapping = load_yaml(yaml_path)
10
-
11
- def match(self, op1, op2):
12
- if "Aten" in op1 and "Aten" not in op2:
13
- return self.match_op(op1, op2)
14
- else:
15
- return self.match_op(op2, op1)
16
-
17
- def match_op(self, aten_op, torch_op):
18
- try:
19
- aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
20
- aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
21
- torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
22
- except IndexError as e:
23
- err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
24
- raise CompareException.INVALID_DATA_ERROR(err_msg) from e
25
- matching_op = self.aten_mapping.get(aten_op_raw_name)
26
- if matching_op is None:
27
- return False
28
- if matching_op.lower() == torch_op_raw_name:
29
- return True
30
- return False
31
-
32
-
33
- graph_mapping = AtenIrMapping()
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from msprobe.core.common.utils import CompareException
18
+ from msprobe.core.common.file_utils import load_yaml
19
+
20
+
21
+ class AtenIrMapping():
22
+ def __init__(self):
23
+ cur_path = os.path.dirname(os.path.realpath(__file__))
24
+ yaml_path = os.path.join(cur_path, "mapping.yaml")
25
+ self.aten_mapping = load_yaml(yaml_path)
26
+
27
+ def match(self, op1, op2):
28
+ if "Aten" in op1 and "Aten" not in op2:
29
+ return self.match_op(op1, op2)
30
+ else:
31
+ return self.match_op(op2, op1)
32
+
33
+ def match_op(self, aten_op, torch_op):
34
+ try:
35
+ aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
36
+ aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
37
+ torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
38
+ except IndexError as e:
39
+ err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
40
+ raise CompareException.INVALID_DATA_ERROR(err_msg) from e
41
+ matching_op = self.aten_mapping.get(aten_op_raw_name)
42
+ if matching_op is None:
43
+ return False
44
+ if matching_op.lower() == torch_op_raw_name:
45
+ return True
46
+ return False
47
+
48
+
49
+ graph_mapping = AtenIrMapping()
@@ -1,40 +1,82 @@
1
- import os.path
2
- import torch
3
- from msprobe.core.common.const import FileCheckConst, Const
4
- from msprobe.core.common.log import logger
5
- from msprobe.core.common.exceptions import FileCheckException
6
- from msprobe.core.compare.acc_compare import Comparator
7
- from msprobe.core.common.utils import create_directory, check_configuration_param, task_dumppath_get, \
8
- check_compare_param, FileChecker
9
- from msprobe.core.common.utils import CompareException
10
-
11
-
12
- class PTComparator (Comparator):
13
- def __init__(self):
14
- self.frame_name = PTComparator.__name__
15
-
16
- def read_npy_data(self, dir_path, file_name):
17
- data_path = os.path.join(dir_path, file_name)
18
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
19
- FileCheckConst.PT_SUFFIX, False)
20
- data_path = path_checker.common_check()
21
- data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory
22
- if data_value.dtype == torch.bfloat16:
23
- data_value = data_value.to(torch.float32)
24
- data_value = data_value.numpy()
25
- return data_value
26
-
27
-
28
- def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
29
- try:
30
- summary_compare, md5_compare = task_dumppath_get(input_param)
31
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
32
- create_directory(output_path)
33
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
34
- except (CompareException, FileCheckException) as error:
35
- logger.error('Compare failed. Please check the arguments and do it again!')
36
- raise CompareException(error.code) from error
37
- pt_comparator = PTComparator()
38
- pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
39
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
40
- md5_compare=md5_compare)
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os.path
17
+ import torch
18
+ from msprobe.core.common.const import FileCheckConst
19
+ from msprobe.pytorch.common.log import logger
20
+ from msprobe.core.common.exceptions import FileCheckException
21
+ from msprobe.core.compare.acc_compare import Comparator
22
+ from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, \
23
+ CompareException
24
+ from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
25
+ from msprobe.pytorch.common.utils import load_pt
26
+
27
+
28
+ class PTComparator (Comparator):
29
+ def __init__(self, data_mapping=None):
30
+ self.frame_name = PTComparator.__name__
31
+ self.data_mapping = data_mapping
32
+ if isinstance(self.data_mapping, str) or self.data_mapping is None:
33
+ self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
34
+ elif isinstance(self.data_mapping, dict):
35
+ self.data_mapping_dict = self.data_mapping
36
+ else:
37
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
38
+ f"{type(self.data_mapping)}")
39
+
40
+ def load_mapping_file(self, mapping_file):
41
+ if isinstance(mapping_file, str):
42
+ mapping_dict = load_yaml(mapping_file)
43
+ else:
44
+ mapping_dict = {}
45
+ return mapping_dict
46
+
47
+ def read_npy_data(self, dir_path, file_name):
48
+ data_path = os.path.join(dir_path, file_name)
49
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
50
+ FileCheckConst.PT_SUFFIX, False)
51
+ data_path = path_checker.common_check()
52
+ try:
53
+ data_value = load_pt(data_path,
54
+ to_cpu=True).detach() # detach because numpy can not process gradient information
55
+ except RuntimeError as e:
56
+ # 这里捕获 load_pt 中抛出的异常
57
+ logger.error(f"Failed to load the .pt file at {data_path}.")
58
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from e
59
+ except AttributeError as e:
60
+ # 这里捕获 detach 方法抛出的异常
61
+ logger.error(f"Failed to detach the loaded tensor.")
62
+ raise CompareException(CompareException.DETACH_ERROR) from e
63
+ if data_value.dtype == torch.bfloat16:
64
+ data_value = data_value.to(torch.float32)
65
+ data_value = data_value.numpy()
66
+ return data_value
67
+
68
+
69
+ def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
70
+ try:
71
+ summary_compare, md5_compare = task_dumppath_get(input_param)
72
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
73
+ create_directory(output_path)
74
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
75
+ data_mapping = kwargs.get('data_mapping', None)
76
+ except (CompareException, FileCheckException) as error:
77
+ logger.error('Compare failed. Please check the arguments and do it again!')
78
+ raise CompareException(error.code) from error
79
+ pt_comparator = PTComparator(data_mapping)
80
+ pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
81
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
82
+ md5_compare=md5_compare)
@@ -1,95 +1,108 @@
1
- from msprobe.pytorch.common import seed_all
2
- from msprobe.pytorch.common.log import logger
3
- from msprobe.core.common.const import Const
4
-
5
-
6
- class DebuggerConfig:
7
- def __init__(self, common_config, task_config, task, dump_path, level):
8
- self.dump_path = dump_path if dump_path else common_config.dump_path
9
- self.task = task or common_config.task or Const.STATISTICS
10
- self.rank = common_config.rank if common_config.rank else []
11
- self.step = common_config.step if common_config.step else []
12
- self.level = level or common_config.level or "L1"
13
- self.seed = common_config.seed if common_config.seed else 1234
14
- self.is_deterministic = common_config.is_deterministic
15
- self.enable_dataloader = common_config.enable_dataloader
16
- self.scope = task_config.scope if task_config.scope else []
17
- self.list = task_config.list if task_config.list else []
18
- self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
19
- self.backward_input_list = task_config.backward_input if task_config.backward_input else []
20
- self.backward_input = {}
21
- self.acl_config = common_config.acl_config if common_config.acl_config else ""
22
- self.is_forward_acl_dump = True
23
- self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
24
- self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
25
- self.framework = Const.PT_FRAMEWORK
26
-
27
- if self.task == Const.FREE_BENCHMARK:
28
- self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
29
- self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
30
- self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
31
- self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
32
- self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
33
- self.preheat_config = {
34
- "if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
35
- "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
- "max_sample": task_config.max_sample if task_config.max_sample else 20,
37
- }
38
-
39
- self.online_run_ut = False
40
- if self.task == Const.TENSOR:
41
- # dump api tensor and collaborate with online run_ut
42
- self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
43
- self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
44
- self.tls_path = task_config.tls_path if task_config.tls_path else ""
45
- self.host = task_config.host if task_config.host else ""
46
- self.port = task_config.port if task_config.port else -1
47
-
48
- self.check()
49
- if self.step:
50
- self.step.sort()
51
- if self.level == "L2":
52
- if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
53
- raise ValueError("scope must be configured as a list with one api name")
54
- if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
55
- raise ValueError("backward_input must be configured when scope contains 'backward'")
56
- if Const.BACKWARD in self.scope[0]:
57
- self.is_forward_acl_dump = False
58
- for index, scope_spec in enumerate(self.scope):
59
- self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
60
- self.backward_input[self.scope[index]] = self.backward_input_list[index]
61
- seed_all(self.seed, self.is_deterministic)
62
-
63
- def check_kwargs(self):
64
- if self.task and self.task not in Const.TASK_LIST:
65
- raise Exception("task is invalid")
66
- if self.level and self.level not in Const.LEVEL_LIST:
67
- raise Exception("level is invalid")
68
- if not self.dump_path:
69
- raise Exception("Invalid dump path, please check your config")
70
-
71
- def check(self):
72
- self.check_kwargs()
73
- self._check_rank()
74
- self._check_step()
75
- return True
76
-
77
- def check_model(self, model):
78
- if self.level in ["L0", "mix"] and not model:
79
- raise Exception(
80
- f"For level {self.level}, PrecisionDebugger must receive a model argument."
81
- )
82
-
83
- def _check_rank(self):
84
- if self.rank:
85
- for rank_id in self.rank:
86
- if not isinstance(rank_id, int) or rank_id < 0:
87
- raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
88
- else:
89
- logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
90
-
91
- def _check_step(self):
92
- if self.step:
93
- for s in self.step:
94
- if not isinstance(s, int) or s < 0:
95
- raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+ from msprobe.pytorch.common.log import logger
21
+
22
+
23
+ class DebuggerConfig:
24
+ def __init__(self, common_config, task_config, task, dump_path, level):
25
+ self.dump_path = dump_path if dump_path else common_config.dump_path
26
+ self.task = task or common_config.task or Const.STATISTICS
27
+ self.rank = common_config.rank if common_config.rank else []
28
+ self.step = common_config.step if common_config.step else []
29
+ self.level = level or common_config.level or "L1"
30
+ self.enable_dataloader = common_config.enable_dataloader
31
+ self.scope = task_config.scope if task_config.scope else []
32
+ self.list = task_config.list if task_config.list else []
33
+ self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
34
+ self.backward_input_list = task_config.backward_input if task_config.backward_input else []
35
+ self.backward_input = {}
36
+ self.acl_config = common_config.acl_config if common_config.acl_config else ""
37
+ self.is_forward_acl_dump = True
38
+ self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
39
+ self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
40
+ self.framework = Const.PT_FRAMEWORK
41
+
42
+ if self.task == Const.FREE_BENCHMARK:
43
+ self.fuzz_device = task_config.fuzz_device
44
+ self.handler_type = task_config.handler_type
45
+ self.pert_mode = task_config.pert_mode
46
+ self.fuzz_level = task_config.fuzz_level
47
+ self.fuzz_stage = task_config.fuzz_stage
48
+ self.preheat_config = {
49
+ "if_preheat": task_config.if_preheat,
50
+ "preheat_step": task_config.preheat_step,
51
+ "max_sample": task_config.max_sample
52
+ }
53
+
54
+ self.online_run_ut = False
55
+ if self.task == Const.TENSOR:
56
+ # dump api tensor and collaborate with online run_ut
57
+ self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
58
+ self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
59
+ self.tls_path = task_config.tls_path if task_config.tls_path else ""
60
+ self.host = task_config.host if task_config.host else ""
61
+ self.port = task_config.port if task_config.port else -1
62
+
63
+ self.check()
64
+
65
+ if self.level == "L2":
66
+ if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
67
+ raise ValueError("scope must be configured as a list with one api name")
68
+ if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
69
+ raise ValueError("backward_input must be configured when scope contains 'backward'")
70
+ if Const.BACKWARD in self.scope[0]:
71
+ self.is_forward_acl_dump = False
72
+ for index, scope_spec in enumerate(self.scope):
73
+ self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
74
+ self.backward_input[self.scope[index]] = self.backward_input_list[index]
75
+
76
+ def check_kwargs(self):
77
+ if self.task and self.task not in Const.TASK_LIST:
78
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
79
+ f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
80
+ if self.level and self.level not in Const.LEVEL_LIST:
81
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
82
+ f"The level <{self.level}> is not in the {Const.LEVEL_LIST}.")
83
+ if not self.dump_path:
84
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
85
+ f"The dump_path not found.")
86
+
87
+ def check(self):
88
+ self.check_kwargs()
89
+ return True
90
+
91
+ def check_model(self, instance, start_model):
92
+ if self.level not in ["L0", "mix"]:
93
+ if instance.model is not None or start_model is not None:
94
+ logger.warning_on_rank_0(
95
+ f"The current level is not L0 or mix level, so the model parameters will not be used.")
96
+ return
97
+ if start_model is None:
98
+ if instance.model is None:
99
+ logger.error_on_rank_0(
100
+ f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
101
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
102
+ return
103
+ if isinstance(start_model, torch.nn.Module):
104
+ instance.model = start_model
105
+ else:
106
+ logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
107
+ raise MsprobeException(
108
+ MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")