mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -1,12 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import copy
|
|
2
|
-
import csv
|
|
3
17
|
import glob
|
|
4
18
|
import os
|
|
19
|
+
import re
|
|
5
20
|
|
|
6
21
|
import numpy as np
|
|
7
22
|
import pandas as pd
|
|
8
|
-
from msprobe.core.common.const import CompareConst, GraphMode, Const
|
|
9
|
-
from msprobe.core.common.file_utils import
|
|
23
|
+
from msprobe.core.common.const import CompareConst, GraphMode, Const
|
|
24
|
+
from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
|
|
10
25
|
from msprobe.core.common.log import logger
|
|
11
26
|
from msprobe.core.common.utils import add_time_with_xlsx, CompareException
|
|
12
27
|
from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
|
|
@@ -14,7 +29,7 @@ from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_che
|
|
|
14
29
|
from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
|
|
15
30
|
|
|
16
31
|
|
|
17
|
-
class
|
|
32
|
+
class RowData:
|
|
18
33
|
def __init__(self, mode):
|
|
19
34
|
self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
|
|
20
35
|
self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
|
|
@@ -28,17 +43,34 @@ class row_data:
|
|
|
28
43
|
return self.data
|
|
29
44
|
|
|
30
45
|
|
|
46
|
+
def get_name_dict(name: str) -> dict:
|
|
47
|
+
compare_pattern = re.compile(r'^([^.]+)\.([^.]+)\.([^.]+)\.([^.]+)\.(\d+(?:\.\d+)*)\.'
|
|
48
|
+
r'((?:in|out)put(?:\.\d+)*)\.([^.]+)\.([^.]+)\.npy$')
|
|
49
|
+
match = compare_pattern.match(name)
|
|
50
|
+
if match:
|
|
51
|
+
return {'op_type': match.group(1),
|
|
52
|
+
'op_name': match.group(2),
|
|
53
|
+
'task_id': match.group(3),
|
|
54
|
+
'stream_id': match.group(4),
|
|
55
|
+
'timestamp': match.group(5).split(Const.SEP)[0],
|
|
56
|
+
'input_output_index': match.group(6),
|
|
57
|
+
'slot': match.group(7),
|
|
58
|
+
'format': match.group(8)}
|
|
59
|
+
return {}
|
|
60
|
+
|
|
61
|
+
|
|
31
62
|
def npy_data_read(data_path, npy_file_list, mapping_dict):
|
|
32
63
|
data_list = []
|
|
64
|
+
compare_key_elements = ['op_name', 'task_id', 'input_output_index', 'slot']
|
|
33
65
|
for data in npy_file_list:
|
|
34
66
|
if data in mapping_dict:
|
|
35
|
-
|
|
67
|
+
name_dict = get_name_dict(mapping_dict[data])
|
|
36
68
|
else:
|
|
37
|
-
|
|
38
|
-
if
|
|
69
|
+
name_dict = get_name_dict(data)
|
|
70
|
+
if not name_dict:
|
|
39
71
|
continue
|
|
40
|
-
compare_key =
|
|
41
|
-
timestamp = convert_to_int(
|
|
72
|
+
compare_key = Const.SEP.join([name_dict.get(element) for element in compare_key_elements])
|
|
73
|
+
timestamp = convert_to_int(name_dict.get('timestamp'))
|
|
42
74
|
|
|
43
75
|
data_list.append([os.path.join(data_path, data), compare_key, timestamp])
|
|
44
76
|
return data_list
|
|
@@ -48,18 +80,17 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
48
80
|
data_list = []
|
|
49
81
|
statistic_data_list = []
|
|
50
82
|
header_index = {
|
|
51
|
-
'Data Type': None, 'Shape': None, 'Max Value': None,
|
|
52
|
-
'Min Value': None,'Avg Value': None, 'L2Norm Value': None
|
|
83
|
+
'Data Type': None, 'Shape': None, 'Max Value': None,
|
|
84
|
+
'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
|
|
53
85
|
}
|
|
54
86
|
for statistic_file in statistic_file_list:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
for
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
statistic_data_list.extend([row for row in csv_reader])
|
|
87
|
+
content = read_csv(statistic_file, as_pd=False)
|
|
88
|
+
header = content[0]
|
|
89
|
+
for key in header_index.keys():
|
|
90
|
+
for index, value in enumerate(header):
|
|
91
|
+
if key == value:
|
|
92
|
+
header_index[key] = index
|
|
93
|
+
statistic_data_list.extend(content[1:])
|
|
63
94
|
|
|
64
95
|
for key in header_index.keys():
|
|
65
96
|
if header_index[key] is None:
|
|
@@ -97,11 +128,9 @@ def generate_data_name(data_path):
|
|
|
97
128
|
mapping_dict = {}
|
|
98
129
|
if mapping_exist:
|
|
99
130
|
for mapping_file in mapping_file_list:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
for row in csv_reader:
|
|
104
|
-
mapping_dict[row[0]] = row[1]
|
|
131
|
+
content = read_csv(mapping_file, False)
|
|
132
|
+
for row in content[1:]:
|
|
133
|
+
mapping_dict[row[0]] = row[1]
|
|
105
134
|
|
|
106
135
|
if npy_exist:
|
|
107
136
|
data_list = npy_data_read(data_path, npy_file_list, mapping_dict)
|
|
@@ -136,7 +165,7 @@ class GraphMSComparator:
|
|
|
136
165
|
def compare_ops(compare_result_db, mode):
|
|
137
166
|
|
|
138
167
|
def npy_mode_compute(row):
|
|
139
|
-
result_dict =
|
|
168
|
+
result_dict = RowData(GraphMode.NPY_MODE)()
|
|
140
169
|
|
|
141
170
|
def process_npy_file(file_path, name_prefix, result):
|
|
142
171
|
if os.path.exists(file_path):
|
|
@@ -171,7 +200,7 @@ class GraphMSComparator:
|
|
|
171
200
|
return pd.Series(result_dict)
|
|
172
201
|
|
|
173
202
|
def statistic_mode_compute(row):
|
|
174
|
-
result_dict =
|
|
203
|
+
result_dict = RowData('STATISTIC')()
|
|
175
204
|
|
|
176
205
|
def update_result_dict(result, rows, prefix):
|
|
177
206
|
result[f'{prefix} Name'] = rows[f'{prefix} Name']
|
|
@@ -198,24 +227,30 @@ class GraphMSComparator:
|
|
|
198
227
|
result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM])
|
|
199
228
|
result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[
|
|
200
229
|
CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0
|
|
201
|
-
|
|
230
|
+
if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]):
|
|
231
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] = str(
|
|
232
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
|
|
202
233
|
result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[
|
|
203
234
|
CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0
|
|
204
|
-
|
|
235
|
+
if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]):
|
|
236
|
+
result_dict[CompareConst.MIN_RELATIVE_ERR] = \
|
|
237
|
+
str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
|
|
205
238
|
result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[
|
|
206
239
|
CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0
|
|
207
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR]
|
|
208
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR]
|
|
240
|
+
if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
|
|
241
|
+
result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
|
|
242
|
+
result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
|
|
209
243
|
result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
|
|
210
244
|
CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
|
|
211
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR]
|
|
212
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR]
|
|
245
|
+
if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
|
|
246
|
+
result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
|
|
247
|
+
result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
|
|
213
248
|
magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
|
|
214
249
|
max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
|
|
215
|
-
if
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
250
|
+
if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
|
|
251
|
+
magnitude_diff = 0
|
|
252
|
+
result_dict[CompareConst.ACCURACY] = CompareConst.YES if \
|
|
253
|
+
magnitude_diff <= CompareConst.MAGNITUDE else CompareConst.NO
|
|
219
254
|
|
|
220
255
|
return pd.Series(result_dict)
|
|
221
256
|
|
|
@@ -238,24 +273,23 @@ class GraphMSComparator:
|
|
|
238
273
|
is_empty = True
|
|
239
274
|
if is_empty or not mode:
|
|
240
275
|
continue
|
|
241
|
-
compare_result_df = self.
|
|
276
|
+
compare_result_df = self.do_multi_process(compare_result_df, mode)
|
|
242
277
|
compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
|
|
243
278
|
compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
|
|
244
|
-
check_path_before_create(compare_result_path)
|
|
245
279
|
self.to_excel(compare_result_df, compare_result_path)
|
|
246
280
|
logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
|
|
247
|
-
|
|
281
|
+
|
|
248
282
|
def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
|
|
249
283
|
size = len(compare_result_df)
|
|
250
284
|
# sheet size cannot be larger than 1048576
|
|
251
285
|
if size < CompareConst.MAX_EXCEL_LENGTH:
|
|
252
|
-
compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if
|
|
253
|
-
|
|
254
|
-
|
|
286
|
+
compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
|
|
287
|
+
need_slice else compare_result_path
|
|
288
|
+
save_excel(compare_result_path, compare_result_df)
|
|
255
289
|
return slice_num + 1
|
|
256
290
|
else:
|
|
257
|
-
slice_num = self.to_excel(compare_result_df.iloc[0: size//2], compare_result_path, slice_num, True)
|
|
258
|
-
return self.to_excel(compare_result_df.iloc[size//2:], compare_result_path, slice_num, True)
|
|
291
|
+
slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
|
|
292
|
+
return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
|
|
259
293
|
|
|
260
294
|
def compare_process(self, rank_id, step_id):
|
|
261
295
|
# generate data_path
|
|
@@ -303,8 +337,8 @@ class GraphMSComparator:
|
|
|
303
337
|
npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
|
|
304
338
|
|
|
305
339
|
bench_float_type = [
|
|
306
|
-
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
307
|
-
CompareConst.BENCH_MEAN,CompareConst.BENCH_NORM
|
|
340
|
+
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
341
|
+
CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
|
|
308
342
|
]
|
|
309
343
|
bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
|
|
310
344
|
|
|
@@ -355,7 +389,7 @@ class GraphMSComparator:
|
|
|
355
389
|
rank_step_path_dict[rank_step_key] = [dir_path]
|
|
356
390
|
return dict(sorted(rank_step_path_dict.items()))
|
|
357
391
|
|
|
358
|
-
def
|
|
392
|
+
def do_multi_process(self, result_df, mode):
|
|
359
393
|
try:
|
|
360
394
|
result_df = _ms_graph_handle_multi_process(self.compare_ops, result_df, mode)
|
|
361
395
|
except ValueError as e:
|
|
@@ -33,7 +33,7 @@ class DebuggerConfig:
|
|
|
33
33
|
self.level_ori = common_config.level
|
|
34
34
|
self.list = [] if not task_config.list else task_config.list
|
|
35
35
|
self.scope = [] if not task_config.scope else task_config.scope
|
|
36
|
-
self.data_mode = [] if not task_config.data_mode else task_config.data_mode
|
|
36
|
+
self.data_mode = [Const.ALL] if not task_config.data_mode else task_config.data_mode
|
|
37
37
|
self.file_format = task_config.file_format
|
|
38
38
|
self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
|
|
39
39
|
self.check_mode = task_config.check_mode
|
|
@@ -52,6 +52,9 @@ class DebuggerConfig:
|
|
|
52
52
|
self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
|
|
53
53
|
raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
|
|
54
54
|
f"but got {self.pert_type}.")
|
|
55
|
+
if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
|
|
56
|
+
raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
|
|
57
|
+
f"but got {self.handler_type}.")
|
|
55
58
|
self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
|
|
56
59
|
|
|
57
60
|
def check(self):
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
from collections import defaultdict
|
|
17
18
|
|
|
18
19
|
import mindspore as ms
|
|
19
20
|
from mindspore._c_expression import MSContext
|
|
20
21
|
|
|
21
22
|
from msprobe.core.common.const import Const, MsgConst
|
|
23
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
22
24
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
23
25
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
26
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
24
27
|
from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
|
|
25
28
|
from msprobe.mindspore.ms_config import parse_json_config
|
|
26
29
|
from msprobe.mindspore.runtime import Runtime
|
|
@@ -128,6 +131,9 @@ class PrecisionDebugger:
|
|
|
128
131
|
return
|
|
129
132
|
if instance.service:
|
|
130
133
|
instance.service.step()
|
|
134
|
+
HOOKCell.cell_count = defaultdict(int)
|
|
135
|
+
CellProcessor.reset_cell_stats()
|
|
136
|
+
|
|
131
137
|
Runtime.step_count += 1
|
|
132
138
|
|
|
133
139
|
@classmethod
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -40,6 +40,8 @@ class DumpToolFactory:
|
|
|
40
40
|
|
|
41
41
|
@staticmethod
|
|
42
42
|
def create(config: DebuggerConfig):
|
|
43
|
+
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
|
|
44
|
+
raise Exception("data_mode must be one of all, input, output.")
|
|
43
45
|
tool = DumpToolFactory.tools.get(config.level)
|
|
44
46
|
if not tool:
|
|
45
47
|
raise Exception("Valid level is needed.")
|
|
@@ -24,6 +24,12 @@ from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTenso
|
|
|
24
24
|
from msprobe.core.common.utils import Const
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
def stub_method(method):
|
|
28
|
+
def wrapped_method(*args, **kwargs):
|
|
29
|
+
return method(*args, **kwargs)
|
|
30
|
+
return wrapped_method
|
|
31
|
+
|
|
32
|
+
|
|
27
33
|
class ApiRegistry:
|
|
28
34
|
def __init__(self):
|
|
29
35
|
self.tensor_ori_attr = {}
|
|
@@ -50,9 +56,13 @@ class ApiRegistry:
|
|
|
50
56
|
if Const.SEP in api:
|
|
51
57
|
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
52
58
|
sub_module = getattr(ori_api_group, sub_module_name)
|
|
53
|
-
|
|
59
|
+
ori_api_func = getattr(sub_module, sub_op)
|
|
54
60
|
else:
|
|
55
|
-
|
|
61
|
+
ori_api_func = getattr(ori_api_group, api)
|
|
62
|
+
if ori_api_group == StubTensor:
|
|
63
|
+
api_ori_attr[api] = stub_method(ori_api_func)
|
|
64
|
+
continue
|
|
65
|
+
api_ori_attr[api] = ori_api_func
|
|
56
66
|
|
|
57
67
|
@staticmethod
|
|
58
68
|
def set_api_attr(api_group, attr_dict):
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,18 +12,16 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
|
-
import mindspore as ms
|
|
19
|
-
from mindspore.common.tensor import Tensor
|
|
20
18
|
from mindspore import ops
|
|
19
|
+
from mindspore.common.tensor import Tensor
|
|
21
20
|
|
|
22
|
-
from msprobe.mindspore.common.log import logger
|
|
23
21
|
from msprobe.core.common.utils import Const, DumpException
|
|
24
|
-
from msprobe.core.data_dump.data_processor.base import
|
|
25
|
-
|
|
22
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
|
|
23
|
+
ModuleForwardInputsOutputs)
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
class PrimitiveHookService:
|
|
@@ -41,6 +40,7 @@ class PrimitiveHookService:
|
|
|
41
40
|
Returns:
|
|
42
41
|
callable: 包装后的 primitive 函数。
|
|
43
42
|
"""
|
|
43
|
+
|
|
44
44
|
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
|
|
45
45
|
"""
|
|
46
46
|
创建反向 hook 函数,用于捕获梯度。
|
|
@@ -54,26 +54,24 @@ class PrimitiveHookService:
|
|
|
54
54
|
Returns:
|
|
55
55
|
callable: 反向 hook 函数。
|
|
56
56
|
"""
|
|
57
|
-
def backward_hook(grad):
|
|
58
57
|
|
|
59
|
-
|
|
58
|
+
def backward_hook(grad):
|
|
59
|
+
captured_grads.extend(grad)
|
|
60
60
|
backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
|
|
61
61
|
|
|
62
62
|
try:
|
|
63
|
-
if
|
|
63
|
+
if hook_type == Const.INPUT:
|
|
64
64
|
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
65
65
|
new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
|
|
66
66
|
self.service_instance.data_collector.backward_output_data_collect(
|
|
67
67
|
backward_primitive_name, self, os.getpid(), new_module_input_output
|
|
68
68
|
)
|
|
69
|
-
|
|
70
|
-
elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
|
|
69
|
+
elif hook_type == Const.OUTPUT:
|
|
71
70
|
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
72
71
|
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
|
|
73
72
|
self.service_instance.data_collector.backward_input_data_collect(
|
|
74
73
|
backward_primitive_name, self, os.getpid(), new_module_input_output
|
|
75
74
|
)
|
|
76
|
-
captured_grads.clear()
|
|
77
75
|
|
|
78
76
|
except Exception as exception:
|
|
79
77
|
logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
|
|
@@ -104,7 +102,7 @@ class PrimitiveHookService:
|
|
|
104
102
|
hooked_inputs.append(arg_hooked)
|
|
105
103
|
else:
|
|
106
104
|
hooked_inputs.append(arg)
|
|
107
|
-
return hooked_inputs
|
|
105
|
+
return tuple(hooked_inputs)
|
|
108
106
|
|
|
109
107
|
def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
|
|
110
108
|
"""
|
|
@@ -178,7 +176,7 @@ class PrimitiveHookService:
|
|
|
178
176
|
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
|
|
179
177
|
try:
|
|
180
178
|
self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
|
|
181
|
-
|
|
179
|
+
os.getpid(), module_input_output)
|
|
182
180
|
except Exception as exception:
|
|
183
181
|
logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
|
|
184
182
|
f"primitive_name: {primitive_name}")
|
|
@@ -203,4 +201,3 @@ class PrimitiveHookService:
|
|
|
203
201
|
self.primitive_counters[primitive_name] = 0
|
|
204
202
|
else:
|
|
205
203
|
self.primitive_counters[primitive_name] += 1
|
|
206
|
-
|
|
@@ -490,6 +490,31 @@ ops:
|
|
|
490
490
|
- scatter_update
|
|
491
491
|
- derivative
|
|
492
492
|
- jet
|
|
493
|
+
- row_stack
|
|
494
|
+
- gather
|
|
495
|
+
- arange
|
|
496
|
+
- cond
|
|
497
|
+
- slice_scatter
|
|
498
|
+
- clip_by_norm
|
|
499
|
+
- eps
|
|
500
|
+
- layer_norm
|
|
501
|
+
- cast
|
|
502
|
+
- numel
|
|
503
|
+
- permute
|
|
504
|
+
- select_scatter
|
|
505
|
+
- group_norm
|
|
506
|
+
- eq
|
|
507
|
+
- embedding
|
|
508
|
+
- ones_like
|
|
509
|
+
- zeros
|
|
510
|
+
- nanmean
|
|
511
|
+
- shape
|
|
512
|
+
- zeros_like
|
|
513
|
+
- ones
|
|
514
|
+
- diagonal_scatter
|
|
515
|
+
- vander
|
|
516
|
+
- is_nonzero
|
|
517
|
+
- rotary_position_embedding
|
|
493
518
|
|
|
494
519
|
tensor:
|
|
495
520
|
- __abs__
|
|
@@ -20,7 +20,7 @@ from mindspore import Tensor
|
|
|
20
20
|
from mindspore._c_expression import PyNativeExecutor_
|
|
21
21
|
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
22
22
|
|
|
23
|
-
from msprobe.
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
24
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
25
25
|
from msprobe.core.common.const import Const
|
|
26
26
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
|
|
@@ -33,6 +33,8 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
33
33
|
index = ori_args.find("<")
|
|
34
34
|
if index != 0 and index != -1:
|
|
35
35
|
result = ori_args[0:index]
|
|
36
|
+
elif name is not None and "<" not in str(name):
|
|
37
|
+
result = str(name)
|
|
36
38
|
else:
|
|
37
39
|
result = "JitFunction"
|
|
38
40
|
if JitDump.need_dump():
|
|
@@ -47,7 +49,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
47
49
|
name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
|
|
48
50
|
Const.BACKWARD
|
|
49
51
|
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
50
|
-
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat
|
|
52
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
|
|
51
53
|
JitDump.data_collector.backward_data_collect(name_template, None, pid, module_input_output)
|
|
52
54
|
|
|
53
55
|
|
|
@@ -59,15 +61,25 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
59
61
|
|
|
60
62
|
def __init__(self, *args, **kwargs):
|
|
61
63
|
super().__init__(*args, **kwargs)
|
|
64
|
+
self.name = None
|
|
65
|
+
if len(args) > 0:
|
|
66
|
+
self.name = args[0].__name__
|
|
62
67
|
self._executor = PyNativeExecutor_.get_instance()
|
|
63
68
|
|
|
64
69
|
def __call__(self, *args, **kwargs):
|
|
65
|
-
|
|
70
|
+
if JitDump.jit_dump_switch:
|
|
71
|
+
api_register.api_set_ori_func()
|
|
66
72
|
out = super().__call__(*args, **kwargs)
|
|
67
73
|
if JitDump.jit_dump_switch and len(args) > 0:
|
|
68
|
-
|
|
74
|
+
if self.name and self.name != "construct":
|
|
75
|
+
dump_jit(self.name, args, out, True)
|
|
76
|
+
else:
|
|
77
|
+
dump_jit(args[0], args, out, True)
|
|
69
78
|
JitDump.jit_enable = True
|
|
70
|
-
|
|
79
|
+
elif len(args) == 0:
|
|
80
|
+
logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
|
|
81
|
+
if JitDump.jit_dump_switch:
|
|
82
|
+
api_register.api_set_hook_func()
|
|
71
83
|
return out
|
|
72
84
|
|
|
73
85
|
@classmethod
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
16
|
import os
|
|
18
17
|
|
|
19
|
-
from msprobe.core.common.file_utils import
|
|
18
|
+
from msprobe.core.common.file_utils import create_directory, save_json
|
|
20
19
|
from msprobe.mindspore.common.log import logger
|
|
21
20
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
22
21
|
|
|
@@ -62,8 +61,7 @@ class KernelGraphDump:
|
|
|
62
61
|
json_path = self.dump_json["common_dump_settings"]["path"]
|
|
63
62
|
create_directory(json_path)
|
|
64
63
|
json_path = os.path.join(json_path, "kernel_graph_dump.json")
|
|
65
|
-
|
|
66
|
-
json.dump(self.dump_json, f)
|
|
64
|
+
save_json(json_path, self.dump_json, indent=4)
|
|
67
65
|
logger.info(json_path + " has been created.")
|
|
68
66
|
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|
|
69
67
|
if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
|
|
@@ -13,11 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
16
|
import os
|
|
18
17
|
|
|
19
18
|
from msprobe.core.common.const import Const
|
|
20
|
-
from msprobe.core.common.file_utils import
|
|
19
|
+
from msprobe.core.common.file_utils import create_directory, save_json
|
|
21
20
|
from msprobe.mindspore.common.log import logger
|
|
22
21
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
23
22
|
|
|
@@ -70,8 +69,7 @@ class KernelKbykDump:
|
|
|
70
69
|
json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
|
|
71
70
|
create_directory(json_path)
|
|
72
71
|
json_path = os.path.join(json_path, "kernel_kbyk_dump.json")
|
|
73
|
-
|
|
74
|
-
json.dump(self.dump_json, f)
|
|
72
|
+
save_json(json_path, self.dump_json, indent=4)
|
|
75
73
|
logger.info(json_path + " has been created.")
|
|
76
74
|
|
|
77
75
|
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|