mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from msprobe.core.data_dump.scope import build_scope, ListScope
|
|
5
|
+
from msprobe.core.data_dump.json_writer import DataWriter
|
|
6
|
+
from msprobe.core.common.log import logger
|
|
7
|
+
from msprobe.core.common.const import Const
|
|
8
|
+
from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_data_collector(config):
|
|
12
|
+
return DataCollector(config)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DataCollector:
|
|
16
|
+
multi_output_apis = ["_sort_", "npu_flash_attention"]
|
|
17
|
+
tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
|
|
18
|
+
level_without_construct = ["L1", "L2"]
|
|
19
|
+
|
|
20
|
+
def __init__(self, config):
|
|
21
|
+
self.config = config
|
|
22
|
+
self.data_writer = DataWriter()
|
|
23
|
+
self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
|
|
24
|
+
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) if self.config.framework == Const.PT_FRAMEWORK else None
|
|
25
|
+
self.module_count = {}
|
|
26
|
+
if self.config.task == Const.FREE_BENCHMARK:
|
|
27
|
+
self.scope = build_scope(ListScope, self.config.scope, self.config.list)
|
|
28
|
+
else:
|
|
29
|
+
self.scope = build_scope(None, self.config.scope, self.config.list)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dump_data_dir(self):
|
|
33
|
+
return self.data_writer.dump_tensor_data_dir
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def dump_file_path(self):
|
|
37
|
+
return self.data_writer.dump_file_path
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def check_scope_and_pid(scope, name, pid):
|
|
41
|
+
return (not scope or scope.check(name)) and pid == os.getpid()
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def is_inplace(module):
|
|
45
|
+
return getattr(module, "op_is_inplace", False)
|
|
46
|
+
|
|
47
|
+
def if_return_forward_new_output(self):
|
|
48
|
+
return self.data_processor.if_return_forward_new_output()
|
|
49
|
+
|
|
50
|
+
def get_forward_new_output(self):
|
|
51
|
+
return self.data_processor.get_forward_new_output()
|
|
52
|
+
|
|
53
|
+
def visit_and_clear_overflow_status(self, api_or_module_name):
|
|
54
|
+
self.data_processor.visit_and_clear_overflow_status(api_or_module_name)
|
|
55
|
+
|
|
56
|
+
def write_json(self):
|
|
57
|
+
self.data_writer.write_json()
|
|
58
|
+
|
|
59
|
+
def update_data(self, data_info, msg=''):
|
|
60
|
+
if self.config.task == Const.OVERFLOW_CHECK:
|
|
61
|
+
if self.data_processor.has_overflow:
|
|
62
|
+
self.data_writer.update_data(data_info)
|
|
63
|
+
msg += "Overflow detected."
|
|
64
|
+
else:
|
|
65
|
+
msg += "No Overflow, OK."
|
|
66
|
+
else:
|
|
67
|
+
self.data_writer.update_data(data_info)
|
|
68
|
+
return msg
|
|
69
|
+
|
|
70
|
+
def pre_forward_data_collect(self, name, module, pid, module_input_output):
|
|
71
|
+
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
72
|
+
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
73
|
+
self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
|
|
74
|
+
if not self.is_inplace(module):
|
|
75
|
+
return
|
|
76
|
+
logger.info(f"API {name} is inplace.")
|
|
77
|
+
if self.check_scope_and_pid(self.scope, name, pid):
|
|
78
|
+
data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
|
|
79
|
+
self.update_data(data_info)
|
|
80
|
+
|
|
81
|
+
def forward_data_collect(self, name, module, pid, module_input_output):
|
|
82
|
+
self.update_construct(name)
|
|
83
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
if not self.is_inplace(module):
|
|
87
|
+
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
88
|
+
else:
|
|
89
|
+
data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
|
|
90
|
+
if self.config.level == "L2":
|
|
91
|
+
return
|
|
92
|
+
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
93
|
+
self.handle_data(name, data_info)
|
|
94
|
+
|
|
95
|
+
def backward_data_collect(self, name, module, pid, module_input_output):
|
|
96
|
+
self.update_construct(name)
|
|
97
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
101
|
+
self.handle_data(name, data_info)
|
|
102
|
+
|
|
103
|
+
def update_construct(self, name):
|
|
104
|
+
if self.config.level not in DataCollector.level_without_construct:
|
|
105
|
+
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
106
|
+
self.data_writer.update_construct(self.module_processor.module_node)
|
|
107
|
+
|
|
108
|
+
def handle_data(self, name, data_info):
|
|
109
|
+
msg = f"msProbe is collecting data on {name}. "
|
|
110
|
+
if data_info:
|
|
111
|
+
msg = self.update_data(data_info, msg)
|
|
112
|
+
logger.info(msg)
|
|
113
|
+
self.data_writer.flush_data_when_buffer_is_full()
|
|
114
|
+
|
|
115
|
+
def module_count_func(self, name, name_template):
|
|
116
|
+
module_name = name.split(Const.SEP)[-3]
|
|
117
|
+
if "forward" in name_template:
|
|
118
|
+
if module_name not in self.module_count:
|
|
119
|
+
self.module_count[module_name] = [0, [0]]
|
|
120
|
+
else:
|
|
121
|
+
if self.module_count[module_name][-1] and \
|
|
122
|
+
self.module_count[module_name][0] != self.module_count[module_name][-1][-1]:
|
|
123
|
+
self.module_count[module_name][-1].pop()
|
|
124
|
+
self.module_count[module_name][0] += 1
|
|
125
|
+
self.module_count[module_name][-1].append(self.module_count[module_name][0])
|
|
126
|
+
index = self.module_count[module_name][0]
|
|
127
|
+
else:
|
|
128
|
+
backward_stack = self.module_count[module_name][-1] if module_name in self.module_count else []
|
|
129
|
+
if not backward_stack:
|
|
130
|
+
index = "abnormal"
|
|
131
|
+
else:
|
|
132
|
+
index = backward_stack.pop()
|
|
133
|
+
return index
|
|
134
|
+
|
|
135
|
+
def update_dump_paths(self, *args):
|
|
136
|
+
self.data_writer.update_dump_paths(*args)
|
|
137
|
+
self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
|
|
138
|
+
|
|
139
|
+
def update_iter(self, current_iter):
|
|
140
|
+
self.data_processor.update_iter(current_iter)
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import inspect
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Tuple, Dict, Optional, Any
|
|
5
|
+
import numpy as np
|
|
6
|
+
from msprobe.core.common.log import logger
|
|
7
|
+
from msprobe.core.common.utils import convert_tuple
|
|
8
|
+
from msprobe.core.common.const import Const
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ModuleForwardInputsOutputs:
|
|
13
|
+
args: Optional[Tuple]
|
|
14
|
+
kwargs: Optional[Dict]
|
|
15
|
+
output: Any
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def args_tuple(self):
|
|
19
|
+
return convert_tuple(self.args)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def output_tuple(self):
|
|
23
|
+
return convert_tuple(self.output)
|
|
24
|
+
|
|
25
|
+
def concat_args_and_kwargs(self):
|
|
26
|
+
args = self.args + tuple(self.kwargs.values())
|
|
27
|
+
return args
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ModuleBackwardInputsOutputs:
|
|
32
|
+
grad_output: Optional[Tuple]
|
|
33
|
+
grad_input: Optional[Tuple]
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def grad_input_tuple(self):
|
|
37
|
+
return convert_tuple(self.grad_input)
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def grad_output_tuple(self):
|
|
41
|
+
return convert_tuple(self.grad_output)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TensorStatInfo:
|
|
45
|
+
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
46
|
+
self.max = max_val
|
|
47
|
+
self.min = min_val
|
|
48
|
+
self.mean = mean_val
|
|
49
|
+
self.norm = norm_val
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class BaseDataProcessor:
|
|
53
|
+
_recursive_key_stack = []
|
|
54
|
+
special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
55
|
+
bool, int, float, str, slice)
|
|
56
|
+
|
|
57
|
+
def __init__(self, config, data_writer):
|
|
58
|
+
self.data_writer = data_writer
|
|
59
|
+
self.config = config
|
|
60
|
+
self.api_info_struct = {}
|
|
61
|
+
self.stack_info_struct = {}
|
|
62
|
+
self.current_api_or_module_name = None
|
|
63
|
+
self.api_data_category = None
|
|
64
|
+
self.has_overflow = False
|
|
65
|
+
self.current_iter = 0
|
|
66
|
+
self._return_forward_new_output = False
|
|
67
|
+
self._forward_new_output = None
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def data_path(self):
|
|
71
|
+
return self.data_writer.dump_tensor_data_dir
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def analyze_api_call_stack(name):
|
|
75
|
+
stack_str = []
|
|
76
|
+
for (_, path, line, func, code, _) in inspect.stack()[5:]:
|
|
77
|
+
if not code:
|
|
78
|
+
continue
|
|
79
|
+
stack_line = " ".join([
|
|
80
|
+
"File", ", ".join([
|
|
81
|
+
path,
|
|
82
|
+
" ".join(["line", str(line)]),
|
|
83
|
+
" ".join(["in", func]),
|
|
84
|
+
" ".join(["\n", code[0].strip()])
|
|
85
|
+
])
|
|
86
|
+
])
|
|
87
|
+
stack_str.append(stack_line)
|
|
88
|
+
stack_info_struct = {name: stack_str}
|
|
89
|
+
return stack_info_struct
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def _convert_numpy_to_builtin(arg):
|
|
93
|
+
type_mapping = {
|
|
94
|
+
np.integer: int,
|
|
95
|
+
np.floating: float,
|
|
96
|
+
np.bool_: bool,
|
|
97
|
+
np.complexfloating: complex,
|
|
98
|
+
np.str_: str,
|
|
99
|
+
np.byte: bytes,
|
|
100
|
+
np.unicode_: str
|
|
101
|
+
}
|
|
102
|
+
for numpy_type, builtin_type in type_mapping.items():
|
|
103
|
+
if isinstance(arg, numpy_type):
|
|
104
|
+
return builtin_type(arg), type(arg).__name__
|
|
105
|
+
return arg, ''
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def _analyze_numpy(value, numpy_type):
|
|
109
|
+
return {"type": numpy_type, "value": value}
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _analyze_builtin(arg):
|
|
113
|
+
single_arg = {}
|
|
114
|
+
if isinstance(arg, slice):
|
|
115
|
+
single_arg.update({"type": "slice"})
|
|
116
|
+
single_arg.update({"value": [arg.start, arg.stop, arg.step]})
|
|
117
|
+
else:
|
|
118
|
+
single_arg.update({"type": type(arg).__name__})
|
|
119
|
+
single_arg.update({"value": arg})
|
|
120
|
+
return single_arg
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def get_special_types(cls):
|
|
124
|
+
return cls.special_type
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def recursive_apply_transform(cls, args, transform):
|
|
128
|
+
if isinstance(args, cls.get_special_types()):
|
|
129
|
+
arg_transform = transform(args, cls._recursive_key_stack)
|
|
130
|
+
return arg_transform
|
|
131
|
+
elif isinstance(args, (list, tuple)):
|
|
132
|
+
result_list = []
|
|
133
|
+
for i, arg in enumerate(args):
|
|
134
|
+
cls._recursive_key_stack.append(str(i))
|
|
135
|
+
result_list.append(cls.recursive_apply_transform(arg, transform))
|
|
136
|
+
cls._recursive_key_stack.pop()
|
|
137
|
+
return type(args)(result_list)
|
|
138
|
+
elif isinstance(args, dict):
|
|
139
|
+
resutl_dict = {}
|
|
140
|
+
for k, arg in args.items():
|
|
141
|
+
cls._recursive_key_stack.append(str(k))
|
|
142
|
+
resutl_dict[k] = cls.recursive_apply_transform(arg, transform)
|
|
143
|
+
cls._recursive_key_stack.pop()
|
|
144
|
+
return resutl_dict
|
|
145
|
+
elif args is not None:
|
|
146
|
+
logger.warning(f"Data type {type(args)} is not supported.")
|
|
147
|
+
return None
|
|
148
|
+
else:
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def if_return_forward_new_output(self):
|
|
152
|
+
return self._return_forward_new_output
|
|
153
|
+
|
|
154
|
+
def get_forward_new_output(self):
|
|
155
|
+
self._return_forward_new_output = False
|
|
156
|
+
return self._forward_new_output
|
|
157
|
+
|
|
158
|
+
def update_iter(self, current_iter):
|
|
159
|
+
self.current_iter = current_iter
|
|
160
|
+
|
|
161
|
+
def visit_and_clear_overflow_status(self, api_or_module_name):
|
|
162
|
+
if self.current_api_or_module_name != api_or_module_name:
|
|
163
|
+
self.current_api_or_module_name = api_or_module_name
|
|
164
|
+
self.has_overflow = False
|
|
165
|
+
|
|
166
|
+
def is_dump_for_data_mode(self, forward_backward, input_output):
|
|
167
|
+
"""
|
|
168
|
+
Compare the parameters with data_mode to determine whether to dump.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
forward_backward(str): The forward or backward mode to check.
|
|
172
|
+
input_output(str): The input or output mode to check.
|
|
173
|
+
|
|
174
|
+
Return:
|
|
175
|
+
bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
|
|
176
|
+
"""
|
|
177
|
+
return (Const.ALL in self.config.data_mode or
|
|
178
|
+
forward_backward in self.config.data_mode or
|
|
179
|
+
input_output in self.config.data_mode)
|
|
180
|
+
|
|
181
|
+
def analyze_pre_forward(self, name, module,module_input_output: ModuleForwardInputsOutputs):
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
185
|
+
api_info_struct = {}
|
|
186
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input
|
|
187
|
+
api_info_struct[name] = {}
|
|
188
|
+
self.api_data_category = Const.INPUT
|
|
189
|
+
args_info_list = self.analyze_element(module_input_output.args_tuple)
|
|
190
|
+
api_info_struct[name][Const.INPUT_ARGS] = args_info_list
|
|
191
|
+
self.api_data_category = Const.KWARGS
|
|
192
|
+
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
193
|
+
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
194
|
+
|
|
195
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output
|
|
196
|
+
api_info_struct[name] = api_info_struct.get(name, {})
|
|
197
|
+
self.api_data_category = Const.OUTPUT
|
|
198
|
+
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
199
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
200
|
+
return api_info_struct
|
|
201
|
+
|
|
202
|
+
def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
203
|
+
api_info_struct = {}
|
|
204
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
205
|
+
api_info_struct[name] = {}
|
|
206
|
+
self.api_data_category = Const.INPUT
|
|
207
|
+
args_info_list = self.analyze_element(module_input_output.args_tuple)
|
|
208
|
+
api_info_struct[name][Const.INPUT_ARGS] = args_info_list
|
|
209
|
+
self.api_data_category = Const.KWARGS
|
|
210
|
+
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
211
|
+
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
212
|
+
return api_info_struct
|
|
213
|
+
|
|
214
|
+
def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
215
|
+
concat_args = module_input_output.concat_args_and_kwargs()
|
|
216
|
+
api_info_struct = {}
|
|
217
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
218
|
+
api_info_struct[name] = {}
|
|
219
|
+
self.api_data_category = Const.OUTPUT
|
|
220
|
+
output_info_list = self.analyze_element(concat_args)
|
|
221
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
222
|
+
return api_info_struct
|
|
223
|
+
|
|
224
|
+
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
225
|
+
api_info_struct = {}
|
|
226
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
|
|
227
|
+
api_info_struct[name] = {}
|
|
228
|
+
self.api_data_category = Const.OUTPUT
|
|
229
|
+
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
|
|
230
|
+
api_info_struct[name][Const.GRAD_INPUT] = input_info_list
|
|
231
|
+
|
|
232
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
|
|
233
|
+
api_info_struct[name] = api_info_struct.get(name, {})
|
|
234
|
+
self.api_data_category = Const.INPUT
|
|
235
|
+
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
|
|
236
|
+
api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
|
|
237
|
+
|
|
238
|
+
return api_info_struct
|
|
239
|
+
|
|
240
|
+
def get_save_file_path(self, suffix):
|
|
241
|
+
file_format = "pt" if self.config.framework == Const.PT_FRAMEWORK else "npy"
|
|
242
|
+
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
|
|
243
|
+
suffix + Const.SEP + file_format)
|
|
244
|
+
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
245
|
+
return dump_data_name, file_path
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from msprobe.core.common.const import Const
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DataProcessorFactory:
|
|
5
|
+
_data_processor = {}
|
|
6
|
+
_module_processor = {}
|
|
7
|
+
|
|
8
|
+
@classmethod
|
|
9
|
+
def register_processor(cls, framework, task, processor_class):
|
|
10
|
+
key = (framework, task)
|
|
11
|
+
cls._data_processor[key] = processor_class
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def register_module_processor(cls, framework, processor_class):
|
|
15
|
+
cls._module_processor[framework] = processor_class
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def get_module_processor(cls, framework):
|
|
19
|
+
processor_class = cls._module_processor.get(framework)
|
|
20
|
+
if not processor_class:
|
|
21
|
+
raise ValueError(f"ModuleProcesser not found for framework: {framework}")
|
|
22
|
+
return processor_class
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def create_processor(cls, config, data_writer):
|
|
26
|
+
cls.register_processors(config.framework)
|
|
27
|
+
task = Const.KERNEL_DUMP if config.level == "L2" else config.task
|
|
28
|
+
key = (config.framework, task)
|
|
29
|
+
processor_class = cls._data_processor.get(key)
|
|
30
|
+
if not processor_class:
|
|
31
|
+
raise ValueError(f"Processor not found for framework: {config.framework}, task: {config.task}")
|
|
32
|
+
return processor_class(config, data_writer)
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def register_processors(cls, framework):
|
|
36
|
+
if framework == Const.PT_FRAMEWORK:
|
|
37
|
+
from .pytorch_processor import (
|
|
38
|
+
StatisticsDataProcessor as PytorchStatisticsDataProcessor,
|
|
39
|
+
TensorDataProcessor as PytorchTensorDataProcessor,
|
|
40
|
+
OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
|
|
41
|
+
FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
|
|
42
|
+
KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
|
|
43
|
+
)
|
|
44
|
+
from ....pytorch.module_processer import ModuleProcesser
|
|
45
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
|
|
46
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
|
|
47
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
|
|
48
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
|
|
49
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
|
|
50
|
+
cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
|
|
51
|
+
elif framework == Const.MS_FRAMEWORK:
|
|
52
|
+
from .mindspore_processor import (
|
|
53
|
+
StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
|
|
54
|
+
TensorDataProcessor as MindsporeTensorDataProcessor,
|
|
55
|
+
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
|
|
56
|
+
FreeBenchmarkDataProcessor as MindsporeFreeBenchmarkDataProcessor
|
|
57
|
+
)
|
|
58
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
|
|
59
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
|
|
60
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
|
|
61
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.FREE_BENCHMARK, MindsporeFreeBenchmarkDataProcessor)
|