mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,140 @@
1
+
2
+ import os
3
+
4
+ from msprobe.core.data_dump.scope import build_scope, ListScope
5
+ from msprobe.core.data_dump.json_writer import DataWriter
6
+ from msprobe.core.common.log import logger
7
+ from msprobe.core.common.const import Const
8
+ from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
9
+
10
+
11
+ def build_data_collector(config):
12
+ return DataCollector(config)
13
+
14
+
15
+ class DataCollector:
16
+ multi_output_apis = ["_sort_", "npu_flash_attention"]
17
+ tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
18
+ level_without_construct = ["L1", "L2"]
19
+
20
+ def __init__(self, config):
21
+ self.config = config
22
+ self.data_writer = DataWriter()
23
+ self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
24
+ self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) if self.config.framework == Const.PT_FRAMEWORK else None
25
+ self.module_count = {}
26
+ if self.config.task == Const.FREE_BENCHMARK:
27
+ self.scope = build_scope(ListScope, self.config.scope, self.config.list)
28
+ else:
29
+ self.scope = build_scope(None, self.config.scope, self.config.list)
30
+
31
+ @property
32
+ def dump_data_dir(self):
33
+ return self.data_writer.dump_tensor_data_dir
34
+
35
+ @property
36
+ def dump_file_path(self):
37
+ return self.data_writer.dump_file_path
38
+
39
+ @staticmethod
40
+ def check_scope_and_pid(scope, name, pid):
41
+ return (not scope or scope.check(name)) and pid == os.getpid()
42
+
43
+ @staticmethod
44
+ def is_inplace(module):
45
+ return getattr(module, "op_is_inplace", False)
46
+
47
+ def if_return_forward_new_output(self):
48
+ return self.data_processor.if_return_forward_new_output()
49
+
50
+ def get_forward_new_output(self):
51
+ return self.data_processor.get_forward_new_output()
52
+
53
+ def visit_and_clear_overflow_status(self, api_or_module_name):
54
+ self.data_processor.visit_and_clear_overflow_status(api_or_module_name)
55
+
56
+ def write_json(self):
57
+ self.data_writer.write_json()
58
+
59
+ def update_data(self, data_info, msg=''):
60
+ if self.config.task == Const.OVERFLOW_CHECK:
61
+ if self.data_processor.has_overflow:
62
+ self.data_writer.update_data(data_info)
63
+ msg += "Overflow detected."
64
+ else:
65
+ msg += "No Overflow, OK."
66
+ else:
67
+ self.data_writer.update_data(data_info)
68
+ return msg
69
+
70
+ def pre_forward_data_collect(self, name, module, pid, module_input_output):
71
+ backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
72
+ if self.check_scope_and_pid(self.scope, backward_name, pid):
73
+ self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
74
+ if not self.is_inplace(module):
75
+ return
76
+ logger.info(f"API {name} is inplace.")
77
+ if self.check_scope_and_pid(self.scope, name, pid):
78
+ data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
79
+ self.update_data(data_info)
80
+
81
+ def forward_data_collect(self, name, module, pid, module_input_output):
82
+ self.update_construct(name)
83
+ if not self.check_scope_and_pid(self.scope, name, pid):
84
+ return
85
+
86
+ if not self.is_inplace(module):
87
+ data_info = self.data_processor.analyze_forward(name, module, module_input_output)
88
+ else:
89
+ data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
90
+ if self.config.level == "L2":
91
+ return
92
+ self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
93
+ self.handle_data(name, data_info)
94
+
95
+ def backward_data_collect(self, name, module, pid, module_input_output):
96
+ self.update_construct(name)
97
+ if not self.check_scope_and_pid(self.scope, name, pid):
98
+ return
99
+
100
+ data_info = self.data_processor.analyze_backward(name, module, module_input_output)
101
+ self.handle_data(name, data_info)
102
+
103
+ def update_construct(self, name):
104
+ if self.config.level not in DataCollector.level_without_construct:
105
+ self.data_writer.update_construct({name: self.module_processor.api_parent_node})
106
+ self.data_writer.update_construct(self.module_processor.module_node)
107
+
108
+ def handle_data(self, name, data_info):
109
+ msg = f"msProbe is collecting data on {name}. "
110
+ if data_info:
111
+ msg = self.update_data(data_info, msg)
112
+ logger.info(msg)
113
+ self.data_writer.flush_data_when_buffer_is_full()
114
+
115
+ def module_count_func(self, name, name_template):
116
+ module_name = name.split(Const.SEP)[-3]
117
+ if "forward" in name_template:
118
+ if module_name not in self.module_count:
119
+ self.module_count[module_name] = [0, [0]]
120
+ else:
121
+ if self.module_count[module_name][-1] and \
122
+ self.module_count[module_name][0] != self.module_count[module_name][-1][-1]:
123
+ self.module_count[module_name][-1].pop()
124
+ self.module_count[module_name][0] += 1
125
+ self.module_count[module_name][-1].append(self.module_count[module_name][0])
126
+ index = self.module_count[module_name][0]
127
+ else:
128
+ backward_stack = self.module_count[module_name][-1] if module_name in self.module_count else []
129
+ if not backward_stack:
130
+ index = "abnormal"
131
+ else:
132
+ index = backward_stack.pop()
133
+ return index
134
+
135
+ def update_dump_paths(self, *args):
136
+ self.data_writer.update_dump_paths(*args)
137
+ self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
138
+
139
+ def update_iter(self, current_iter):
140
+ self.data_processor.update_iter(current_iter)
@@ -0,0 +1,245 @@
1
+ import os
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, Dict, Optional, Any
5
+ import numpy as np
6
+ from msprobe.core.common.log import logger
7
+ from msprobe.core.common.utils import convert_tuple
8
+ from msprobe.core.common.const import Const
9
+
10
+
11
+ @dataclass
12
+ class ModuleForwardInputsOutputs:
13
+ args: Optional[Tuple]
14
+ kwargs: Optional[Dict]
15
+ output: Any
16
+
17
+ @property
18
+ def args_tuple(self):
19
+ return convert_tuple(self.args)
20
+
21
+ @property
22
+ def output_tuple(self):
23
+ return convert_tuple(self.output)
24
+
25
+ def concat_args_and_kwargs(self):
26
+ args = self.args + tuple(self.kwargs.values())
27
+ return args
28
+
29
+
30
+ @dataclass
31
+ class ModuleBackwardInputsOutputs:
32
+ grad_output: Optional[Tuple]
33
+ grad_input: Optional[Tuple]
34
+
35
+ @property
36
+ def grad_input_tuple(self):
37
+ return convert_tuple(self.grad_input)
38
+
39
+ @property
40
+ def grad_output_tuple(self):
41
+ return convert_tuple(self.grad_output)
42
+
43
+
44
+ class TensorStatInfo:
45
+ def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
46
+ self.max = max_val
47
+ self.min = min_val
48
+ self.mean = mean_val
49
+ self.norm = norm_val
50
+
51
+
52
+ class BaseDataProcessor:
53
+ _recursive_key_stack = []
54
+ special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
55
+ bool, int, float, str, slice)
56
+
57
+ def __init__(self, config, data_writer):
58
+ self.data_writer = data_writer
59
+ self.config = config
60
+ self.api_info_struct = {}
61
+ self.stack_info_struct = {}
62
+ self.current_api_or_module_name = None
63
+ self.api_data_category = None
64
+ self.has_overflow = False
65
+ self.current_iter = 0
66
+ self._return_forward_new_output = False
67
+ self._forward_new_output = None
68
+
69
+ @property
70
+ def data_path(self):
71
+ return self.data_writer.dump_tensor_data_dir
72
+
73
+ @staticmethod
74
+ def analyze_api_call_stack(name):
75
+ stack_str = []
76
+ for (_, path, line, func, code, _) in inspect.stack()[5:]:
77
+ if not code:
78
+ continue
79
+ stack_line = " ".join([
80
+ "File", ", ".join([
81
+ path,
82
+ " ".join(["line", str(line)]),
83
+ " ".join(["in", func]),
84
+ " ".join(["\n", code[0].strip()])
85
+ ])
86
+ ])
87
+ stack_str.append(stack_line)
88
+ stack_info_struct = {name: stack_str}
89
+ return stack_info_struct
90
+
91
+ @staticmethod
92
+ def _convert_numpy_to_builtin(arg):
93
+ type_mapping = {
94
+ np.integer: int,
95
+ np.floating: float,
96
+ np.bool_: bool,
97
+ np.complexfloating: complex,
98
+ np.str_: str,
99
+ np.byte: bytes,
100
+ np.unicode_: str
101
+ }
102
+ for numpy_type, builtin_type in type_mapping.items():
103
+ if isinstance(arg, numpy_type):
104
+ return builtin_type(arg), type(arg).__name__
105
+ return arg, ''
106
+
107
+ @staticmethod
108
+ def _analyze_numpy(value, numpy_type):
109
+ return {"type": numpy_type, "value": value}
110
+
111
+ @staticmethod
112
+ def _analyze_builtin(arg):
113
+ single_arg = {}
114
+ if isinstance(arg, slice):
115
+ single_arg.update({"type": "slice"})
116
+ single_arg.update({"value": [arg.start, arg.stop, arg.step]})
117
+ else:
118
+ single_arg.update({"type": type(arg).__name__})
119
+ single_arg.update({"value": arg})
120
+ return single_arg
121
+
122
+ @classmethod
123
+ def get_special_types(cls):
124
+ return cls.special_type
125
+
126
+ @classmethod
127
+ def recursive_apply_transform(cls, args, transform):
128
+ if isinstance(args, cls.get_special_types()):
129
+ arg_transform = transform(args, cls._recursive_key_stack)
130
+ return arg_transform
131
+ elif isinstance(args, (list, tuple)):
132
+ result_list = []
133
+ for i, arg in enumerate(args):
134
+ cls._recursive_key_stack.append(str(i))
135
+ result_list.append(cls.recursive_apply_transform(arg, transform))
136
+ cls._recursive_key_stack.pop()
137
+ return type(args)(result_list)
138
+ elif isinstance(args, dict):
139
+ resutl_dict = {}
140
+ for k, arg in args.items():
141
+ cls._recursive_key_stack.append(str(k))
142
+ resutl_dict[k] = cls.recursive_apply_transform(arg, transform)
143
+ cls._recursive_key_stack.pop()
144
+ return resutl_dict
145
+ elif args is not None:
146
+ logger.warning(f"Data type {type(args)} is not supported.")
147
+ return None
148
+ else:
149
+ return None
150
+
151
+ def if_return_forward_new_output(self):
152
+ return self._return_forward_new_output
153
+
154
+ def get_forward_new_output(self):
155
+ self._return_forward_new_output = False
156
+ return self._forward_new_output
157
+
158
+ def update_iter(self, current_iter):
159
+ self.current_iter = current_iter
160
+
161
+ def visit_and_clear_overflow_status(self, api_or_module_name):
162
+ if self.current_api_or_module_name != api_or_module_name:
163
+ self.current_api_or_module_name = api_or_module_name
164
+ self.has_overflow = False
165
+
166
+ def is_dump_for_data_mode(self, forward_backward, input_output):
167
+ """
168
+ Compare the parameters with data_mode to determine whether to dump.
169
+
170
+ Args:
171
+ forward_backward(str): The forward or backward mode to check.
172
+ input_output(str): The input or output mode to check.
173
+
174
+ Return:
175
+ bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
176
+ """
177
+ return (Const.ALL in self.config.data_mode or
178
+ forward_backward in self.config.data_mode or
179
+ input_output in self.config.data_mode)
180
+
181
+ def analyze_pre_forward(self, name, module,module_input_output: ModuleForwardInputsOutputs):
182
+ pass
183
+
184
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
185
+ api_info_struct = {}
186
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input
187
+ api_info_struct[name] = {}
188
+ self.api_data_category = Const.INPUT
189
+ args_info_list = self.analyze_element(module_input_output.args_tuple)
190
+ api_info_struct[name][Const.INPUT_ARGS] = args_info_list
191
+ self.api_data_category = Const.KWARGS
192
+ kwargs_info_list = self.analyze_element(module_input_output.kwargs)
193
+ api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
194
+
195
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output
196
+ api_info_struct[name] = api_info_struct.get(name, {})
197
+ self.api_data_category = Const.OUTPUT
198
+ output_info_list = self.analyze_element(module_input_output.output_tuple)
199
+ api_info_struct[name][Const.OUTPUT] = output_info_list
200
+ return api_info_struct
201
+
202
+ def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
203
+ api_info_struct = {}
204
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
205
+ api_info_struct[name] = {}
206
+ self.api_data_category = Const.INPUT
207
+ args_info_list = self.analyze_element(module_input_output.args_tuple)
208
+ api_info_struct[name][Const.INPUT_ARGS] = args_info_list
209
+ self.api_data_category = Const.KWARGS
210
+ kwargs_info_list = self.analyze_element(module_input_output.kwargs)
211
+ api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
212
+ return api_info_struct
213
+
214
+ def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
215
+ concat_args = module_input_output.concat_args_and_kwargs()
216
+ api_info_struct = {}
217
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
218
+ api_info_struct[name] = {}
219
+ self.api_data_category = Const.OUTPUT
220
+ output_info_list = self.analyze_element(concat_args)
221
+ api_info_struct[name][Const.OUTPUT] = output_info_list
222
+ return api_info_struct
223
+
224
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
225
+ api_info_struct = {}
226
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
227
+ api_info_struct[name] = {}
228
+ self.api_data_category = Const.OUTPUT
229
+ input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
230
+ api_info_struct[name][Const.GRAD_INPUT] = input_info_list
231
+
232
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
233
+ api_info_struct[name] = api_info_struct.get(name, {})
234
+ self.api_data_category = Const.INPUT
235
+ output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
236
+ api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
237
+
238
+ return api_info_struct
239
+
240
+ def get_save_file_path(self, suffix):
241
+ file_format = "pt" if self.config.framework == Const.PT_FRAMEWORK else "npy"
242
+ dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
243
+ suffix + Const.SEP + file_format)
244
+ file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
245
+ return dump_data_name, file_path
@@ -0,0 +1,61 @@
1
+ from msprobe.core.common.const import Const
2
+
3
+
4
+ class DataProcessorFactory:
5
+ _data_processor = {}
6
+ _module_processor = {}
7
+
8
+ @classmethod
9
+ def register_processor(cls, framework, task, processor_class):
10
+ key = (framework, task)
11
+ cls._data_processor[key] = processor_class
12
+
13
+ @classmethod
14
+ def register_module_processor(cls, framework, processor_class):
15
+ cls._module_processor[framework] = processor_class
16
+
17
+ @classmethod
18
+ def get_module_processor(cls, framework):
19
+ processor_class = cls._module_processor.get(framework)
20
+ if not processor_class:
21
+ raise ValueError(f"ModuleProcesser not found for framework: {framework}")
22
+ return processor_class
23
+
24
+ @classmethod
25
+ def create_processor(cls, config, data_writer):
26
+ cls.register_processors(config.framework)
27
+ task = Const.KERNEL_DUMP if config.level == "L2" else config.task
28
+ key = (config.framework, task)
29
+ processor_class = cls._data_processor.get(key)
30
+ if not processor_class:
31
+ raise ValueError(f"Processor not found for framework: {config.framework}, task: {config.task}")
32
+ return processor_class(config, data_writer)
33
+
34
+ @classmethod
35
+ def register_processors(cls, framework):
36
+ if framework == Const.PT_FRAMEWORK:
37
+ from .pytorch_processor import (
38
+ StatisticsDataProcessor as PytorchStatisticsDataProcessor,
39
+ TensorDataProcessor as PytorchTensorDataProcessor,
40
+ OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
41
+ FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
42
+ KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
43
+ )
44
+ from ....pytorch.module_processer import ModuleProcesser
45
+ cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
46
+ cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
47
+ cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
48
+ cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
49
+ cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
50
+ cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
51
+ elif framework == Const.MS_FRAMEWORK:
52
+ from .mindspore_processor import (
53
+ StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
54
+ TensorDataProcessor as MindsporeTensorDataProcessor,
55
+ OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
56
+ FreeBenchmarkDataProcessor as MindsporeFreeBenchmarkDataProcessor
57
+ )
58
+ cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
59
+ cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
60
+ cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
61
+ cls.register_processor(Const.MS_FRAMEWORK, Const.FREE_BENCHMARK, MindsporeFreeBenchmarkDataProcessor)