mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -20,6 +20,7 @@ import mindspore as ms
|
|
|
20
20
|
|
|
21
21
|
from mindspore import ops
|
|
22
22
|
from mindspore.mint import nn
|
|
23
|
+
|
|
23
24
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
24
25
|
from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
|
|
25
26
|
from msprobe.core.common.log import logger
|
|
@@ -43,7 +44,7 @@ def convert_bf16_to_fp32(tensor):
|
|
|
43
44
|
def save_tensor_as_npy(tensor, file_path):
|
|
44
45
|
if not path_len_exceeds_limit(file_path):
|
|
45
46
|
tensor = convert_bf16_to_fp32(tensor)
|
|
46
|
-
saved_tensor = tensor.
|
|
47
|
+
saved_tensor = tensor.asnumpy()
|
|
47
48
|
save_npy(saved_tensor, file_path)
|
|
48
49
|
else:
|
|
49
50
|
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
@@ -56,6 +57,11 @@ def convert_to_int(value):
|
|
|
56
57
|
return -1
|
|
57
58
|
|
|
58
59
|
|
|
60
|
+
def clean_input_kwargs(cell):
|
|
61
|
+
if hasattr(cell, 'input_kwargs'):
|
|
62
|
+
del cell.input_kwargs
|
|
63
|
+
|
|
64
|
+
|
|
59
65
|
def list_lowest_level_directories(root_dir):
|
|
60
66
|
check_path_exists(root_dir)
|
|
61
67
|
lowest_level_dirs = []
|
|
@@ -77,7 +83,7 @@ def list_lowest_level_directories(root_dir):
|
|
|
77
83
|
|
|
78
84
|
|
|
79
85
|
def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
80
|
-
check_seed_all(seed, mode)
|
|
86
|
+
check_seed_all(seed, mode, rm_dropout)
|
|
81
87
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
82
88
|
ms.set_seed(seed)
|
|
83
89
|
random.seed(seed)
|
|
@@ -102,8 +108,8 @@ class MsprobeStep(ms.train.Callback):
|
|
|
102
108
|
|
|
103
109
|
|
|
104
110
|
class Dropout(ops.Dropout):
|
|
105
|
-
def __init__(self, keep_prob=0.5,
|
|
106
|
-
super().__init__(1.,
|
|
111
|
+
def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
|
|
112
|
+
super().__init__(1., seed0, seed1)
|
|
107
113
|
|
|
108
114
|
|
|
109
115
|
class Dropout2D(ops.Dropout2D):
|
|
@@ -134,3 +140,42 @@ def remove_dropout():
|
|
|
134
140
|
ops.operations.Dropout3D = Dropout3D
|
|
135
141
|
nn.Dropout = DropoutExt
|
|
136
142
|
nn.functional.dropout = dropout_ext
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
mindtorch_check_result = None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def is_mindtorch():
|
|
149
|
+
global mindtorch_check_result
|
|
150
|
+
if mindtorch_check_result is None:
|
|
151
|
+
mindtorch_check_result = False
|
|
152
|
+
try:
|
|
153
|
+
import torch
|
|
154
|
+
from mindspore._c_expression import Tensor
|
|
155
|
+
except ImportError:
|
|
156
|
+
return mindtorch_check_result
|
|
157
|
+
tensor = torch.tensor(0.0)
|
|
158
|
+
if isinstance(tensor, Tensor):
|
|
159
|
+
mindtorch_check_result = True
|
|
160
|
+
return mindtorch_check_result
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
register_backward_hook_functions = {}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def set_register_backward_hook_functions():
|
|
167
|
+
global register_backward_hook_functions
|
|
168
|
+
if is_mindtorch():
|
|
169
|
+
import torch
|
|
170
|
+
from msprobe.mindspore.mindtorch import (_call_impl,
|
|
171
|
+
register_full_backward_pre_hook,
|
|
172
|
+
register_full_backward_hook)
|
|
173
|
+
if not hasattr(torch, "register_full_backward_hook"):
|
|
174
|
+
setattr(torch.nn.Module, "_call_impl", _call_impl)
|
|
175
|
+
setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
|
|
176
|
+
setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
|
|
177
|
+
register_backward_hook_functions["pre"] = torch.nn.Module.register_full_backward_pre_hook
|
|
178
|
+
register_backward_hook_functions["full"] = torch.nn.Module.register_full_backward_hook
|
|
179
|
+
else:
|
|
180
|
+
register_backward_hook_functions["pre"] = ms.nn.Cell.register_backward_pre_hook
|
|
181
|
+
register_backward_hook_functions["full"] = ms.nn.Cell.register_backward_hook
|
|
@@ -41,12 +41,10 @@ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
41
41
|
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
42
42
|
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
43
43
|
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
44
|
-
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
45
44
|
|
|
46
45
|
dump_result_param = {
|
|
47
46
|
'npu_json_path': npu_path,
|
|
48
47
|
'bench_json_path': bench_path,
|
|
49
|
-
'stack_json_path': stack_path,
|
|
50
48
|
'is_print_compare_log': is_print_compare_log
|
|
51
49
|
}
|
|
52
50
|
ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
|
-
|
|
19
18
|
from collections import defaultdict
|
|
20
19
|
|
|
21
20
|
import numpy as np
|
|
@@ -23,15 +22,21 @@ import pandas as pd
|
|
|
23
22
|
|
|
24
23
|
from msprobe.core.common.const import CompareConst, Const
|
|
25
24
|
from msprobe.core.common.exceptions import FileCheckException
|
|
26
|
-
from msprobe.core.common.file_utils import
|
|
27
|
-
load_npy, load_yaml)
|
|
25
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml
|
|
28
26
|
from msprobe.core.common.log import logger
|
|
29
|
-
from msprobe.core.common.utils import
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
|
|
28
|
+
check_op_str_pattern_valid, get_dump_mode, set_dump_path
|
|
29
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
32
30
|
from msprobe.core.compare.check import dtype_mapping
|
|
33
|
-
from msprobe.core.compare.acc_compare import Comparator
|
|
34
31
|
from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
|
|
32
|
+
from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MappingConfig:
|
|
36
|
+
def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
|
|
37
|
+
self.cell_mapping = cell_mapping
|
|
38
|
+
self.api_mapping = api_mapping
|
|
39
|
+
self.data_mapping = data_mapping
|
|
35
40
|
|
|
36
41
|
|
|
37
42
|
class MSComparator(Comparator):
|
|
@@ -42,18 +47,27 @@ class MSComparator(Comparator):
|
|
|
42
47
|
data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
|
|
43
48
|
is_cross_framework: 是否跨框架。
|
|
44
49
|
"""
|
|
45
|
-
def __init__(self,
|
|
50
|
+
def __init__(self, mode_config, mapping_config=None, is_cross_framework=False):
|
|
51
|
+
super().__init__(mode_config)
|
|
46
52
|
self.frame_name = MSComparator.__name__
|
|
47
|
-
|
|
48
|
-
self.
|
|
49
|
-
self.
|
|
50
|
-
|
|
53
|
+
|
|
54
|
+
self.stack_mode = mode_config.stack_mode
|
|
55
|
+
self.auto_analyze = mode_config.auto_analyze
|
|
56
|
+
self.fuzzy_match = mode_config.fuzzy_match
|
|
57
|
+
self.dump_mode = mode_config.dump_mode
|
|
58
|
+
|
|
59
|
+
if mapping_config:
|
|
60
|
+
self.cell_mapping = mapping_config.cell_mapping
|
|
61
|
+
self.api_mapping = mapping_config.api_mapping
|
|
62
|
+
self.data_mapping = mapping_config.data_mapping
|
|
63
|
+
|
|
64
|
+
if self.data_mapping:
|
|
51
65
|
self.cross_frame = is_cross_framework
|
|
52
66
|
else:
|
|
53
|
-
self.cross_frame = cell_mapping is not None or api_mapping is not None
|
|
67
|
+
self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None
|
|
54
68
|
self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
|
|
55
69
|
self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
|
|
56
|
-
if api_mapping is not None:
|
|
70
|
+
if self.api_mapping is not None:
|
|
57
71
|
self.ms_to_pt_mapping = self.load_internal_api()
|
|
58
72
|
|
|
59
73
|
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
@@ -63,9 +77,8 @@ class MSComparator(Comparator):
|
|
|
63
77
|
else:
|
|
64
78
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
65
79
|
f"{type(self.data_mapping)}")
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def calc_accuracy(cls, result_df, dump_mode, header):
|
|
80
|
+
|
|
81
|
+
def calc_accuracy(self, result_df, header):
|
|
69
82
|
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
70
83
|
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
71
84
|
result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
|
|
@@ -76,10 +89,10 @@ class MSComparator(Comparator):
|
|
|
76
89
|
val_str = val.astype(str)
|
|
77
90
|
check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
|
|
78
91
|
return check_series
|
|
79
|
-
|
|
92
|
+
|
|
80
93
|
def get_number(val):
|
|
81
94
|
return pd.to_numeric(val.astype(str), errors='coerce')
|
|
82
|
-
|
|
95
|
+
|
|
83
96
|
ms_val = result_df['NPU ' + data_type]
|
|
84
97
|
pt_val = result_df['Bench ' + data_type]
|
|
85
98
|
diff_name = data_type.capitalize() + ' diff'
|
|
@@ -93,7 +106,7 @@ class MSComparator(Comparator):
|
|
|
93
106
|
condition_pt_zero = pt_val == 0
|
|
94
107
|
result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
|
|
95
108
|
condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
|
|
96
|
-
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
|
|
109
|
+
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
|
|
97
110
|
pt_val[condition_ref_err] * 100)
|
|
98
111
|
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
|
|
99
112
|
.abs().astype(str) + '%')
|
|
@@ -101,31 +114,30 @@ class MSComparator(Comparator):
|
|
|
101
114
|
pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
|
|
102
115
|
return magnitude > CompareConst.MAGNITUDE
|
|
103
116
|
|
|
104
|
-
if dump_mode == Const.MD5:
|
|
117
|
+
if self.dump_mode == Const.MD5:
|
|
105
118
|
condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
|
|
106
119
|
result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
|
|
107
120
|
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
108
|
-
elif dump_mode == Const.SUMMARY:
|
|
121
|
+
elif self.dump_mode == Const.SUMMARY:
|
|
109
122
|
warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
|
|
110
123
|
warning_flag = pd.DataFrame(warning_list).all()
|
|
111
124
|
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
112
125
|
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
113
126
|
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
114
127
|
else:
|
|
115
|
-
fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
128
|
+
fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
116
129
|
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
117
130
|
CompareConst.ERROR_MESSAGE]
|
|
118
131
|
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
119
132
|
result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
|
|
120
133
|
return result_df[header]
|
|
121
134
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode]
|
|
135
|
+
def make_result_df(self, result):
|
|
136
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
|
|
125
137
|
|
|
126
|
-
if stack_mode:
|
|
138
|
+
if self.stack_mode:
|
|
127
139
|
header.append(CompareConst.STACK)
|
|
128
|
-
if dump_mode == Const.ALL:
|
|
140
|
+
if self.dump_mode == Const.ALL:
|
|
129
141
|
header.append(CompareConst.DATA_NAME)
|
|
130
142
|
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
131
143
|
'op_name_y': CompareConst.BENCH_NAME,
|
|
@@ -137,10 +149,11 @@ class MSComparator(Comparator):
|
|
|
137
149
|
'md5_y': CompareConst.BENCH_MD5,
|
|
138
150
|
'data_name_x': CompareConst.DATA_NAME,
|
|
139
151
|
'stack_info_x': CompareConst.STACK}, inplace=True)
|
|
140
|
-
|
|
152
|
+
|
|
141
153
|
npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
142
|
-
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
154
|
+
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
143
155
|
CompareConst.BENCH_NORM]
|
|
156
|
+
|
|
144
157
|
def set_summary(summary):
|
|
145
158
|
if summary == CompareConst.N_A:
|
|
146
159
|
return [CompareConst.N_A] * 4
|
|
@@ -153,14 +166,14 @@ class MSComparator(Comparator):
|
|
|
153
166
|
else:
|
|
154
167
|
summary_list.append(i)
|
|
155
168
|
return summary_list
|
|
156
|
-
|
|
169
|
+
|
|
157
170
|
result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
|
|
158
171
|
result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
|
|
159
172
|
result_df = pd.DataFrame(columns=header)
|
|
160
173
|
for h in header:
|
|
161
174
|
if h in result.columns:
|
|
162
175
|
result_df[h] = result[h]
|
|
163
|
-
return
|
|
176
|
+
return self.calc_accuracy(result_df, header)
|
|
164
177
|
|
|
165
178
|
def load_internal_api(self):
|
|
166
179
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
@@ -175,13 +188,16 @@ class MSComparator(Comparator):
|
|
|
175
188
|
return mapping_dict
|
|
176
189
|
|
|
177
190
|
def process_cell_mapping(self, npu_op_name):
|
|
178
|
-
if not npu_op_name
|
|
191
|
+
if not npu_op_name:
|
|
192
|
+
return CompareConst.N_A
|
|
193
|
+
param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
|
|
194
|
+
if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
|
|
179
195
|
return CompareConst.N_A
|
|
180
196
|
npu_op_name = npu_op_name.replace("Cell", "Module", 1)
|
|
181
197
|
if self.cell_mapping_dict:
|
|
182
198
|
# get cell name & class name from op_name
|
|
183
199
|
# Cell.fc1.Dense.forward.0.input.0
|
|
184
|
-
cell_name = re.split(r'\.(?:
|
|
200
|
+
cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
|
|
185
201
|
if cell_name in self.cell_mapping_dict:
|
|
186
202
|
npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
187
203
|
return npu_op_name
|
|
@@ -198,7 +214,7 @@ class MSComparator(Comparator):
|
|
|
198
214
|
data_value = data_value.to(torch.float32)
|
|
199
215
|
data_value = data_value.numpy()
|
|
200
216
|
else:
|
|
201
|
-
data_value = load_npy(data_path)
|
|
217
|
+
data_value = load_npy(data_path)
|
|
202
218
|
return data_value
|
|
203
219
|
|
|
204
220
|
def process_internal_api_mapping(self, npu_op_name):
|
|
@@ -214,7 +230,7 @@ class MSComparator(Comparator):
|
|
|
214
230
|
return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
|
|
215
231
|
else:
|
|
216
232
|
return npu_op_name
|
|
217
|
-
|
|
233
|
+
|
|
218
234
|
def get_api_name(self, api_list):
|
|
219
235
|
try:
|
|
220
236
|
api_name = api_list[0] + Const.SEP + api_list[1]
|
|
@@ -223,14 +239,14 @@ class MSComparator(Comparator):
|
|
|
223
239
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
224
240
|
return api_name
|
|
225
241
|
|
|
226
|
-
def compare_process(self, file_lists
|
|
242
|
+
def compare_process(self, file_lists):
|
|
227
243
|
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
228
244
|
npu_json_data = load_json(npu_json_path)
|
|
229
245
|
bench_json_data = load_json(bench_json_path)
|
|
230
|
-
stack_json_data = load_json(stack_json_path)
|
|
246
|
+
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
231
247
|
|
|
232
|
-
npu_df = self.gen_data_df(npu_json_data, stack_json_data
|
|
233
|
-
bench_df = self.gen_data_df(bench_json_data, stack_json_data
|
|
248
|
+
npu_df = self.gen_data_df(npu_json_data, stack_json_data)
|
|
249
|
+
bench_df = self.gen_data_df(bench_json_data, stack_json_data)
|
|
234
250
|
if self.cell_mapping:
|
|
235
251
|
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
|
|
236
252
|
elif self.api_mapping:
|
|
@@ -242,8 +258,8 @@ class MSComparator(Comparator):
|
|
|
242
258
|
npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
243
259
|
bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
244
260
|
npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
|
|
245
|
-
bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
|
|
246
261
|
bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
|
|
262
|
+
bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
|
|
247
263
|
match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
|
|
248
264
|
how='outer')
|
|
249
265
|
match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
|
|
@@ -262,9 +278,9 @@ class MSComparator(Comparator):
|
|
|
262
278
|
((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
|
|
263
279
|
((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
|
|
264
280
|
((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
|
|
265
|
-
|
|
281
|
+
|
|
266
282
|
match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
|
|
267
|
-
return
|
|
283
|
+
return self.make_result_df(match_result)
|
|
268
284
|
|
|
269
285
|
def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
|
|
270
286
|
def get_api_indices_dict(op_name_df):
|
|
@@ -288,11 +304,17 @@ class MSComparator(Comparator):
|
|
|
288
304
|
return flag
|
|
289
305
|
|
|
290
306
|
for mapping_dict in self.api_mapping_dict:
|
|
291
|
-
|
|
292
|
-
|
|
307
|
+
keys_to_compare = [
|
|
308
|
+
('ms_args', 'pt_args'),
|
|
309
|
+
('ms_output', 'pt_output'),
|
|
310
|
+
('ms_parameters', 'pt_parameters'),
|
|
311
|
+
('ms_parameters_grad', 'pt_parameters_grad'),
|
|
312
|
+
]
|
|
313
|
+
if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare):
|
|
293
314
|
logger.warning('The user-defined mapping table is incorrect,\
|
|
294
315
|
make sure that the number of parameters is equal')
|
|
295
316
|
continue
|
|
317
|
+
|
|
296
318
|
ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
|
|
297
319
|
if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
|
|
298
320
|
continue
|
|
@@ -304,13 +326,17 @@ class MSComparator(Comparator):
|
|
|
304
326
|
is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
|
|
305
327
|
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
306
328
|
is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
|
|
329
|
+
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
330
|
+
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
|
|
331
|
+
elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
|
|
332
|
+
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
|
|
307
333
|
else:
|
|
308
334
|
logger.error(f'Excepted op_name: {op_name}')
|
|
309
335
|
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
310
336
|
if is_abandoned:
|
|
311
337
|
npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
|
|
312
338
|
|
|
313
|
-
def gen_data_df(self, data_json,
|
|
339
|
+
def gen_data_df(self, data_json, stack_json_data):
|
|
314
340
|
result = {
|
|
315
341
|
CompareConst.OP_NAME: [],
|
|
316
342
|
Const.DTYPE: [],
|
|
@@ -318,29 +344,40 @@ class MSComparator(Comparator):
|
|
|
318
344
|
Const.SUMMARY: [],
|
|
319
345
|
'stack_info': []
|
|
320
346
|
}
|
|
321
|
-
if dump_mode == Const.ALL:
|
|
347
|
+
if self.dump_mode == Const.ALL:
|
|
322
348
|
result['data_name'] = []
|
|
323
|
-
elif dump_mode == Const.MD5:
|
|
349
|
+
elif self.dump_mode == Const.MD5:
|
|
324
350
|
result[Const.MD5] = []
|
|
325
351
|
for data_name in data_json['data']:
|
|
326
352
|
check_op_str_pattern_valid(data_name)
|
|
327
|
-
merge_list = self.gen_merge_list(data_json, data_name,
|
|
353
|
+
merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
|
|
328
354
|
if not merge_list:
|
|
329
355
|
continue
|
|
330
|
-
|
|
356
|
+
|
|
357
|
+
op_name_list = merge_list.get(CompareConst.OP_NAME)
|
|
358
|
+
summary_list = merge_list.get(Const.SUMMARY)
|
|
359
|
+
data_name_list = merge_list.get('data_name')
|
|
360
|
+
op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
|
|
361
|
+
summary_list,
|
|
362
|
+
data_name_list)
|
|
363
|
+
for op_name in op_name_reorder:
|
|
331
364
|
result[CompareConst.OP_NAME].append(op_name)
|
|
332
365
|
if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
|
|
333
366
|
struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
|
|
334
|
-
|
|
367
|
+
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
335
368
|
struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
|
|
369
|
+
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
370
|
+
struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
|
|
371
|
+
else:
|
|
372
|
+
struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
|
|
336
373
|
result[Const.DTYPE].append(struct[0])
|
|
337
374
|
result[Const.SHAPE].append(struct[1])
|
|
338
|
-
if dump_mode == Const.MD5:
|
|
375
|
+
if self.dump_mode == Const.MD5:
|
|
339
376
|
result[Const.MD5].append(struct[2])
|
|
340
|
-
result[Const.SUMMARY].append(
|
|
341
|
-
result['stack_info'].append(merge_list['stack_info'][0])
|
|
342
|
-
if dump_mode == Const.ALL:
|
|
343
|
-
result['data_name'].append(
|
|
377
|
+
result[Const.SUMMARY].append(summary_reorder.pop(0))
|
|
378
|
+
result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None)
|
|
379
|
+
if self.dump_mode == Const.ALL:
|
|
380
|
+
result['data_name'].append(data_name_reorder.pop(0))
|
|
344
381
|
return pd.DataFrame(result)
|
|
345
382
|
|
|
346
383
|
|
|
@@ -355,7 +392,6 @@ def check_cross_framework(bench_json_path):
|
|
|
355
392
|
|
|
356
393
|
def ms_compare(input_param, output_path, **kwargs):
|
|
357
394
|
try:
|
|
358
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
359
395
|
auto_analyze = kwargs.get('auto_analyze', True)
|
|
360
396
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
361
397
|
cell_mapping = kwargs.get('cell_mapping', None)
|
|
@@ -366,15 +402,21 @@ def ms_compare(input_param, output_path, **kwargs):
|
|
|
366
402
|
|
|
367
403
|
set_dump_path(input_param)
|
|
368
404
|
dump_mode = get_dump_mode(input_param)
|
|
405
|
+
if 'stack_json_path' in input_param:
|
|
406
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
407
|
+
else:
|
|
408
|
+
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
369
409
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
370
410
|
create_directory(output_path)
|
|
371
|
-
check_compare_param(input_param, output_path, dump_mode)
|
|
411
|
+
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
372
412
|
except (CompareException, FileCheckException) as error:
|
|
373
413
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
374
414
|
raise CompareException(error.code) from error
|
|
375
415
|
if layer_mapping:
|
|
376
416
|
data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
417
|
+
|
|
418
|
+
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
|
|
419
|
+
mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping)
|
|
420
|
+
is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
|
|
421
|
+
ms_comparator = MSComparator(mode_config, mapping_config, is_cross_framework)
|
|
422
|
+
ms_comparator.compare_core(input_param, output_path, suffix=suffix)
|
|
@@ -25,7 +25,7 @@ from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
|
|
|
25
25
|
from msprobe.core.common.log import logger
|
|
26
26
|
from msprobe.core.common.utils import add_time_with_xlsx, CompareException
|
|
27
27
|
from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
|
|
28
|
-
from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check,
|
|
28
|
+
from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check, compare_ops_apply
|
|
29
29
|
from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
|
|
30
30
|
|
|
31
31
|
|
|
@@ -144,10 +144,16 @@ def generate_data_name(data_path):
|
|
|
144
144
|
mode = GraphMode.STATISTIC_MODE
|
|
145
145
|
else:
|
|
146
146
|
mode = GraphMode.ERROR_MODE
|
|
147
|
-
logger.error(
|
|
147
|
+
logger.error("Error mode.")
|
|
148
148
|
return mode, data_list
|
|
149
149
|
|
|
150
150
|
|
|
151
|
+
def transform_special_string_into_float(data_frame):
|
|
152
|
+
data_frame[data_frame == "null"] = '0'
|
|
153
|
+
data_frame[data_frame == "False"] = '0'
|
|
154
|
+
data_frame[data_frame == "True"] = '1'
|
|
155
|
+
|
|
156
|
+
|
|
151
157
|
class GraphMSComparator:
|
|
152
158
|
def __init__(self, input_param, output_path):
|
|
153
159
|
self.output_path = output_path
|
|
@@ -187,7 +193,6 @@ class GraphMSComparator:
|
|
|
187
193
|
result_dict[CompareConst.ERROR_MESSAGE] = error_message
|
|
188
194
|
|
|
189
195
|
if not error_flag:
|
|
190
|
-
n_value, b_value = reshape_value(n_value, b_value)
|
|
191
196
|
result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
|
|
192
197
|
result_dict[CompareConst.COSINE] = result_list[0]
|
|
193
198
|
result_dict[CompareConst.MAX_ABS_ERR] = result_list[1]
|
|
@@ -334,13 +339,17 @@ class GraphMSComparator:
|
|
|
334
339
|
CompareConst.BENCH_NORM])
|
|
335
340
|
|
|
336
341
|
npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
337
|
-
|
|
342
|
+
npu_float_data_df = npu_data_df[npu_float_type].astype(str)
|
|
343
|
+
transform_special_string_into_float(npu_float_data_df)
|
|
344
|
+
npu_data_df[npu_float_type] = npu_float_data_df.astype(float)
|
|
338
345
|
|
|
339
346
|
bench_float_type = [
|
|
340
347
|
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
341
348
|
CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
|
|
342
349
|
]
|
|
343
|
-
|
|
350
|
+
bench_float_data_df = bench_data_df[bench_float_type].astype(str)
|
|
351
|
+
transform_special_string_into_float(bench_float_data_df)
|
|
352
|
+
bench_data_df[bench_float_type] = bench_float_data_df.astype(float)
|
|
344
353
|
|
|
345
354
|
npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|
|
346
355
|
bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|
|
@@ -39,6 +39,7 @@ class DebuggerConfig:
|
|
|
39
39
|
self.check_mode = task_config.check_mode
|
|
40
40
|
self.framework = Const.MS_FRAMEWORK
|
|
41
41
|
self.summary_mode = task_config.summary_mode
|
|
42
|
+
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
42
43
|
self.check()
|
|
43
44
|
create_directory(self.dump_path)
|
|
44
45
|
|
|
@@ -69,4 +70,6 @@ class DebuggerConfig:
|
|
|
69
70
|
self.file_format = "npy"
|
|
70
71
|
if not self.check_mode:
|
|
71
72
|
self.check_mode = "all"
|
|
73
|
+
if not isinstance(self.async_dump, bool):
|
|
74
|
+
raise Exception("The parameters async_dump should be bool.")
|
|
72
75
|
return True
|