mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -20,89 +20,108 @@ from collections import namedtuple
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import openpyxl
|
|
22
22
|
from openpyxl.styles import PatternFill
|
|
23
|
+
from tqdm import tqdm
|
|
23
24
|
from msprobe.core.common.utils import get_header_index
|
|
24
25
|
from msprobe.core.common.file_utils import save_workbook
|
|
25
26
|
from msprobe.core.common.log import logger
|
|
26
|
-
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
27
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst, Const
|
|
28
|
+
from msprobe.core.common.utils import safe_get_value
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
class HighlightCheck(abc.ABC):
|
|
30
32
|
@abc.abstractmethod
|
|
31
|
-
def apply(self, info, color_columns,
|
|
33
|
+
def apply(self, info, color_columns, dump_mode):
|
|
32
34
|
raise NotImplementedError
|
|
33
35
|
|
|
34
36
|
|
|
37
|
+
def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
38
|
+
for i, (existing_num, existing_err_msg) in enumerate(color_list):
|
|
39
|
+
if num == existing_num:
|
|
40
|
+
color_list[i][1].append(highlight_err_msg)
|
|
41
|
+
return
|
|
42
|
+
color_list.append((num, [highlight_err_msg]))
|
|
43
|
+
|
|
44
|
+
|
|
35
45
|
class CheckOrderMagnitude(HighlightCheck):
|
|
36
46
|
"""检查Max diff的数量级差异"""
|
|
37
|
-
def apply(self, info, color_columns,
|
|
47
|
+
def apply(self, info, color_columns, dump_mode):
|
|
38
48
|
api_in, api_out, num = info
|
|
39
|
-
max_diff_index = get_header_index(
|
|
49
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
50
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
40
51
|
if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
|
|
41
52
|
return
|
|
42
53
|
in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
|
|
43
54
|
out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
|
|
44
55
|
if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
|
|
45
|
-
color_columns.yellow
|
|
56
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
57
|
+
"maximum absolute error of both input and output exceed 1, "
|
|
58
|
+
"with the output larger by an order of magnitude")
|
|
46
59
|
|
|
47
60
|
|
|
48
61
|
class CheckOneThousandErrorRatio(HighlightCheck):
|
|
49
62
|
"""检查千分误差比率"""
|
|
50
|
-
def apply(self, info, color_columns,
|
|
63
|
+
def apply(self, info, color_columns, dump_mode):
|
|
51
64
|
api_in, api_out, num = info
|
|
52
|
-
one_thousand_index = get_header_index(
|
|
65
|
+
one_thousand_index = get_header_index(CompareConst.ONE_THOUSANDTH_ERR_RATIO, dump_mode)
|
|
53
66
|
if (not isinstance(api_in[one_thousand_index], (float, int)) or
|
|
54
67
|
not isinstance(api_out[one_thousand_index], (float, int))):
|
|
55
68
|
return
|
|
56
69
|
if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
57
70
|
api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
|
|
58
|
-
color_columns.red
|
|
71
|
+
add_highlight_row_info(color_columns.red, num,
|
|
72
|
+
"The input's one thousandth err ratio exceeds 0.9, while the output's is below 0.6")
|
|
59
73
|
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
60
|
-
color_columns.yellow
|
|
74
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
75
|
+
"The output's one thousandth err ratio decreases by more than 0.1 "
|
|
76
|
+
"compared to the input's")
|
|
61
77
|
|
|
62
78
|
|
|
63
79
|
class CheckCosineSimilarity(HighlightCheck):
|
|
64
80
|
"""检查余弦相似度"""
|
|
65
|
-
def apply(self, info, color_columns,
|
|
81
|
+
def apply(self, info, color_columns, dump_mode):
|
|
66
82
|
api_in, api_out, num = info
|
|
67
|
-
cosine_index = get_header_index(
|
|
83
|
+
cosine_index = get_header_index(CompareConst.COSINE, dump_mode)
|
|
68
84
|
if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
|
|
69
85
|
return
|
|
70
86
|
if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
|
|
71
|
-
color_columns.yellow
|
|
87
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
88
|
+
"The output's cosine decreases by more than 0.1 compared to the input's")
|
|
72
89
|
|
|
73
90
|
|
|
74
91
|
class CheckMaxRelativeDiff(HighlightCheck):
|
|
75
92
|
"""检查最大相对差异"""
|
|
76
|
-
def apply(self, info, color_columns,
|
|
93
|
+
def apply(self, info, color_columns, dump_mode):
|
|
77
94
|
api_in, api_out, num = info
|
|
78
|
-
max_diff_index = get_header_index(
|
|
79
|
-
bench_max_index = get_header_index(
|
|
95
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
|
|
96
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
80
97
|
input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
|
|
81
98
|
output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
|
|
82
99
|
if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
|
|
83
100
|
(float, int)):
|
|
84
101
|
return
|
|
85
102
|
if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
|
|
86
|
-
color_columns.red.
|
|
103
|
+
add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
|
|
87
104
|
elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
88
105
|
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
89
|
-
color_columns.yellow
|
|
106
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
107
|
+
"The output's maximum relative error exceeds 0.1, while the input's is below 0.01")
|
|
90
108
|
|
|
91
109
|
|
|
92
110
|
class CheckOverflow(HighlightCheck):
|
|
93
111
|
"""检查是否存在溢出"""
|
|
94
|
-
def apply(self, info, color_columns,
|
|
112
|
+
def apply(self, info, color_columns, dump_mode):
|
|
95
113
|
line, num = info
|
|
96
|
-
npu_max_index = get_header_index(
|
|
97
|
-
npu_min_index = get_header_index(
|
|
98
|
-
max_diff_index = get_header_index(
|
|
114
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
115
|
+
npu_min_index = get_header_index(CompareConst.NPU_MIN, dump_mode)
|
|
116
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
117
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
99
118
|
if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
|
|
100
119
|
line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
|
|
101
|
-
color_columns.red
|
|
120
|
+
add_highlight_row_info(color_columns.red, num, "maximum or minimum is nan, -inf, or inf")
|
|
102
121
|
return
|
|
103
122
|
# check if Max_Diff > 1e+10
|
|
104
|
-
if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
|
|
105
|
-
color_columns.red
|
|
123
|
+
if isinstance(line[max_diff_index], (float, int)) and abs(line[max_diff_index]) > CompareConst.MAX_DIFF_RED:
|
|
124
|
+
add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
|
|
106
125
|
|
|
107
126
|
|
|
108
127
|
class HighlightRules:
|
|
@@ -124,13 +143,14 @@ class HighlightRules:
|
|
|
124
143
|
}
|
|
125
144
|
|
|
126
145
|
|
|
127
|
-
def find_error_rows(result, last_len, n_num_input, highlight_dict,
|
|
146
|
+
def find_error_rows(result, last_len, n_num_input, highlight_dict, dump_mode):
|
|
128
147
|
"""找到单个API中需要高亮的行"""
|
|
129
|
-
if
|
|
148
|
+
if dump_mode == Const.MD5:
|
|
130
149
|
return
|
|
131
|
-
npu_max_index = get_header_index(
|
|
132
|
-
bench_max_index = get_header_index(
|
|
133
|
-
max_diff_index = get_header_index(
|
|
150
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
151
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
152
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
153
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
134
154
|
|
|
135
155
|
red_lines, yellow_lines = [], []
|
|
136
156
|
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
@@ -143,7 +163,7 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
|
|
|
143
163
|
num = last_len + i
|
|
144
164
|
line_info = LineInfo(line_data=line, num_pointer=num)
|
|
145
165
|
for rule in HighlightRules.basic_rules.values():
|
|
146
|
-
rule.apply(line_info, color_columns,
|
|
166
|
+
rule.apply(line_info, color_columns, dump_mode)
|
|
147
167
|
|
|
148
168
|
# 对API的输出与输入比较,进行误差判断
|
|
149
169
|
for n, api_out in enumerate(result[n_num_input:len(result)]):
|
|
@@ -161,36 +181,42 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
|
|
|
161
181
|
continue
|
|
162
182
|
|
|
163
183
|
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
|
|
164
|
-
if
|
|
184
|
+
if dump_mode == Const.SUMMARY:
|
|
165
185
|
for rule in HighlightRules.summary_compare_rules.values():
|
|
166
|
-
rule.apply(api_info, color_columns,
|
|
186
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
167
187
|
else:
|
|
168
188
|
for rule in HighlightRules.compare_rules.values():
|
|
169
|
-
rule.apply(api_info, color_columns,
|
|
189
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
170
190
|
|
|
171
|
-
|
|
172
|
-
|
|
191
|
+
red_lines_num_set = {x[0] for x in red_lines}
|
|
192
|
+
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
193
|
+
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
194
|
+
highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
|
|
195
|
+
highlight_dict.get('red_lines', []).extend(red_lines)
|
|
196
|
+
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
173
197
|
|
|
174
198
|
|
|
175
199
|
def get_name_and_state(name):
|
|
176
200
|
"""Get api/module name and state"""
|
|
177
|
-
if
|
|
178
|
-
api_name = name.split(
|
|
179
|
-
state =
|
|
201
|
+
if Const.INPUT in name:
|
|
202
|
+
api_name = name.split(Const.INPUT)[0]
|
|
203
|
+
state = Const.INPUT
|
|
180
204
|
else:
|
|
181
|
-
api_name = name.split(
|
|
182
|
-
state =
|
|
205
|
+
api_name = name.split(Const.OUTPUT)[0]
|
|
206
|
+
state = Const.OUTPUT
|
|
183
207
|
return api_name, state
|
|
184
208
|
|
|
185
209
|
|
|
186
|
-
def find_compare_result_error_rows(result_df, highlight_dict,
|
|
210
|
+
def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
|
|
187
211
|
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
188
212
|
result = result_df.values
|
|
189
213
|
start, input_num, output_num, end = 0, 0, 0, len(result_df)
|
|
190
214
|
last_api_name, last_state = None, None
|
|
191
215
|
num, last_len = 0, 0
|
|
216
|
+
progress_bar = tqdm(total=len(result), desc="API/Module Analyse Progress", unit="item", ncols=100)
|
|
192
217
|
for res_i in result:
|
|
193
|
-
|
|
218
|
+
api_full_name = safe_get_value(res_i, 0, "res_i")
|
|
219
|
+
api_name, state = get_name_and_state(api_full_name)
|
|
194
220
|
if last_api_name:
|
|
195
221
|
if api_name == last_api_name:
|
|
196
222
|
if state == last_state:
|
|
@@ -201,29 +227,33 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m
|
|
|
201
227
|
else:
|
|
202
228
|
output_num = num
|
|
203
229
|
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
204
|
-
|
|
230
|
+
dump_mode)
|
|
205
231
|
num, last_api_name, last_state = 1, api_name, state
|
|
206
232
|
start += input_num + output_num
|
|
207
233
|
input_num, output_num = 1, 0
|
|
208
234
|
else:
|
|
209
235
|
num, last_api_name, last_state = 1, api_name, state
|
|
236
|
+
progress_bar.update(1)
|
|
237
|
+
progress_bar.close()
|
|
210
238
|
if state:
|
|
211
|
-
if state ==
|
|
239
|
+
if state == Const.INPUT:
|
|
212
240
|
input_num = num
|
|
213
241
|
else:
|
|
214
242
|
output_num = num
|
|
215
243
|
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
216
|
-
|
|
244
|
+
dump_mode)
|
|
217
245
|
|
|
218
246
|
|
|
219
247
|
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
220
248
|
"""Write and highlight results in Excel"""
|
|
221
|
-
|
|
249
|
+
|
|
250
|
+
update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
|
|
222
251
|
|
|
223
252
|
wb = openpyxl.Workbook()
|
|
224
253
|
ws = wb.active
|
|
225
254
|
|
|
226
255
|
# write header
|
|
256
|
+
logger.info('Initializing Excel file.')
|
|
227
257
|
for j, col_name in enumerate(result_df.columns, start=1):
|
|
228
258
|
if not csv_value_is_valid(col_name):
|
|
229
259
|
raise RuntimeError(f"Malicious value [{col_name}] is not allowed to be written into the xlsx: {file_path}.")
|
|
@@ -231,20 +261,59 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
|
231
261
|
|
|
232
262
|
for i, row in enumerate(result_df.iterrows(), start=2):
|
|
233
263
|
for j, value in enumerate(row[1], start=1):
|
|
234
|
-
if not isinstance(value, (float, int)):
|
|
264
|
+
if not isinstance(value, (float, int)) or isinstance(value, bool):
|
|
235
265
|
value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
|
|
236
266
|
if not csv_value_is_valid(value):
|
|
237
|
-
raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx:
|
|
267
|
+
raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: "
|
|
268
|
+
f"{file_path}.")
|
|
238
269
|
ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
|
|
270
|
+
|
|
271
|
+
# 对可疑数据标色
|
|
272
|
+
logger.info('Coloring Excel in progress.')
|
|
273
|
+
col_len = len(result_df.columns)
|
|
274
|
+
red_fill = PatternFill(
|
|
275
|
+
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
276
|
+
)
|
|
277
|
+
yellow_fill = PatternFill(
|
|
278
|
+
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
279
|
+
)
|
|
280
|
+
for i in highlight_dict.get("red_rows", []):
|
|
281
|
+
for j in range(1, col_len + 1):
|
|
282
|
+
ws.cell(row=i + 2, column=j).fill = red_fill
|
|
283
|
+
for i in highlight_dict.get("yellow_rows", []):
|
|
284
|
+
for j in range(1, col_len + 1):
|
|
285
|
+
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
286
|
+
logger.info('Saving Excel file to disk: %s' % file_path)
|
|
287
|
+
save_workbook(wb, file_path)
|
|
239
288
|
|
|
240
|
-
if (i - 2) in highlight_dict['red_rows']:
|
|
241
|
-
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
|
|
242
|
-
end_color=CompareConst.RED, fill_type="solid")
|
|
243
|
-
elif (i - 2) in highlight_dict['yellow_rows']:
|
|
244
|
-
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
|
|
245
|
-
end_color=CompareConst.YELLOW, fill_type="solid")
|
|
246
289
|
|
|
247
|
-
|
|
290
|
+
def update_highlight_err_msg(result_df, highlight_dict):
|
|
291
|
+
if result_df.shape[1] <= 1:
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
if CompareConst.NPU_MD5 in result_df.columns:
|
|
295
|
+
return
|
|
296
|
+
|
|
297
|
+
err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
|
|
298
|
+
red_lines_num_set = highlight_dict.get('red_rows')
|
|
299
|
+
|
|
300
|
+
for color in ['red', 'yellow']:
|
|
301
|
+
line_key = f'{color}_lines'
|
|
302
|
+
lines = highlight_dict.get(line_key, [])
|
|
303
|
+
for line_index, messages in lines:
|
|
304
|
+
if color == 'yellow' and line_index in red_lines_num_set:
|
|
305
|
+
continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
|
|
306
|
+
|
|
307
|
+
for msg in messages:
|
|
308
|
+
if err_msg[line_index] == '':
|
|
309
|
+
err_msg[line_index] = msg
|
|
310
|
+
else:
|
|
311
|
+
err_msg[line_index] += '\n' + msg
|
|
312
|
+
|
|
313
|
+
if color == 'red':
|
|
314
|
+
red_lines_num_set.add(line_index)
|
|
315
|
+
|
|
316
|
+
result_df[CompareConst.ERROR_MESSAGE] = err_msg
|
|
248
317
|
|
|
249
318
|
|
|
250
319
|
def csv_value_is_valid(value: str) -> bool:
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.compare.layer_mapping.layer_mapping import (
|
|
17
|
+
generate_data_mapping_by_layer_mapping,
|
|
18
|
+
generate_api_mapping_by_layer_mapping,
|
|
19
|
+
)
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import ClassVar, Dict, List, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import yaml
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.file_utils import save_yaml
|
|
25
|
+
from msprobe.core.common.log import logger
|
|
26
|
+
from msprobe.core.common.utils import CompareException, add_time_with_yaml
|
|
27
|
+
from msprobe.core.compare.layer_mapping.postprocess_pass import postprocess_pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class DumpDataItem:
|
|
32
|
+
framework: str
|
|
33
|
+
data_name: Optional[str] = None
|
|
34
|
+
api_type: Optional[str] = None
|
|
35
|
+
api_name: Optional[str] = None
|
|
36
|
+
type_name: Optional[str] = None
|
|
37
|
+
full_scope: str = ""
|
|
38
|
+
layer_scope: str = ""
|
|
39
|
+
stack_scope: str = ""
|
|
40
|
+
frame_stack_scope: str = ""
|
|
41
|
+
user_stack_scope: str = ""
|
|
42
|
+
construct_scope: str = ""
|
|
43
|
+
scope_direction: Optional[str] = None
|
|
44
|
+
scope_id: Optional[int] = None
|
|
45
|
+
|
|
46
|
+
# 类变量使用 ClassVar
|
|
47
|
+
framework2layername: ClassVar[Dict[str, str]] = {
|
|
48
|
+
Const.MS_FRAMEWORK: Const.CELL, Const.PT_FRAMEWORK: Const.MODULE}
|
|
49
|
+
framework2stack_sign: ClassVar[Dict[str, Tuple[str, str]]] = {
|
|
50
|
+
Const.MS_FRAMEWORK: ("Template", "construct"),
|
|
51
|
+
Const.PT_FRAMEWORK: ("Template", r"in (for|back)ward,")
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def check_stack_valid(stack_info):
|
|
56
|
+
if stack_info is not None:
|
|
57
|
+
if not isinstance(stack_info, list):
|
|
58
|
+
logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
|
|
59
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
60
|
+
for stack in stack_info:
|
|
61
|
+
if not isinstance(stack, str):
|
|
62
|
+
logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
|
|
63
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
64
|
+
|
|
65
|
+
def set(self, data_name: str, construct_info: str, stack_info: str) -> None:
|
|
66
|
+
self.set_name(data_name)
|
|
67
|
+
self.set_layer_scope(construct_info)
|
|
68
|
+
self.set_stack_scope(stack_info)
|
|
69
|
+
self.set_full_scope()
|
|
70
|
+
|
|
71
|
+
def set_name(self, data_name: str) -> None:
|
|
72
|
+
self.data_name = data_name
|
|
73
|
+
data_name_list = data_name.split(Const.SEP)
|
|
74
|
+
if not data_name_list or len(data_name_list) < abs(Const.LAYER_NAME_INDEX):
|
|
75
|
+
logger.error(
|
|
76
|
+
f"The dump data does not comply with the format specification and "
|
|
77
|
+
f"must contain no less than four fields. "
|
|
78
|
+
f"The current data is {data_name}"
|
|
79
|
+
)
|
|
80
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
81
|
+
|
|
82
|
+
self.api_type = data_name_list[Const.API_TYPE_INDEX]
|
|
83
|
+
self.type_name = data_name_list[Const.TYPE_NAME_INDEX]
|
|
84
|
+
if self.api_type == self.framework2layername.get(self.framework):
|
|
85
|
+
self.api_name = data_name_list[Const.LAYER_NAME_INDEX]
|
|
86
|
+
else:
|
|
87
|
+
self.api_name = self.type_name
|
|
88
|
+
|
|
89
|
+
def set_layer_scope(self, construct_info: str) -> None:
|
|
90
|
+
self.construct_scope = construct_info
|
|
91
|
+
if self.api_type == self.framework2layername.get(self.framework):
|
|
92
|
+
# remove api name
|
|
93
|
+
data_list = self.data_name.split(Const.SEP)
|
|
94
|
+
data_list = data_list[:Const.LAYER_NAME_INDEX] + data_list[Const.TYPE_NAME_INDEX:]
|
|
95
|
+
elif construct_info:
|
|
96
|
+
data_list = construct_info.split(Const.SEP)
|
|
97
|
+
else:
|
|
98
|
+
data_list = []
|
|
99
|
+
|
|
100
|
+
if data_list:
|
|
101
|
+
self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX])
|
|
102
|
+
else:
|
|
103
|
+
self.layer_scope = self.framework2layername.get(self.framework)
|
|
104
|
+
if construct_info:
|
|
105
|
+
construct_list = construct_info.split(Const.SEP)
|
|
106
|
+
if len(construct_list) < abs(Const.LAYER_NAME_INDEX):
|
|
107
|
+
logger.error(
|
|
108
|
+
f"The construct data does not comply with the format specification and "
|
|
109
|
+
f"must contain no less than four fields. "
|
|
110
|
+
f"The current data is {construct_info}"
|
|
111
|
+
)
|
|
112
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
113
|
+
self.scope_id = construct_list[Const.SCOPE_ID_INDEX]
|
|
114
|
+
self.scope_direction = construct_list[Const.SCOPE_DIRECTION_INDEX]
|
|
115
|
+
|
|
116
|
+
def set_stack_scope(self, stack_info: str) -> None:
|
|
117
|
+
# Cell/Module has no stack info
|
|
118
|
+
if self.api_type == self.framework2layername.get(self.framework):
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
if self.api_type in Const.DATA_TYPE_SKIP_LIST or not stack_info:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
start_sign, end_sign = self.framework2stack_sign.get(self.framework)
|
|
125
|
+
self.check_stack_valid(stack_info)
|
|
126
|
+
start_pos, end_pos = find_regard_scope(stack_info, start_sign, end_sign)
|
|
127
|
+
# 获取指定范围的代码
|
|
128
|
+
regard_scope = stack_info[start_pos + 1:end_pos]
|
|
129
|
+
frame_func_stack_list, user_func_stack_list = find_stack_func_list(regard_scope)
|
|
130
|
+
self.frame_stack_scope = Const.SEP.join(frame_func_stack_list)
|
|
131
|
+
self.user_stack_scope = Const.SEP.join(user_func_stack_list)
|
|
132
|
+
|
|
133
|
+
def set_full_scope(self, use_user_func_scope=False, use_frame_func_scope=True) -> None:
|
|
134
|
+
scope_list = [self.layer_scope]
|
|
135
|
+
if use_user_func_scope and self.user_stack_scope:
|
|
136
|
+
scope_list.append(self.user_stack_scope)
|
|
137
|
+
if use_frame_func_scope and self.frame_stack_scope:
|
|
138
|
+
scope_list.append(self.frame_stack_scope)
|
|
139
|
+
scope_list.append(self.api_name)
|
|
140
|
+
self.full_scope = Const.SEP.join(scope_list)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def find_regard_scope(lines, start_sign, end_sign):
|
|
144
|
+
# 找出 start_pos 和 end_pos
|
|
145
|
+
start_pos = -1
|
|
146
|
+
end_pos = len(lines)
|
|
147
|
+
for idx, ii in enumerate(lines):
|
|
148
|
+
if re.search(start_sign, ii):
|
|
149
|
+
start_pos = idx
|
|
150
|
+
elif start_pos >= 0 and re.search(end_sign, ii):
|
|
151
|
+
end_pos = idx
|
|
152
|
+
break
|
|
153
|
+
return start_pos, end_pos
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def find_stack_func_list(lines, record_user=True):
|
|
157
|
+
res_list = []
|
|
158
|
+
user_stack = []
|
|
159
|
+
frame_stack = None
|
|
160
|
+
no_entrance = True
|
|
161
|
+
for line in lines:
|
|
162
|
+
ele_list = line.split(Const.COMMA)
|
|
163
|
+
file_ele = ele_list[Const.STACK_FILE_INDEX]
|
|
164
|
+
# if framework func line and no framework entrance found yet
|
|
165
|
+
if any(ii in file_ele for ii in Const.FRAME_FILE_LIST) and no_entrance:
|
|
166
|
+
frame_stack = line # Update the last target index
|
|
167
|
+
else:
|
|
168
|
+
if record_user:
|
|
169
|
+
user_stack.append(line)
|
|
170
|
+
no_entrance = False
|
|
171
|
+
|
|
172
|
+
# Check if the last string in the list contains target str
|
|
173
|
+
if frame_stack and no_entrance:
|
|
174
|
+
no_entrance = False
|
|
175
|
+
|
|
176
|
+
# 过滤和处理 regard_scope
|
|
177
|
+
frame_func = get_stack_in_lines([frame_stack])
|
|
178
|
+
user_func = get_stack_in_lines(user_stack)
|
|
179
|
+
return (frame_func, user_func)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_stack_in_lines(simplified: List[str]):
|
|
183
|
+
res_list = []
|
|
184
|
+
if not simplified:
|
|
185
|
+
return res_list
|
|
186
|
+
for line in simplified:
|
|
187
|
+
if not line:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
ele_list = line.split(Const.COMMA)
|
|
191
|
+
file_ele = ele_list[Const.STACK_FILE_INDEX]
|
|
192
|
+
if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
func_ele = ele_list[Const.STACK_FUNC_INDEX]
|
|
196
|
+
if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
|
|
200
|
+
|
|
201
|
+
res_list.append(in_func_name)
|
|
202
|
+
|
|
203
|
+
reversed_list = res_list[::-1]
|
|
204
|
+
return reversed_list
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def dumpdata_representer(dumper, data):
|
|
208
|
+
d = deepcopy(data.__dict__)
|
|
209
|
+
d.pop("data_name")
|
|
210
|
+
return dumper.represent_dict(d)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_dump_data_items(dump, stack, construct, framework, output_path=None):
|
|
214
|
+
if not stack or not construct:
|
|
215
|
+
return []
|
|
216
|
+
name2item = {}
|
|
217
|
+
data_items = []
|
|
218
|
+
|
|
219
|
+
dump_data = dump.get("data", {})
|
|
220
|
+
for data_name in dump_data:
|
|
221
|
+
code_info = stack.get(data_name, None)
|
|
222
|
+
parent_info = construct.get(data_name, None)
|
|
223
|
+
data_item = DumpDataItem(framework)
|
|
224
|
+
data_item.set(data_name, parent_info, code_info)
|
|
225
|
+
name2item[data_name] = data_item
|
|
226
|
+
data_items.append(data_item)
|
|
227
|
+
|
|
228
|
+
postprocess_pass(data_items, name2item)
|
|
229
|
+
|
|
230
|
+
if output_path:
|
|
231
|
+
yaml.add_representer(DumpDataItem, dumpdata_representer)
|
|
232
|
+
file_name = add_time_with_yaml(f"{framework}_data")
|
|
233
|
+
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
234
|
+
save_yaml(file_path, name2item)
|
|
235
|
+
return data_items
|