mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,293 +1,314 @@
1
- import os
2
- import inspect
3
- from dataclasses import dataclass
4
- from typing import Tuple, Dict, Optional, Any
5
- import numpy as np
6
- from msprobe.core.common.log import logger
7
- from msprobe.core.common.utils import convert_tuple
8
- from msprobe.core.common.const import Const
9
-
10
-
11
- @dataclass
12
- class ModuleForwardInputsOutputs:
13
- args: Optional[Tuple]
14
- kwargs: Optional[Dict]
15
- output: Any
16
-
17
- @property
18
- def args_tuple(self):
19
- return convert_tuple(self.args)
20
-
21
- @property
22
- def output_tuple(self):
23
- return convert_tuple(self.output)
24
-
25
- def concat_args_and_kwargs(self):
26
- args = self.args + tuple(self.kwargs.values())
27
- return args
28
-
29
-
30
- @dataclass
31
- class ModuleBackwardInputsOutputs:
32
- grad_output: Optional[Tuple]
33
- grad_input: Optional[Tuple]
34
-
35
- @property
36
- def grad_input_tuple(self):
37
- return convert_tuple(self.grad_input)
38
-
39
- @property
40
- def grad_output_tuple(self):
41
- return convert_tuple(self.grad_output)
42
-
43
-
44
- @dataclass
45
- class ModuleBackwardInputs:
46
- grad_input: Optional[Tuple]
47
-
48
- @property
49
- def grad_input_tuple(self):
50
- return convert_tuple(self.grad_input)
51
-
52
-
53
- @dataclass
54
- class ModuleBackwardOutputs:
55
- grad_output: Optional[Tuple]
56
-
57
- @property
58
- def grad_output_tuple(self):
59
- return convert_tuple(self.grad_output)
60
-
61
-
62
- class TensorStatInfo:
63
- def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
64
- self.max = max_val
65
- self.min = min_val
66
- self.mean = mean_val
67
- self.norm = norm_val
68
-
69
-
70
- class BaseDataProcessor:
71
- _recursive_key_stack = []
72
- special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
73
- bool, int, float, str, slice)
74
-
75
- def __init__(self, config, data_writer):
76
- self.data_writer = data_writer
77
- self.config = config
78
- self.api_info_struct = {}
79
- self.stack_info_struct = {}
80
- self.current_api_or_module_name = None
81
- self.api_data_category = None
82
- self.has_overflow = False
83
- self.current_iter = 0
84
- self._return_forward_new_output = False
85
- self._forward_new_output = None
86
-
87
- @property
88
- def data_path(self):
89
- return self.data_writer.dump_tensor_data_dir
90
-
91
- @property
92
- def is_terminated(self):
93
- return False
94
-
95
- @staticmethod
96
- def analyze_api_call_stack(name):
97
- stack_str = []
98
- for (_, path, line, func, code, _) in inspect.stack()[5:]:
99
- if not code:
100
- continue
101
- stack_line = " ".join([
102
- "File", ", ".join([
103
- path,
104
- " ".join(["line", str(line)]),
105
- " ".join(["in", func]),
106
- " ".join(["\n", code[0].strip()])
107
- ])
108
- ])
109
- stack_str.append(stack_line)
110
- stack_info_struct = {name: stack_str}
111
- return stack_info_struct
112
-
113
- @staticmethod
114
- def transfer_type(data):
115
- dtype = str(type(data))
116
- if 'int' in dtype:
117
- return int(data)
118
- elif 'float' in dtype:
119
- return float(data)
120
- else:
121
- return data
122
-
123
- @staticmethod
124
- def _convert_numpy_to_builtin(arg):
125
- type_mapping = {
126
- np.integer: int,
127
- np.floating: float,
128
- np.bool_: bool,
129
- np.complexfloating: complex,
130
- np.str_: str,
131
- np.byte: bytes,
132
- np.unicode_: str
133
- }
134
- for numpy_type, builtin_type in type_mapping.items():
135
- if isinstance(arg, numpy_type):
136
- return builtin_type(arg), type(arg).__name__
137
- return arg, ''
138
-
139
- @staticmethod
140
- def _analyze_numpy(value, numpy_type):
141
- return {"type": numpy_type, "value": value}
142
-
143
- @classmethod
144
- def get_special_types(cls):
145
- return cls.special_type
146
-
147
- @classmethod
148
- def recursive_apply_transform(cls, args, transform):
149
- if isinstance(args, cls.get_special_types()):
150
- arg_transform = transform(args, cls._recursive_key_stack)
151
- return arg_transform
152
- elif isinstance(args, (list, tuple)):
153
- result_list = []
154
- for i, arg in enumerate(args):
155
- cls._recursive_key_stack.append(str(i))
156
- result_list.append(cls.recursive_apply_transform(arg, transform))
157
- cls._recursive_key_stack.pop()
158
- return type(args)(result_list)
159
- elif isinstance(args, dict):
160
- resutl_dict = {}
161
- for k, arg in args.items():
162
- cls._recursive_key_stack.append(str(k))
163
- resutl_dict[k] = cls.recursive_apply_transform(arg, transform)
164
- cls._recursive_key_stack.pop()
165
- return resutl_dict
166
- elif args is not None:
167
- logger.warning(f"Data type {type(args)} is not supported.")
168
- return None
169
- else:
170
- return None
171
-
172
- def if_return_forward_new_output(self):
173
- return self._return_forward_new_output
174
-
175
- def get_forward_new_output(self):
176
- self._return_forward_new_output = False
177
- return self._forward_new_output
178
-
179
- def update_iter(self, current_iter):
180
- self.current_iter = current_iter
181
-
182
- def visit_and_clear_overflow_status(self, api_or_module_name):
183
- if self.current_api_or_module_name != api_or_module_name:
184
- self.current_api_or_module_name = api_or_module_name
185
- self.has_overflow = False
186
-
187
- def is_dump_for_data_mode(self, forward_backward, input_output):
188
- """
189
- Compare the parameters with data_mode to determine whether to dump.
190
-
191
- Args:
192
- forward_backward(str): The forward or backward mode to check.
193
- input_output(str): The input or output mode to check.
194
-
195
- Return:
196
- bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
197
- """
198
- return (Const.ALL in self.config.data_mode or
199
- forward_backward in self.config.data_mode or
200
- input_output in self.config.data_mode)
201
-
202
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
203
- pass
204
-
205
- def analyze_element(self, element):
206
- return self.recursive_apply_transform(element, self.analyze_single_element)
207
-
208
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
209
- api_info_struct = {}
210
- # check whether data_mode contains forward or input
211
- if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
212
- api_info_struct[name] = {}
213
- self.api_data_category = Const.INPUT
214
- args_info_list = self.analyze_element(module_input_output.args_tuple)
215
- api_info_struct[name][Const.INPUT_ARGS] = args_info_list
216
- self.api_data_category = Const.KWARGS
217
- kwargs_info_list = self.analyze_element(module_input_output.kwargs)
218
- api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
219
-
220
- # check whether data_mode contains forward or output
221
- if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
222
- api_info_struct[name] = api_info_struct.get(name, {})
223
- self.api_data_category = Const.OUTPUT
224
- output_info_list = self.analyze_element(module_input_output.output_tuple)
225
- api_info_struct[name][Const.OUTPUT] = output_info_list
226
- return api_info_struct
227
-
228
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
229
- api_info_struct = {}
230
- if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
231
- api_info_struct[name] = {}
232
- self.api_data_category = Const.INPUT
233
- args_info_list = self.analyze_element(module_input_output.args_tuple)
234
- api_info_struct[name][Const.INPUT_ARGS] = args_info_list
235
- self.api_data_category = Const.KWARGS
236
- kwargs_info_list = self.analyze_element(module_input_output.kwargs)
237
- api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
238
- return api_info_struct
239
-
240
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
241
- concat_args = module_input_output.concat_args_and_kwargs()
242
- api_info_struct = {}
243
- if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
244
- api_info_struct[name] = {}
245
- self.api_data_category = Const.OUTPUT
246
- output_info_list = self.analyze_element(concat_args)
247
- api_info_struct[name][Const.OUTPUT] = output_info_list
248
- return api_info_struct
249
-
250
- def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
251
- api_info_struct = {}
252
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
253
- api_info_struct[name] = {}
254
- self.api_data_category = Const.INPUT
255
- input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
256
- api_info_struct[name][Const.INPUT] = input_info_list
257
-
258
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
259
- api_info_struct[name] = api_info_struct.get(name, {})
260
- self.api_data_category = Const.OUTPUT
261
- output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
262
- api_info_struct[name][Const.OUTPUT] = output_info_list
263
-
264
- return api_info_struct
265
-
266
- def analyze_backward_input(self, name, module,
267
- module_input_output: ModuleBackwardInputs):
268
- api_info_struct = {}
269
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
270
- api_info_struct[name] = {}
271
- self.api_data_category = Const.INPUT
272
-
273
- input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
274
- api_info_struct[name][Const.INPUT] = input_info_list
275
- return api_info_struct
276
-
277
- def analyze_backward_output(self, name, module,
278
- module_input_output: ModuleBackwardOutputs):
279
- api_info_struct = {}
280
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
281
- api_info_struct[name] = {}
282
- self.api_data_category = Const.OUTPUT
283
-
284
- output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
285
- api_info_struct[name][Const.OUTPUT] = output_info_list
286
- return api_info_struct
287
-
288
- def get_save_file_path(self, suffix):
289
- file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
290
- dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
291
- suffix + file_format)
292
- file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
293
- return dump_data_name, file_path
1
+ import os
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, Dict, Optional, Any
5
+ import numpy as np
6
+ from msprobe.core.common.log import logger
7
+ from msprobe.core.common.utils import convert_tuple
8
+ from msprobe.core.common.const import Const
9
+
10
+
11
+ @dataclass
12
+ class ModuleForwardInputsOutputs:
13
+ args: Optional[Tuple]
14
+ kwargs: Optional[Dict]
15
+ output: Any
16
+
17
+ @property
18
+ def args_tuple(self):
19
+ return convert_tuple(self.args)
20
+
21
+ @property
22
+ def output_tuple(self):
23
+ return convert_tuple(self.output)
24
+
25
+ def concat_args_and_kwargs(self):
26
+ args = self.args + tuple(self.kwargs.values())
27
+ return args
28
+
29
+
30
+ @dataclass
31
+ class ModuleBackwardInputsOutputs:
32
+ grad_output: Optional[Tuple]
33
+ grad_input: Optional[Tuple]
34
+
35
+ @property
36
+ def grad_input_tuple(self):
37
+ return convert_tuple(self.grad_input)
38
+
39
+ @property
40
+ def grad_output_tuple(self):
41
+ return convert_tuple(self.grad_output)
42
+
43
+
44
+ @dataclass
45
+ class ModuleBackwardInputs:
46
+ grad_input: Optional[Tuple]
47
+
48
+ @property
49
+ def grad_input_tuple(self):
50
+ return convert_tuple(self.grad_input)
51
+
52
+
53
+ @dataclass
54
+ class ModuleBackwardOutputs:
55
+ grad_output: Optional[Tuple]
56
+
57
+ @property
58
+ def grad_output_tuple(self):
59
+ return convert_tuple(self.grad_output)
60
+
61
+
62
+ class TensorStatInfo:
63
+ def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
64
+ self.max = max_val
65
+ self.min = min_val
66
+ self.mean = mean_val
67
+ self.norm = norm_val
68
+
69
+
70
+ class BaseDataProcessor:
71
+ _recursive_key_stack = []
72
+ special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
73
+ bool, int, float, str, slice, type(Ellipsis))
74
+
75
+ def __init__(self, config, data_writer):
76
+ self.data_writer = data_writer
77
+ self.config = config
78
+ self.api_info_struct = {}
79
+ self.stack_info_struct = {}
80
+ self.current_api_or_module_name = None
81
+ self.api_data_category = None
82
+ self.current_iter = 0
83
+ self._return_forward_new_output = False
84
+ self._forward_new_output = None
85
+
86
+ @property
87
+ def data_path(self):
88
+ return self.data_writer.dump_tensor_data_dir
89
+
90
+ @property
91
+ def is_terminated(self):
92
+ return False
93
+
94
+ @staticmethod
95
+ def analyze_api_call_stack(name):
96
+ stack_str = []
97
+ for (_, path, line, func, code, _) in inspect.stack()[5:]:
98
+ if not code:
99
+ continue
100
+ stack_line = " ".join([
101
+ "File", ", ".join([
102
+ path,
103
+ " ".join(["line", str(line)]),
104
+ " ".join(["in", func]),
105
+ " ".join(["\n", code[0].strip()])
106
+ ])
107
+ ])
108
+ stack_str.append(stack_line)
109
+ stack_info_struct = {name: stack_str}
110
+ return stack_info_struct
111
+
112
+ @staticmethod
113
+ def transfer_type(data):
114
+ dtype = str(type(data))
115
+ if 'int' in dtype:
116
+ return int(data)
117
+ elif 'float' in dtype:
118
+ return float(data)
119
+ else:
120
+ return data
121
+
122
+ @staticmethod
123
+ def _convert_numpy_to_builtin(arg):
124
+ type_mapping = {
125
+ np.integer: int,
126
+ np.floating: float,
127
+ np.bool_: bool,
128
+ np.complexfloating: complex,
129
+ np.str_: str,
130
+ np.byte: bytes,
131
+ np.unicode_: str
132
+ }
133
+ for numpy_type, builtin_type in type_mapping.items():
134
+ if isinstance(arg, numpy_type):
135
+ return builtin_type(arg), type(arg).__name__
136
+ return arg, ''
137
+
138
+ @staticmethod
139
+ def _analyze_builtin(arg):
140
+ single_arg = {}
141
+ if isinstance(arg, slice):
142
+ # The slice parameter may be of the tensor, numpy or other types.
143
+ # It needs to be converted to the Python value type before JSON serialization
144
+ single_arg.update({"type": "slice"})
145
+ values = []
146
+ for value in [arg.start, arg.stop, arg.step]:
147
+ if value is not None:
148
+ try:
149
+ value = int(value)
150
+ except ValueError:
151
+ logger.warning(f"The data type {type(value)} cannot be converted to int type.")
152
+ value = None
153
+ values.append(value)
154
+ single_arg.update({"value": values})
155
+ else:
156
+ single_arg.update({"type": type(arg).__name__})
157
+ # When arg is Ellipsis(...) type, it needs to be converted to str("...") type
158
+ single_arg.update({"value": arg if arg is not Ellipsis else "..."})
159
+ return single_arg
160
+
161
+ @staticmethod
162
+ def _analyze_numpy(value, numpy_type):
163
+ return {"type": numpy_type, "value": value}
164
+
165
+ @classmethod
166
+ def get_special_types(cls):
167
+ return cls.special_type
168
+
169
+ @classmethod
170
+ def recursive_apply_transform(cls, args, transform):
171
+ if isinstance(args, cls.get_special_types()):
172
+ arg_transform = transform(args, cls._recursive_key_stack)
173
+ return arg_transform
174
+ elif isinstance(args, (list, tuple)):
175
+ result_list = []
176
+ for i, arg in enumerate(args):
177
+ cls._recursive_key_stack.append(str(i))
178
+ result_list.append(cls.recursive_apply_transform(arg, transform))
179
+ cls._recursive_key_stack.pop()
180
+ return type(args)(result_list)
181
+ elif isinstance(args, dict):
182
+ result_dict = {}
183
+ for k, arg in args.items():
184
+ cls._recursive_key_stack.append(str(k))
185
+ result_dict[k] = cls.recursive_apply_transform(arg, transform)
186
+ cls._recursive_key_stack.pop()
187
+ return result_dict
188
+ elif args is not None:
189
+ logger.warning(f"Data type {type(args)} is not supported.")
190
+ return None
191
+ else:
192
+ return None
193
+
194
+ def if_return_forward_new_output(self):
195
+ return self._return_forward_new_output
196
+
197
+ def get_forward_new_output(self):
198
+ self._return_forward_new_output = False
199
+ return self._forward_new_output
200
+
201
+ def update_iter(self, current_iter):
202
+ self.current_iter = current_iter
203
+
204
+ def update_api_or_module_name(self, api_or_module_name):
205
+ if self.current_api_or_module_name != api_or_module_name:
206
+ self.current_api_or_module_name = api_or_module_name
207
+
208
+ def is_dump_for_data_mode(self, forward_backward, input_output):
209
+ """
210
+ Compare the parameters with data_mode to determine whether to dump.
211
+
212
+ Args:
213
+ forward_backward(str): The forward or backward mode to check.
214
+ input_output(str): The input or output mode to check.
215
+
216
+ Return:
217
+ bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
218
+ """
219
+ return (Const.ALL in self.config.data_mode or
220
+ forward_backward in self.config.data_mode or
221
+ input_output in self.config.data_mode)
222
+
223
+ def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
224
+ pass
225
+
226
+ def analyze_element(self, element):
227
+ return self.recursive_apply_transform(element, self.analyze_single_element)
228
+
229
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
230
+ api_info_struct = {}
231
+ # check whether data_mode contains forward or input
232
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
233
+ api_info_struct[name] = {}
234
+ self.api_data_category = Const.INPUT
235
+ args_info_list = self.analyze_element(module_input_output.args_tuple)
236
+ api_info_struct[name][Const.INPUT_ARGS] = args_info_list
237
+ self.api_data_category = Const.KWARGS
238
+ kwargs_info_list = self.analyze_element(module_input_output.kwargs)
239
+ api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
240
+
241
+ # check whether data_mode contains forward or output
242
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
243
+ api_info_struct[name] = api_info_struct.get(name, {})
244
+ self.api_data_category = Const.OUTPUT
245
+ output_info_list = self.analyze_element(module_input_output.output_tuple)
246
+ api_info_struct[name][Const.OUTPUT] = output_info_list
247
+ return api_info_struct
248
+
249
+ def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
250
+ api_info_struct = {}
251
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
252
+ api_info_struct[name] = {}
253
+ self.api_data_category = Const.INPUT
254
+ args_info_list = self.analyze_element(module_input_output.args_tuple)
255
+ api_info_struct[name][Const.INPUT_ARGS] = args_info_list
256
+ self.api_data_category = Const.KWARGS
257
+ kwargs_info_list = self.analyze_element(module_input_output.kwargs)
258
+ api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
259
+ return api_info_struct
260
+
261
+ def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
262
+ concat_args = module_input_output.concat_args_and_kwargs()
263
+ api_info_struct = {}
264
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
265
+ api_info_struct[name] = {}
266
+ self.api_data_category = Const.OUTPUT
267
+ output_info_list = self.analyze_element(concat_args)
268
+ api_info_struct[name][Const.OUTPUT] = output_info_list
269
+ return api_info_struct
270
+
271
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
272
+ api_info_struct = {}
273
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
274
+ api_info_struct[name] = {}
275
+ self.api_data_category = Const.INPUT
276
+ input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
277
+ api_info_struct[name][Const.INPUT] = input_info_list
278
+
279
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
280
+ api_info_struct[name] = api_info_struct.get(name, {})
281
+ self.api_data_category = Const.OUTPUT
282
+ output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
283
+ api_info_struct[name][Const.OUTPUT] = output_info_list
284
+
285
+ return api_info_struct
286
+
287
+ def analyze_backward_input(self, name, module,
288
+ module_input_output: ModuleBackwardInputs):
289
+ api_info_struct = {}
290
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
291
+ api_info_struct[name] = {}
292
+ self.api_data_category = Const.INPUT
293
+
294
+ input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
295
+ api_info_struct[name][Const.INPUT] = input_info_list
296
+ return api_info_struct
297
+
298
+ def analyze_backward_output(self, name, module,
299
+ module_input_output: ModuleBackwardOutputs):
300
+ api_info_struct = {}
301
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
302
+ api_info_struct[name] = {}
303
+ self.api_data_category = Const.OUTPUT
304
+
305
+ output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
306
+ api_info_struct[name][Const.OUTPUT] = output_info_list
307
+ return api_info_struct
308
+
309
+ def get_save_file_path(self, suffix):
310
+ file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
311
+ dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
312
+ suffix + file_format)
313
+ file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
314
+ return dump_data_name, file_path