mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -28,7 +28,7 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
|
28
28
|
from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
|
|
29
29
|
check_path_executable, check_path_owner_consistent
|
|
30
30
|
from msprobe.core.common.const import FileCheckConst
|
|
31
|
-
from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type
|
|
31
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type, os_walk_for_files
|
|
32
32
|
from msprobe.pytorch.common.log import logger
|
|
33
33
|
|
|
34
34
|
|
|
@@ -81,16 +81,8 @@ class Util:
|
|
|
81
81
|
|
|
82
82
|
@staticmethod
|
|
83
83
|
def get_subfiles_count(directory):
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
check_file_or_directory_path(root, isdir=True)
|
|
87
|
-
file_count += len(files)
|
|
88
|
-
path_depth = root.count(os.sep)
|
|
89
|
-
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
90
|
-
yield root, _, files
|
|
91
|
-
else:
|
|
92
|
-
_[:] = []
|
|
93
|
-
return file_count
|
|
84
|
+
files = os_walk_for_files(directory, Const.MAX_TRAVERSAL_DEPTH)
|
|
85
|
+
return len(files)
|
|
94
86
|
|
|
95
87
|
@staticmethod
|
|
96
88
|
def get_sorted_subdirectories_names(directory):
|
|
@@ -146,16 +138,10 @@ class Util:
|
|
|
146
138
|
|
|
147
139
|
@staticmethod
|
|
148
140
|
def dir_contains_only(path, endfix):
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
return False
|
|
154
|
-
path_depth = root.count(os.sep)
|
|
155
|
-
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
156
|
-
yield root, _, files
|
|
157
|
-
else:
|
|
158
|
-
_[:] = []
|
|
141
|
+
files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
|
|
142
|
+
for file in files:
|
|
143
|
+
if not file['file'].endswith(endfix):
|
|
144
|
+
return False
|
|
159
145
|
return True
|
|
160
146
|
|
|
161
147
|
@staticmethod
|
|
@@ -273,20 +259,15 @@ class Util:
|
|
|
273
259
|
self.check_path_valid(path)
|
|
274
260
|
file_list = {}
|
|
275
261
|
re_pattern = re.compile(pattern)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
path_depth = dir_path.count(os.sep)
|
|
286
|
-
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
287
|
-
yield dir_path, _, file_names
|
|
288
|
-
else:
|
|
289
|
-
_[:] = []
|
|
262
|
+
files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
|
|
263
|
+
for file in files:
|
|
264
|
+
name = file["file"]
|
|
265
|
+
match = re_pattern.match(name)
|
|
266
|
+
if not match:
|
|
267
|
+
continue
|
|
268
|
+
if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
|
|
269
|
+
continue
|
|
270
|
+
file_list[name] = gen_info_func(name, match, file["root"])
|
|
290
271
|
return file_list
|
|
291
272
|
|
|
292
273
|
def check_file_path_format(self, path, suffix):
|
|
@@ -65,6 +65,8 @@ class Visualization:
|
|
|
65
65
|
self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line))
|
|
66
66
|
self.util.log.warning("Please check the pkl file")
|
|
67
67
|
raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e
|
|
68
|
+
if not isinstance(msg, list) or len(msg) == 0:
|
|
69
|
+
break
|
|
68
70
|
info_prefix = msg[0]
|
|
69
71
|
if not info_prefix.startswith(api_name):
|
|
70
72
|
continue
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -14,11 +14,13 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
import re
|
|
17
18
|
|
|
18
19
|
from msprobe.core.common.const import Const
|
|
19
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
-
from msprobe.core.common.file_utils import FileOpen, load_json
|
|
21
|
+
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
|
|
21
22
|
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.core.common.utils import is_int
|
|
22
24
|
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
23
25
|
from msprobe.core.grad_probe.constant import level_adp
|
|
24
26
|
from msprobe.core.grad_probe.utils import check_bounds
|
|
@@ -38,17 +40,38 @@ class TensorConfig(BaseConfig):
|
|
|
38
40
|
self.host = json_config.get("host", "")
|
|
39
41
|
self.port = json_config.get("port", -1)
|
|
40
42
|
self.tls_path = json_config.get("tls_path", "./")
|
|
43
|
+
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
41
44
|
self.check_config()
|
|
42
45
|
self._check_file_format()
|
|
43
|
-
self.
|
|
46
|
+
if self.online_run_ut:
|
|
47
|
+
self._check_online_run_ut()
|
|
44
48
|
|
|
45
49
|
def _check_file_format(self):
|
|
46
50
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
47
51
|
raise Exception("file_format is invalid")
|
|
48
52
|
|
|
49
|
-
def
|
|
50
|
-
if
|
|
51
|
-
raise Exception("
|
|
53
|
+
def _check_online_run_ut(self):
|
|
54
|
+
if not isinstance(self.online_run_ut, bool):
|
|
55
|
+
raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
|
|
56
|
+
|
|
57
|
+
if not isinstance(self.online_run_ut_recompute, bool):
|
|
58
|
+
raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
|
|
59
|
+
|
|
60
|
+
if self.nfs_path:
|
|
61
|
+
check_file_or_directory_path(self.nfs_path, isdir=True)
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
if self.tls_path:
|
|
65
|
+
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
66
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
67
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
68
|
+
check_crt_valid(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
+
|
|
70
|
+
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
71
|
+
raise Exception(f"host: {self.host} is invalid.")
|
|
72
|
+
|
|
73
|
+
if not isinstance(self.port, int) or not (0 < self.port <= 65535):
|
|
74
|
+
raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
|
|
52
75
|
|
|
53
76
|
|
|
54
77
|
class StatisticsConfig(BaseConfig):
|
|
@@ -70,7 +93,7 @@ class OverflowCheckConfig(BaseConfig):
|
|
|
70
93
|
self.check_overflow_config()
|
|
71
94
|
|
|
72
95
|
def check_overflow_config(self):
|
|
73
|
-
if self.overflow_nums is not None and not
|
|
96
|
+
if self.overflow_nums is not None and not is_int(self.overflow_nums):
|
|
74
97
|
raise Exception("overflow_num is invalid")
|
|
75
98
|
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
76
99
|
raise Exception("check_mode is invalid")
|
|
@@ -170,7 +193,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
170
193
|
)
|
|
171
194
|
|
|
172
195
|
def _check_preheat_config(self):
|
|
173
|
-
if not
|
|
196
|
+
if not is_int(self.preheat_step):
|
|
174
197
|
msg = "preheat_step is invalid, it should be an integer"
|
|
175
198
|
logger.error_log_with_exp(
|
|
176
199
|
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
@@ -180,7 +203,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
180
203
|
logger.error_log_with_exp(
|
|
181
204
|
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
182
205
|
)
|
|
183
|
-
if not
|
|
206
|
+
if not is_int(self.max_sample):
|
|
184
207
|
msg = "max_sample is invalid, it should be an integer"
|
|
185
208
|
logger.error_log_with_exp(
|
|
186
209
|
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
msprobe/pytorch/service.py
CHANGED
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
import os
|
|
18
|
-
|
|
19
18
|
from collections import namedtuple
|
|
19
|
+
|
|
20
20
|
import torch
|
|
21
21
|
from msprobe.core.common.const import Const
|
|
22
22
|
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
@@ -25,13 +25,14 @@ from msprobe.core.common.utils import print_tools_ends_info
|
|
|
25
25
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
26
26
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
27
27
|
from msprobe.core.data_dump.scope import BaseScope
|
|
28
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
28
29
|
from msprobe.pytorch.common.log import logger
|
|
29
30
|
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
31
|
+
from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
30
32
|
from msprobe.pytorch.hook_module import remove_dropout
|
|
31
33
|
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
32
34
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
33
35
|
from msprobe.pytorch.module_processer import ModuleProcesser
|
|
34
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
35
36
|
|
|
36
37
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
37
38
|
if torch_version_above_or_equal_2:
|
|
@@ -159,10 +160,10 @@ class Service:
|
|
|
159
160
|
if api_origin:
|
|
160
161
|
api_register.api_modularity()
|
|
161
162
|
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
162
|
-
run_ut_dispatch(self.attl, True)
|
|
163
|
+
run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
|
|
163
164
|
self.switch = True
|
|
164
165
|
logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
165
|
-
if
|
|
166
|
+
if not self.config.online_run_ut:
|
|
166
167
|
self.create_dirs()
|
|
167
168
|
logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
168
169
|
|
|
@@ -177,7 +178,7 @@ class Service:
|
|
|
177
178
|
return
|
|
178
179
|
self.switch = False
|
|
179
180
|
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
180
|
-
run_ut_dispatch(self.attl, False)
|
|
181
|
+
run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
|
|
181
182
|
return
|
|
182
183
|
self.data_collector.write_json()
|
|
183
184
|
|
|
@@ -191,6 +192,9 @@ class Service:
|
|
|
191
192
|
HOOKModule.reset_module_stats()
|
|
192
193
|
self.data_collector.data_writer.reset_cache()
|
|
193
194
|
|
|
195
|
+
if self.config.level == Const.LEVEL_L2:
|
|
196
|
+
self.data_collector.data_processor.reset_status()
|
|
197
|
+
|
|
194
198
|
def need_stop_service(self):
|
|
195
199
|
if self.should_stop_service:
|
|
196
200
|
return True
|
|
@@ -221,6 +225,12 @@ class Service:
|
|
|
221
225
|
create_directory(self.config.dump_path)
|
|
222
226
|
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
223
227
|
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
228
|
+
if self.config.level == Const.LEVEL_L2:
|
|
229
|
+
create_directory(self.dump_iter_dir)
|
|
230
|
+
kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
|
|
231
|
+
self.config.kernel_config_path = kernel_config_path
|
|
232
|
+
return
|
|
233
|
+
|
|
224
234
|
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
225
235
|
create_directory(dump_dir)
|
|
226
236
|
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
from msprobe.visualization.graph.graph import Graph
|
|
18
|
+
from msprobe.visualization.graph.node_op import NodeOp
|
|
19
|
+
from msprobe.visualization.utils import save_json_file, GraphConst
|
|
20
|
+
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
|
+
from msprobe.core.common.file_utils import load_json
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GraphBuilder:
|
|
25
|
+
@staticmethod
|
|
26
|
+
def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
|
|
27
|
+
"""
|
|
28
|
+
GraphBuilder的对外提供的构图方法
|
|
29
|
+
Args:
|
|
30
|
+
construct_path: construct.json路径
|
|
31
|
+
data_path: dump.json路径
|
|
32
|
+
stack_path: stack.json路径
|
|
33
|
+
model_name: 模型名字,依赖外部输入
|
|
34
|
+
Returns: Graph,代表图的数据结构
|
|
35
|
+
"""
|
|
36
|
+
construct_dict = load_json(construct_path)
|
|
37
|
+
dump_dict = load_json(data_path)
|
|
38
|
+
stack_dict = load_json(stack_path)
|
|
39
|
+
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
40
|
+
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
41
|
+
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
42
|
+
GraphBuilder._collect_apis_between_modules(graph)
|
|
43
|
+
return graph
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def to_json(filename, config):
|
|
47
|
+
"""
|
|
48
|
+
将graph导出成.vis文件的接口
|
|
49
|
+
"""
|
|
50
|
+
result = {}
|
|
51
|
+
if config.graph_b:
|
|
52
|
+
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
|
|
53
|
+
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
|
|
54
|
+
else:
|
|
55
|
+
result = config.graph_n.to_dict()
|
|
56
|
+
if config.tool_tip:
|
|
57
|
+
result[GraphConst.JSON_TIP_KEY] = config.tool_tip
|
|
58
|
+
if config.node_colors:
|
|
59
|
+
result[GraphConst.COLORS] = config.node_colors
|
|
60
|
+
if config.micro_steps:
|
|
61
|
+
result[GraphConst.MICRO_STEPS] = config.micro_steps
|
|
62
|
+
if config.task:
|
|
63
|
+
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
64
|
+
save_json_file(filename, result)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
|
|
68
|
+
"""
|
|
69
|
+
如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
|
|
70
|
+
"""
|
|
71
|
+
# 匹配以.backward.后跟一个或多个数字结尾的模式
|
|
72
|
+
backward_pattern = r"(\.backward\.)(\d+)$"
|
|
73
|
+
forward_pattern = r"(\.forward\.)(\d+)$"
|
|
74
|
+
if re.search(backward_pattern, subnode_id) and not upnode_id:
|
|
75
|
+
forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
|
|
76
|
+
if forward_upnode_id:
|
|
77
|
+
new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
|
|
78
|
+
if new_upnode_id in construct_dict:
|
|
79
|
+
return new_upnode_id
|
|
80
|
+
return upnode_id
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def _init_nodes(graph, construct_dict, data_dict, stack_dict):
|
|
84
|
+
for subnode_id, upnode_id in construct_dict.items():
|
|
85
|
+
upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
|
|
86
|
+
if upnode_id:
|
|
87
|
+
upnode_op = NodeOp.get_node_op(upnode_id)
|
|
88
|
+
upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
|
|
89
|
+
else:
|
|
90
|
+
upnode = graph.root
|
|
91
|
+
node_op = NodeOp.get_node_op(subnode_id)
|
|
92
|
+
GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], node_op, subnode_id, upnode)
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _create_or_get_node(graph, data_stack_list, op, name, upnode=None):
|
|
96
|
+
if name in graph.node_map:
|
|
97
|
+
node = graph.get_node(name)
|
|
98
|
+
else:
|
|
99
|
+
graph.add_node(op, name, upnode)
|
|
100
|
+
node = graph.get_node(name)
|
|
101
|
+
node_data = data_stack_list[0].get(name, {})
|
|
102
|
+
node_stack_info = data_stack_list[1].get(name, [])
|
|
103
|
+
# 添加输入输出数据
|
|
104
|
+
input_data, output_data = get_input_output(node_data, node.id)
|
|
105
|
+
# 更新数据
|
|
106
|
+
node.set_input_output(input_data, output_data)
|
|
107
|
+
node.stack_info = node_stack_info
|
|
108
|
+
# 添加节点
|
|
109
|
+
node.add_upnode(upnode)
|
|
110
|
+
return node
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _collect_apis_between_modules(graph):
|
|
114
|
+
"""
|
|
115
|
+
图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点
|
|
116
|
+
Args:
|
|
117
|
+
graph: 模型结构
|
|
118
|
+
|
|
119
|
+
Returns: None
|
|
120
|
+
"""
|
|
121
|
+
i = 0
|
|
122
|
+
output = []
|
|
123
|
+
node_list = graph.root.subnodes
|
|
124
|
+
while i < len(node_list):
|
|
125
|
+
current_node = node_list[i]
|
|
126
|
+
|
|
127
|
+
# 当前节点为api,检查后续是否还有api
|
|
128
|
+
if current_node.op == NodeOp.function_api:
|
|
129
|
+
temp_nodes = [current_node]
|
|
130
|
+
i += 1
|
|
131
|
+
while i < len(node_list) and node_list[i].op == NodeOp.function_api:
|
|
132
|
+
temp_nodes.append(node_list[i])
|
|
133
|
+
i += 1
|
|
134
|
+
|
|
135
|
+
# 检查api节点是否大于等于2个
|
|
136
|
+
if len(temp_nodes) >= 2:
|
|
137
|
+
# 创建新节点,将这些api节点放入新节点的subnodes属性
|
|
138
|
+
node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES,
|
|
139
|
+
id_accumulation=True)
|
|
140
|
+
api_collection_node = graph.get_node(node_id)
|
|
141
|
+
api_collection_node.subnodes = temp_nodes
|
|
142
|
+
# 重新确立父子关系
|
|
143
|
+
for node in temp_nodes:
|
|
144
|
+
node.upnode = api_collection_node
|
|
145
|
+
api_collection_node.upnode = graph.root
|
|
146
|
+
output.append(api_collection_node)
|
|
147
|
+
else:
|
|
148
|
+
# 如果连续的api节点不足2个,将它们原样添加到输出列表
|
|
149
|
+
output.extend(temp_nodes)
|
|
150
|
+
else:
|
|
151
|
+
# 如果当前节点为module,直接添加到输出列表
|
|
152
|
+
output.append(current_node)
|
|
153
|
+
i += 1
|
|
154
|
+
|
|
155
|
+
graph.root.subnodes = output
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class GraphExportConfig:
|
|
159
|
+
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task=''):
|
|
160
|
+
self.graph_n = graph_n
|
|
161
|
+
self.graph_b = graph_b
|
|
162
|
+
self.tool_tip = tool_tip
|
|
163
|
+
self.node_colors = node_colors
|
|
164
|
+
self.micro_steps = micro_steps
|
|
165
|
+
self.task = task
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import re
|
|
16
|
+
import math
|
|
17
|
+
from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
|
|
18
|
+
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
|
+
from msprobe.visualization.utils import GraphConst
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
21
|
+
|
|
22
|
+
# 用于将节点名字解析成对应的NodeOp的规则
|
|
23
|
+
op_patterns = [
|
|
24
|
+
# NodeOp.module
|
|
25
|
+
r'^(Module.|Cell.)',
|
|
26
|
+
# NodeOp.function_api
|
|
27
|
+
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_compare_mode(dump_path_param):
|
|
32
|
+
"""
|
|
33
|
+
获得比较模式,包括summary、MD5和真实数据三种模式
|
|
34
|
+
Args:
|
|
35
|
+
dump_path_param: 调用acc_compare接口所依赖的参数
|
|
36
|
+
Returns: 0 summary mode, 1 md5 mode, 2 true data mode
|
|
37
|
+
"""
|
|
38
|
+
set_dump_path(dump_path_param)
|
|
39
|
+
dump_mode = get_dump_mode(dump_path_param)
|
|
40
|
+
compare_mode = GraphConst.DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING.get(dump_mode)
|
|
41
|
+
return compare_mode
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
45
|
+
"""
|
|
46
|
+
多进程运行生成真实数据
|
|
47
|
+
Args:
|
|
48
|
+
dump_path_param: 调用acc_compare接口所依赖的参数
|
|
49
|
+
csv_path: 生成文件路径
|
|
50
|
+
framework: 框架类型, pytorch或mindspore
|
|
51
|
+
is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
|
|
52
|
+
"""
|
|
53
|
+
if framework == Const.PT_FRAMEWORK:
|
|
54
|
+
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
55
|
+
return PTComparator().do_multi_process(dump_path_param, csv_path)
|
|
56
|
+
else:
|
|
57
|
+
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
58
|
+
ms_comparator = MSComparator()
|
|
59
|
+
ms_comparator.cross_frame = is_cross_frame
|
|
60
|
+
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_input_output(node_data, node_id):
|
|
64
|
+
"""
|
|
65
|
+
将dump的原始数据进行拆解,分解为output和input两个数据
|
|
66
|
+
Args:
|
|
67
|
+
node_data: 属于单个节点的dump数据
|
|
68
|
+
node_id: 节点名字
|
|
69
|
+
"""
|
|
70
|
+
input_data = {}
|
|
71
|
+
output_data = {}
|
|
72
|
+
op_parsed_list = read_op(node_data, node_id)
|
|
73
|
+
for item in op_parsed_list:
|
|
74
|
+
full_op_name = item.get('full_op_name', '')
|
|
75
|
+
if not full_op_name:
|
|
76
|
+
continue
|
|
77
|
+
if GraphConst.OUTPUT in full_op_name and GraphConst.INPUT not in full_op_name:
|
|
78
|
+
output_data[full_op_name] = item
|
|
79
|
+
else:
|
|
80
|
+
name = item.get('data_name')
|
|
81
|
+
# 节点参数名称尽量使用落盘数据的名称
|
|
82
|
+
if isinstance(name, str) and name != '-1':
|
|
83
|
+
input_data[name.rsplit(Const.SEP, 1)[0]] = item
|
|
84
|
+
else:
|
|
85
|
+
input_data[full_op_name] = item
|
|
86
|
+
return input_data, output_data
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def compare_data(data_dict_list1, data_dict_list2):
|
|
90
|
+
"""
|
|
91
|
+
比较get_input_output中输出的结果是否结构一致,比较一致返回True
|
|
92
|
+
"""
|
|
93
|
+
if len(data_dict_list1) != len(data_dict_list2):
|
|
94
|
+
return False
|
|
95
|
+
# 用于比较两个节点是否相等的关键字段
|
|
96
|
+
tag_keys = ['type', 'shape']
|
|
97
|
+
for key1, key2 in zip(data_dict_list1, data_dict_list2):
|
|
98
|
+
dict1 = data_dict_list1[key1]
|
|
99
|
+
dict2 = data_dict_list2[key2]
|
|
100
|
+
for tag_key in tag_keys:
|
|
101
|
+
tag_value1 = dict1.get(tag_key, None)
|
|
102
|
+
tag_value2 = dict2.get(tag_key, None)
|
|
103
|
+
if tag_value1 != tag_value2:
|
|
104
|
+
return False
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def format_node_data(data_dict):
|
|
109
|
+
"""
|
|
110
|
+
批量进行节点数据的输出
|
|
111
|
+
"""
|
|
112
|
+
del_list = ['requires_grad', 'full_op_name']
|
|
113
|
+
for _, value in data_dict.items():
|
|
114
|
+
if not isinstance(value, dict):
|
|
115
|
+
continue
|
|
116
|
+
for item in del_list:
|
|
117
|
+
if item in value:
|
|
118
|
+
del value[item]
|
|
119
|
+
_format_data(value)
|
|
120
|
+
return data_dict
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def compare_node(node_ids, data_dicts, stack_json_data, compare_mode):
|
|
124
|
+
"""
|
|
125
|
+
调用acc_compare.py中的get_accuracy获得精度对比指标
|
|
126
|
+
真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
|
|
127
|
+
Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
|
|
128
|
+
"""
|
|
129
|
+
merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
|
|
130
|
+
merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
|
|
131
|
+
result = []
|
|
132
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
133
|
+
get_accuracy(result, merge_n, merge_b, dump_mode)
|
|
134
|
+
return result
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _parse_node(node_id, data_dict, stack_json_data, compare_mode):
|
|
138
|
+
"""
|
|
139
|
+
转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
|
|
140
|
+
"""
|
|
141
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
142
|
+
op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
|
|
143
|
+
if node_id in stack_json_data:
|
|
144
|
+
op_parsed_list.append(
|
|
145
|
+
{'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
|
|
146
|
+
else:
|
|
147
|
+
op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
|
|
148
|
+
result = merge_tensor(op_parsed_list, dump_mode)
|
|
149
|
+
if not result:
|
|
150
|
+
result['op_name'] = []
|
|
151
|
+
return result
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _format_decimal_string(s):
|
|
155
|
+
"""
|
|
156
|
+
使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
|
|
157
|
+
"""
|
|
158
|
+
pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
|
|
159
|
+
matches = pattern.findall(s)
|
|
160
|
+
for match in matches:
|
|
161
|
+
is_percent = match.endswith('%')
|
|
162
|
+
number_str = match.rstrip('%')
|
|
163
|
+
decimal_part = number_str.split('.')[1]
|
|
164
|
+
# 如果小数位数大于6,进行处理
|
|
165
|
+
if len(decimal_part) > GraphConst.ROUND_TH:
|
|
166
|
+
number_float = float(number_str)
|
|
167
|
+
formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}"
|
|
168
|
+
# 如果原来是百分数,加回百分号
|
|
169
|
+
if is_percent:
|
|
170
|
+
formatted_number += '%'
|
|
171
|
+
# 替换原字符串中的数值部分
|
|
172
|
+
s = s.replace(match, formatted_number)
|
|
173
|
+
return s
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _format_data(data_dict):
|
|
177
|
+
"""
|
|
178
|
+
格式化数据,小数保留6位,处理一些异常值
|
|
179
|
+
"""
|
|
180
|
+
pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
|
|
181
|
+
all_null = False
|
|
182
|
+
for key, value in data_dict.items():
|
|
183
|
+
if isinstance(value, str):
|
|
184
|
+
# 将单引号删掉,None换成null避免前端解析错误
|
|
185
|
+
value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL)
|
|
186
|
+
value = _format_decimal_string(value)
|
|
187
|
+
elif value is None or value == ' ':
|
|
188
|
+
value = GraphConst.NULL
|
|
189
|
+
# 科学计数法1.123123123123e-11,格式化为1.123123e-11
|
|
190
|
+
elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)):
|
|
191
|
+
value = "{:.6e}".format(value)
|
|
192
|
+
elif isinstance(value, float):
|
|
193
|
+
value = round(value, GraphConst.ROUND_TH)
|
|
194
|
+
# Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案
|
|
195
|
+
if key != GraphConst.ERROR_KEY:
|
|
196
|
+
# 除了error_key不转str,其他都转str, 避免前端解析错误
|
|
197
|
+
value = str(value)
|
|
198
|
+
# max为null, 意味着这个参数值为null
|
|
199
|
+
if key == Const.MAX and value == GraphConst.NULL:
|
|
200
|
+
all_null = True
|
|
201
|
+
data_dict[key] = value
|
|
202
|
+
# 字典里的value全null,只保留一个null
|
|
203
|
+
if all_null:
|
|
204
|
+
data_dict.clear()
|
|
205
|
+
data_dict[GraphConst.VALUE] = GraphConst.NULL
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|