mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from msprobe.core.common.const import Const
|
|
4
|
+
from msprobe.core.common.log import logger
|
|
5
|
+
from msprobe.core.common.utils import CompareException
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Trie:
|
|
9
|
+
def __init__(self, type_name=None, has_data=False):
|
|
10
|
+
self.type_name = type_name
|
|
11
|
+
self.call_count_list = []
|
|
12
|
+
self.children = {}
|
|
13
|
+
self.has_data = has_data
|
|
14
|
+
self.node_type = None
|
|
15
|
+
|
|
16
|
+
def __repr__(self):
|
|
17
|
+
return (f"Node(type_name={self.type_name}, "
|
|
18
|
+
f"has_data={self.has_data}, call number={len(self.call_count_list)})")
|
|
19
|
+
|
|
20
|
+
def insert(self, word, word_type="func"):
|
|
21
|
+
parts = word.split(Const.SEP)
|
|
22
|
+
if len(parts) < 2:
|
|
23
|
+
logger.error('result dataframe elements can not be access.')
|
|
24
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
25
|
+
"""
|
|
26
|
+
xxx, node_name, type_name, execute_num
|
|
27
|
+
etc: Cell.network_with_loss.language_model.encoder.layers.1.attention.out_proj.RowParallelLinear.1
|
|
28
|
+
prefix_name_list: Cell.network_with_loss.language_model.encoder.layers.1.attention
|
|
29
|
+
node_name: out_proj
|
|
30
|
+
type_name: RowParallelLinear
|
|
31
|
+
call_count: 1
|
|
32
|
+
"""
|
|
33
|
+
type_name = parts[-2]
|
|
34
|
+
call_count = parts[-1]
|
|
35
|
+
node = self
|
|
36
|
+
prefix_name_list = parts[:-2]
|
|
37
|
+
|
|
38
|
+
for name in prefix_name_list:
|
|
39
|
+
if name not in node.children:
|
|
40
|
+
node.children[name] = Trie()
|
|
41
|
+
node = node.children[name]
|
|
42
|
+
if node.type_name is None:
|
|
43
|
+
node.type_name = name
|
|
44
|
+
|
|
45
|
+
node.type_name = type_name
|
|
46
|
+
node.has_data = True
|
|
47
|
+
node.call_count_list.append(call_count)
|
|
48
|
+
node.node_type = word_type
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DFSConverter:
|
|
52
|
+
def __init__(self, mapping, max_depth=100):
|
|
53
|
+
self.mapping = mapping
|
|
54
|
+
self.max_depth = max_depth
|
|
55
|
+
self.result = {}
|
|
56
|
+
|
|
57
|
+
def traverse_and_collect(self, node, path="", mapping_path="", depth=0):
|
|
58
|
+
if depth > self.max_depth:
|
|
59
|
+
logger.error("The converted data depth is too large, please check the data")
|
|
60
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
61
|
+
|
|
62
|
+
if node is None:
|
|
63
|
+
return self.result
|
|
64
|
+
|
|
65
|
+
type_name = node.type_name
|
|
66
|
+
if node.has_data:
|
|
67
|
+
for count in node.call_count_list:
|
|
68
|
+
origin_name = f"{path}.{count}" if node.node_type == "Cell" else f"{path}.{type_name}.{count}"
|
|
69
|
+
mapping_name = f"{mapping_path}.{count}" if node.node_type == "Cell" else f"{mapping_path}.{type_name}.{count}"
|
|
70
|
+
self.result[origin_name] = mapping_name
|
|
71
|
+
|
|
72
|
+
name_mapping = self.mapping.get(type_name, {})
|
|
73
|
+
|
|
74
|
+
for child_name, child_node in node.children.items():
|
|
75
|
+
new_path = f"{path}.{child_name}" if path else child_name
|
|
76
|
+
converted_name = name_mapping.get(child_name, child_name)
|
|
77
|
+
new_mapping_path = f"{mapping_path}.{converted_name}" if mapping_path else converted_name
|
|
78
|
+
self.traverse_and_collect(child_node, new_path, new_mapping_path, depth+1)
|
|
79
|
+
|
|
80
|
+
return self.result
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_mapping_list(ms_tree, mapping):
|
|
84
|
+
dfs_converter = DFSConverter(mapping)
|
|
85
|
+
ms_pt_mapping = dfs_converter.traverse_and_collect(ms_tree)
|
|
86
|
+
mapping_list = []
|
|
87
|
+
for ms_name, pt_name in ms_pt_mapping.items():
|
|
88
|
+
pt_name = re.sub(r"^Cell", "Module", pt_name)
|
|
89
|
+
mapping_list.append((ms_name, pt_name))
|
|
90
|
+
return mapping_list
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_prefix_mapping(scope_list):
|
|
94
|
+
"""layer name to layer name.class_name"""
|
|
95
|
+
layer_mapping = {}
|
|
96
|
+
for name, v in scope_list.items():
|
|
97
|
+
origin_data = v.get("origin_data")
|
|
98
|
+
if not origin_data.startswith(("Cell", "Module")):
|
|
99
|
+
continue
|
|
100
|
+
name_list = name.split(Const.SEP)
|
|
101
|
+
if len(name_list) < 2:
|
|
102
|
+
logger.error('result dataframe elements can not be access.')
|
|
103
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
104
|
+
prefix_name_list = name_list[:-2] + [name_list[-1]]
|
|
105
|
+
prefix_name = Const.SEP.join(prefix_name_list)
|
|
106
|
+
layer_mapping[prefix_name] = name
|
|
107
|
+
return layer_mapping
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_layer_mapping(ms_scope_list, pt_scope_list, mapping):
|
|
111
|
+
# 1. get layer prefix to full name mapping
|
|
112
|
+
# ect: Cell.network_with_loss.language_model.embedding.3 : Cell.network_with_loss.language_model.embedding.Embedding.3
|
|
113
|
+
ms_prefix2fullname = get_prefix_mapping(ms_scope_list)
|
|
114
|
+
# 2. build trie tree
|
|
115
|
+
ms_tree = Trie(type_name="Cell")
|
|
116
|
+
for k, r in ms_scope_list.items():
|
|
117
|
+
origin_data_name = r.get('origin_data')
|
|
118
|
+
data_type = origin_data_name.split(Const.SEP)[0]
|
|
119
|
+
ms_tree.insert(k, data_type)
|
|
120
|
+
msname2ptname = get_mapping_list(ms_tree, mapping)
|
|
121
|
+
# 3. get pt layer prefix to full name mapping
|
|
122
|
+
# ect: Module.network_with_loss.language_model.embedding.3 : Module.network_with_loss.language_model.embedding.Embedding.3
|
|
123
|
+
pt_prefix2fullname = get_prefix_mapping(pt_scope_list)
|
|
124
|
+
|
|
125
|
+
final_mapping = []
|
|
126
|
+
for ms_name, pt_name in msname2ptname:
|
|
127
|
+
final_ms_name = ms_name
|
|
128
|
+
final_pt_name = pt_name
|
|
129
|
+
# cell
|
|
130
|
+
if ms_name in ms_prefix2fullname:
|
|
131
|
+
final_ms_name = ms_prefix2fullname.get(ms_name)
|
|
132
|
+
final_pt_name = pt_prefix2fullname.get(pt_name, None)
|
|
133
|
+
# func
|
|
134
|
+
elif final_ms_name in ms_scope_list:
|
|
135
|
+
final_ms_name = ms_scope_list.get(ms_name)['origin_data']
|
|
136
|
+
# remove forward/backward
|
|
137
|
+
final_ms_name = Const.SEP.join(final_ms_name.split(Const.SEP)[:-1])
|
|
138
|
+
final_pt_name = pt_scope_list.get(pt_name, None)
|
|
139
|
+
if final_pt_name:
|
|
140
|
+
final_pt_name = final_pt_name['origin_data']
|
|
141
|
+
final_pt_name = Const.SEP.join(final_pt_name.split(Const.SEP)[:-1])
|
|
142
|
+
else:
|
|
143
|
+
continue
|
|
144
|
+
final_mapping.append((final_ms_name, final_pt_name))
|
|
145
|
+
|
|
146
|
+
return final_mapping
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from msprobe.core.common.const import Const
|
|
2
|
+
from msprobe.core.common.log import logger
|
|
3
|
+
|
|
4
|
+
def find_regard_scope(lines, start_sign, end_sign):
|
|
5
|
+
# 找出 start_pos 和 end_pos
|
|
6
|
+
start_pos = end_pos = -1
|
|
7
|
+
for idx, ii in enumerate(lines):
|
|
8
|
+
if start_sign in ii:
|
|
9
|
+
start_pos = idx
|
|
10
|
+
elif end_sign in ii:
|
|
11
|
+
end_pos = idx
|
|
12
|
+
break
|
|
13
|
+
return start_pos, end_pos
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def find_stack_func_list(lines):
|
|
17
|
+
res_list = []
|
|
18
|
+
# 过滤和处理 regard_scope
|
|
19
|
+
for line in lines:
|
|
20
|
+
ele_list = line.split(',')
|
|
21
|
+
file_ele = ele_list[Const.STACK_FILE_INDEX]
|
|
22
|
+
if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
|
|
23
|
+
continue
|
|
24
|
+
|
|
25
|
+
func_ele = ele_list[Const.STACK_FUNC_INDEX]
|
|
26
|
+
if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
|
|
27
|
+
continue
|
|
28
|
+
|
|
29
|
+
in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
|
|
30
|
+
|
|
31
|
+
res_list.append(in_func_name)
|
|
32
|
+
# 反转res_list并生成final_res
|
|
33
|
+
reversed_list = res_list[::-1]
|
|
34
|
+
return reversed_list
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_duplicated_name(components):
|
|
38
|
+
duplicated_components = components
|
|
39
|
+
if len(components) < 3 or components[Const.CONSTRUCT_NAME_INDEX].isdigit():
|
|
40
|
+
logger.warning("key in construct.json is shorter than 3 parts or not name valid.")
|
|
41
|
+
else:
|
|
42
|
+
# 重复name,如Functional.add.add.X ward
|
|
43
|
+
duplicated_components = components[:Const.CONSTRUCT_NAME_INDEX + 1] + components[Const.CONSTRUCT_NAME_INDEX:]
|
|
44
|
+
return duplicated_components
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def modify_mapping_with_stack(stack, construct):
|
|
48
|
+
if not stack or not construct:
|
|
49
|
+
return {}
|
|
50
|
+
|
|
51
|
+
# 是否是mindspore的数据结构
|
|
52
|
+
is_ms = any("Cell" in ii for ii in construct)
|
|
53
|
+
# 调整后的mapping结构
|
|
54
|
+
final_pres = {}
|
|
55
|
+
# 查看归属关系
|
|
56
|
+
for key in construct:
|
|
57
|
+
key_components = key.split(Const.SEP)
|
|
58
|
+
code_list = stack.get(key, None)
|
|
59
|
+
parent_node = construct.get(key, None)
|
|
60
|
+
# 名称如果非标准开头,转为标准开头
|
|
61
|
+
if not key.startswith(("Module", "Cell")):
|
|
62
|
+
# 如果没有拿到父属scope name,默认顶级域名为Module或Cell
|
|
63
|
+
if not parent_node:
|
|
64
|
+
# 将节点名字转为标准的Module或Cell
|
|
65
|
+
key_components[0] = "Cell" if is_ms else "Module"
|
|
66
|
+
# 重复该节点的名字作为类型 如add.add add在-3位置
|
|
67
|
+
duplicated_components = get_duplicated_name(key_components)
|
|
68
|
+
modified_key = Const.SEP.join(duplicated_components)
|
|
69
|
+
|
|
70
|
+
modified_key = modified_key.replace(".forward", "").replace(".backward", "")
|
|
71
|
+
final_pres[modified_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: None, Const.STACK: None}
|
|
72
|
+
continue
|
|
73
|
+
parent = parent_node.split(Const.SEP)
|
|
74
|
+
if len(parent) < 4:
|
|
75
|
+
logger.info(f"Parent name in construct.json is not valid")
|
|
76
|
+
continue
|
|
77
|
+
parent_idx = Const.NAME_FIRST_POSSIBLE_INDEX if not \
|
|
78
|
+
parent[Const.NAME_FIRST_POSSIBLE_INDEX].isdigit() else Const.NAME_SECOND_POSSIBLE_INDEX
|
|
79
|
+
parent_name = parent[parent_idx]
|
|
80
|
+
|
|
81
|
+
if code_list:
|
|
82
|
+
# {name}.Class.count_number.X ward Or {name}.Class.count_number.X ward.ele_number
|
|
83
|
+
if parent_name.endswith('s'):
|
|
84
|
+
parent_name = parent_name[:-1]
|
|
85
|
+
if len(key_components) < 3:
|
|
86
|
+
logger.info("The length of key in construct is less than 3, please check")
|
|
87
|
+
continue
|
|
88
|
+
# {name}.count_number.X ward
|
|
89
|
+
func_name = key_components[-3]
|
|
90
|
+
start_pos, end_pos = find_regard_scope(code_list, func_name, parent_name)
|
|
91
|
+
|
|
92
|
+
# 获取指定范围的代码
|
|
93
|
+
regard_scope = code_list[start_pos:end_pos]
|
|
94
|
+
|
|
95
|
+
func_stack_list = find_stack_func_list(regard_scope)
|
|
96
|
+
else:
|
|
97
|
+
func_stack_list = []
|
|
98
|
+
# 组合逻辑:parent的节点名(到节点名字为止)加上调用栈名[reversed_list]加上原来key重复key的节点名[key_components[1:-2] + key_components[-3:]]
|
|
99
|
+
final_res_key = Const.SEP.join(parent[:parent_idx + 1] + func_stack_list +
|
|
100
|
+
key_components[1:Const.CONSTRUCT_NAME_INDEX + 1] + key_components[Const.CONSTRUCT_NAME_INDEX:])
|
|
101
|
+
final_res_key = final_res_key.strip(".forward").strip(".backward")
|
|
102
|
+
else:
|
|
103
|
+
final_res_key = Const.SEP.join(key_components[:-2] + [key_components[-1]])
|
|
104
|
+
func_stack_list = []
|
|
105
|
+
final_pres[final_res_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: parent_node,
|
|
106
|
+
Const.STACK: Const.SEP.join(func_stack_list) if func_stack_list else None}
|
|
107
|
+
return final_pres
|
|
@@ -1,29 +1,46 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import re
|
|
2
3
|
import copy
|
|
4
|
+
import sys
|
|
5
|
+
from itertools import zip_longest
|
|
6
|
+
|
|
3
7
|
from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
|
|
4
|
-
task_dumppath_get
|
|
5
|
-
from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy
|
|
8
|
+
task_dumppath_get, struct_json_get, add_time_with_yaml
|
|
9
|
+
from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy, load_json, save_yaml, FileOpen
|
|
6
10
|
from msprobe.core.common.const import Const, CompareConst
|
|
7
11
|
from msprobe.core.common.log import logger
|
|
8
12
|
from msprobe.core.common.exceptions import FileCheckException
|
|
9
13
|
from msprobe.core.compare.acc_compare import Comparator
|
|
10
14
|
from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
|
|
11
|
-
|
|
15
|
+
from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack
|
|
16
|
+
from msprobe.mindspore.compare.layer_mapping import get_layer_mapping
|
|
12
17
|
|
|
13
18
|
class MSComparator(Comparator):
|
|
14
|
-
def __init__(self, cell_mapping=None, api_mapping=None):
|
|
19
|
+
def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
|
|
15
20
|
self.frame_name = MSComparator.__name__
|
|
16
21
|
self.cell_mapping = cell_mapping
|
|
17
22
|
self.api_mapping = api_mapping
|
|
18
|
-
self.
|
|
23
|
+
self.data_mapping = data_mapping
|
|
24
|
+
if data_mapping:
|
|
25
|
+
self.cross_frame = is_cross_framework
|
|
26
|
+
else:
|
|
27
|
+
self.cross_frame = cell_mapping is not None or api_mapping is not None
|
|
19
28
|
self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
|
|
20
29
|
self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
|
|
21
30
|
if api_mapping is not None:
|
|
22
31
|
self.ms_to_pt_mapping = self.load_internal_api()
|
|
23
|
-
|
|
32
|
+
|
|
33
|
+
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
34
|
+
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
35
|
+
elif isinstance(self.data_mapping, dict):
|
|
36
|
+
self.data_mapping_dict = self.data_mapping
|
|
37
|
+
else:
|
|
38
|
+
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
39
|
+
f"{type(self.data_mapping)}")
|
|
40
|
+
|
|
24
41
|
def load_internal_api(self):
|
|
25
42
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
26
|
-
yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
|
|
43
|
+
yaml_path = os.path.join(cur_path, "ms_to_pt_api.yaml")
|
|
27
44
|
return load_yaml(yaml_path)
|
|
28
45
|
|
|
29
46
|
def load_mapping_file(self, mapping_file):
|
|
@@ -52,10 +69,12 @@ class MSComparator(Comparator):
|
|
|
52
69
|
if self.api_mapping is not None:
|
|
53
70
|
npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
|
|
54
71
|
if isinstance(self.api_mapping, str):
|
|
55
|
-
npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
|
|
72
|
+
npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
|
|
73
|
+
bench_dict_new)
|
|
56
74
|
if target_dict:
|
|
57
75
|
bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
|
|
58
|
-
npu_op_name
|
|
76
|
+
npu_op_name = npu_dict_new.get(CompareConst.OP_NAME)
|
|
77
|
+
bench_op_name = bench_dict_new.get(CompareConst.OP_NAME)
|
|
59
78
|
struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
|
|
60
79
|
if not fuzzy_match:
|
|
61
80
|
return npu_op_name == bench_op_name and struct_match
|
|
@@ -72,7 +91,7 @@ class MSComparator(Comparator):
|
|
|
72
91
|
if load_pt_file:
|
|
73
92
|
import torch
|
|
74
93
|
from msprobe.pytorch.common.utils import load_pt
|
|
75
|
-
data_value = load_pt(data_path).detach()
|
|
94
|
+
data_value = load_pt(data_path, True).detach()
|
|
76
95
|
if data_value.dtype == torch.bfloat16:
|
|
77
96
|
data_value = data_value.to(torch.float32)
|
|
78
97
|
data_value = data_value.numpy()
|
|
@@ -99,7 +118,7 @@ class MSComparator(Comparator):
|
|
|
99
118
|
elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
|
|
100
119
|
return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
|
|
101
120
|
else:
|
|
102
|
-
return npu_op_name
|
|
121
|
+
return npu_op_name
|
|
103
122
|
|
|
104
123
|
def remove_element(self, op_name, struct, summary, idx):
|
|
105
124
|
del op_name[idx]
|
|
@@ -107,7 +126,12 @@ class MSComparator(Comparator):
|
|
|
107
126
|
del summary[idx]
|
|
108
127
|
|
|
109
128
|
def get_api_name(self, api_list):
|
|
110
|
-
|
|
129
|
+
try:
|
|
130
|
+
api_name = api_list[0] + Const.SEP + api_list[1]
|
|
131
|
+
except IndexError as error:
|
|
132
|
+
logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
|
|
133
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
134
|
+
return api_name
|
|
111
135
|
|
|
112
136
|
def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
|
|
113
137
|
"""
|
|
@@ -119,10 +143,13 @@ class MSComparator(Comparator):
|
|
|
119
143
|
tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
|
|
120
144
|
"""
|
|
121
145
|
npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
|
|
122
|
-
npu_struct_in
|
|
123
|
-
|
|
146
|
+
npu_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT)
|
|
147
|
+
bench_struct_in = new_bench_dict.get(CompareConst.INPUT_STRUCT)
|
|
148
|
+
npu_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT)
|
|
149
|
+
bench_struct_out = new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
|
|
124
150
|
npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
|
|
125
|
-
npu_in_len, bench_in_len
|
|
151
|
+
npu_in_len, bench_in_len = len(npu_struct_in), len(bench_struct_in)
|
|
152
|
+
npu_out_len, bench_out_len = len(npu_struct_out), len(bench_struct_out)
|
|
126
153
|
ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
|
|
127
154
|
ms_api_name = self.get_api_name(ms_api_list)
|
|
128
155
|
pt_api_name = self.get_api_name(pt_api_list)
|
|
@@ -130,22 +157,25 @@ class MSComparator(Comparator):
|
|
|
130
157
|
for api_dict in self.api_mapping_dict:
|
|
131
158
|
if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
|
|
132
159
|
ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
|
|
133
|
-
ms_user_output_len, pt_user_output_len
|
|
160
|
+
ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
|
|
134
161
|
if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
|
|
135
|
-
logger.warning("The user-defined mapping table is incorrect
|
|
162
|
+
logger.warning("The user-defined mapping table is incorrect,\
|
|
163
|
+
make sure that the number of parameters is equal")
|
|
136
164
|
break
|
|
137
165
|
ms_out_list = api_dict.get("ms_output", [])
|
|
138
166
|
for idx in reversed(range(npu_out_len)):
|
|
139
167
|
if idx not in ms_out_list:
|
|
140
168
|
del npu_struct_out[idx]
|
|
141
|
-
|
|
142
|
-
|
|
169
|
+
if idx + npu_in_len < len(npu_summary) and idx + npu_in_len < len(npu_op_name):
|
|
170
|
+
del npu_summary[idx + npu_in_len]
|
|
171
|
+
del npu_op_name[idx + npu_in_len]
|
|
143
172
|
pt_out_list = api_dict.get("pt_output", [])
|
|
144
173
|
for idx in reversed(range(bench_out_len)):
|
|
145
174
|
if idx not in pt_out_list:
|
|
146
175
|
del bench_struct_out[idx]
|
|
147
|
-
|
|
148
|
-
|
|
176
|
+
if idx + bench_in_len < len(bench_summary) and idx + bench_in_len < len(bench_op_name):
|
|
177
|
+
del bench_summary[idx + bench_in_len]
|
|
178
|
+
del bench_op_name[idx + bench_in_len]
|
|
149
179
|
ms_para_list = api_dict.get("ms_args", [])
|
|
150
180
|
for idx in reversed(range(npu_in_len)):
|
|
151
181
|
if idx not in ms_para_list:
|
|
@@ -159,8 +189,10 @@ class MSComparator(Comparator):
|
|
|
159
189
|
target_dict = api_dict
|
|
160
190
|
break
|
|
161
191
|
if target_dict:
|
|
162
|
-
new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
|
|
163
|
-
|
|
192
|
+
new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
|
|
193
|
+
CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
|
|
194
|
+
new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
|
|
195
|
+
CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
|
|
164
196
|
return new_npu_dict, new_bench_dict, target_dict
|
|
165
197
|
|
|
166
198
|
def para_sequence_update(self, npu_op_name, bench_op_name):
|
|
@@ -180,25 +212,115 @@ class MSComparator(Comparator):
|
|
|
180
212
|
if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
|
|
181
213
|
return del_bench_dict
|
|
182
214
|
ms_input_args_list = [i for i in range(npu_in_len)]
|
|
183
|
-
input_sub_list =list(set(ms_input_args_list) - set(ms_user_args_list))
|
|
215
|
+
input_sub_list = list(set(ms_input_args_list) - set(ms_user_args_list))
|
|
184
216
|
ms_output_args_list = [i for i in range(npu_out_len)]
|
|
185
|
-
output_sub_list =list(set(ms_output_args_list) - set(ms_user_output_list))
|
|
217
|
+
output_sub_list = list(set(ms_output_args_list) - set(ms_user_output_list))
|
|
186
218
|
bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
|
|
187
219
|
bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
|
|
188
220
|
bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
|
|
189
221
|
bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
|
|
190
222
|
for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
|
|
191
|
-
bench_op_name.insert(idx, CompareConst.
|
|
192
|
-
bench_struct_in.insert(idx, CompareConst.
|
|
193
|
-
bench_summary.insert(idx, CompareConst.
|
|
223
|
+
bench_op_name.insert(idx, CompareConst.N_A)
|
|
224
|
+
bench_struct_in.insert(idx, CompareConst.N_A)
|
|
225
|
+
bench_summary.insert(idx, CompareConst.N_A)
|
|
194
226
|
for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
|
|
195
|
-
bench_op_name.insert(npu_in_len + idx, CompareConst.
|
|
196
|
-
bench_struct_out.insert(idx, CompareConst.
|
|
197
|
-
bench_summary.insert(npu_in_len + idx, CompareConst.
|
|
198
|
-
del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
|
|
227
|
+
bench_op_name.insert(npu_in_len + idx, CompareConst.N_A)
|
|
228
|
+
bench_struct_out.insert(idx, CompareConst.N_A)
|
|
229
|
+
bench_summary.insert(npu_in_len + idx, CompareConst.N_A)
|
|
230
|
+
del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
|
|
231
|
+
CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
|
|
199
232
|
return del_bench_dict
|
|
200
233
|
|
|
201
|
-
|
|
234
|
+
|
|
235
|
+
def sort_by_execution_sequence(npu_data, bench_data, mapping_list, flag):
|
|
236
|
+
def generate_execution_sequence(data):
|
|
237
|
+
sequence_map = {}
|
|
238
|
+
for index, item in enumerate(data.keys()):
|
|
239
|
+
if flag in item:
|
|
240
|
+
item_split = item.split(Const.SEP)
|
|
241
|
+
item_name = Const.SEP.join(item_split[0:-2])
|
|
242
|
+
item_index = item_split[-1]
|
|
243
|
+
if item_index == 'forward' or item_index == 'backward':
|
|
244
|
+
item_index = item_split[-2]
|
|
245
|
+
item_key = f"{item_name}.{item_index}"
|
|
246
|
+
sequence_map[item_key] = index
|
|
247
|
+
return sequence_map
|
|
248
|
+
|
|
249
|
+
npu_map = generate_execution_sequence(npu_data)
|
|
250
|
+
bench_map = generate_execution_sequence(bench_data)
|
|
251
|
+
|
|
252
|
+
def sort_by_map(item):
|
|
253
|
+
first_key = npu_map.get(item[0], sys.maxsize)
|
|
254
|
+
second_key = bench_map.get(item[1], sys.maxsize)
|
|
255
|
+
return first_key, second_key
|
|
256
|
+
|
|
257
|
+
return sorted(mapping_list, key=sort_by_map)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def generate_kernel_data(map_value, data, flag):
|
|
261
|
+
if not map_value:
|
|
262
|
+
return [], []
|
|
263
|
+
inputs_name = []
|
|
264
|
+
outputs_name = []
|
|
265
|
+
map_split = map_value.split(Const.SEP)
|
|
266
|
+
map_name = Const.SEP.join(map_split[0:-1])
|
|
267
|
+
map_index = map_split[-1]
|
|
268
|
+
for key, value in data.items():
|
|
269
|
+
if key.find(flag) != -1 and key.find(map_name) != -1:
|
|
270
|
+
if key.split(Const.SEP)[-1] != map_index and key.split(Const.SEP)[-2] != map_index :
|
|
271
|
+
continue
|
|
272
|
+
if flag == 'forward':
|
|
273
|
+
input_args = value.get('input_args', {})
|
|
274
|
+
else:
|
|
275
|
+
input_args = value.get('input', {})
|
|
276
|
+
output_args = value.get('output', {})
|
|
277
|
+
for i in range(len(input_args)):
|
|
278
|
+
inputs_name.append(f"{key}.input.{i}")
|
|
279
|
+
for i in range(len(output_args)):
|
|
280
|
+
outputs_name.append(f"{key}.output.{i}")
|
|
281
|
+
return inputs_name, outputs_name
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def generate_file_mapping(npu_json_path, bench_json_path, mapping_list):
|
|
285
|
+
|
|
286
|
+
npu_data = load_json(npu_json_path).get("data", {})
|
|
287
|
+
bench_data = load_json(bench_json_path).get("data", {})
|
|
288
|
+
|
|
289
|
+
forward_data = []
|
|
290
|
+
mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.FORWARD)
|
|
291
|
+
for map_value in mapping_list:
|
|
292
|
+
npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "forward")
|
|
293
|
+
bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "forward")
|
|
294
|
+
inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
|
|
295
|
+
outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
|
|
296
|
+
forward_data.extend(inputs_zip)
|
|
297
|
+
forward_data.extend(outputs_zip)
|
|
298
|
+
|
|
299
|
+
backward_data = []
|
|
300
|
+
mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.BACKWARD)
|
|
301
|
+
for map_value in mapping_list:
|
|
302
|
+
npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward")
|
|
303
|
+
bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward")
|
|
304
|
+
inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
|
|
305
|
+
outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
|
|
306
|
+
backward_data.extend(inputs_zip)
|
|
307
|
+
backward_data.extend(outputs_zip)
|
|
308
|
+
|
|
309
|
+
kernel_data = forward_data + backward_data
|
|
310
|
+
result = {key: value for key, value in kernel_data if key is not None}
|
|
311
|
+
|
|
312
|
+
return result
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def check_cross_framework(bench_json_path):
|
|
316
|
+
pattern = r'"data_name":\s*"[^"]+\.pt"'
|
|
317
|
+
with FileOpen(bench_json_path, 'r') as file:
|
|
318
|
+
for line in file:
|
|
319
|
+
if re.search(pattern, line):
|
|
320
|
+
return True
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
|
|
202
324
|
def ms_compare(input_param, output_path, **kwargs):
|
|
203
325
|
try:
|
|
204
326
|
stack_mode = kwargs.get('stack_mode', False)
|
|
@@ -206,14 +328,30 @@ def ms_compare(input_param, output_path, **kwargs):
|
|
|
206
328
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
207
329
|
cell_mapping = kwargs.get('cell_mapping', None)
|
|
208
330
|
api_mapping = kwargs.get('api_mapping', None)
|
|
331
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
332
|
+
layer_mapping = kwargs.get('layer_mapping', None)
|
|
333
|
+
|
|
209
334
|
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
210
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
335
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
211
336
|
create_directory(output_path)
|
|
212
337
|
check_compare_param(input_param, output_path, summary_compare, md5_compare)
|
|
213
338
|
except (CompareException, FileCheckException) as error:
|
|
214
339
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
215
340
|
raise CompareException(error.code) from error
|
|
216
|
-
|
|
341
|
+
if layer_mapping:
|
|
342
|
+
pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK)
|
|
343
|
+
ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK)
|
|
344
|
+
mapping = load_yaml(layer_mapping)
|
|
345
|
+
ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct)
|
|
346
|
+
pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct)
|
|
347
|
+
layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping)
|
|
348
|
+
data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping)
|
|
349
|
+
|
|
350
|
+
data_mapping_name = add_time_with_yaml(f"data_mapping")
|
|
351
|
+
data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}")
|
|
352
|
+
save_yaml(data_mapping_path, data_mapping)
|
|
353
|
+
is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
|
|
354
|
+
ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
|
|
217
355
|
ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
218
356
|
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
219
357
|
md5_compare=md5_compare)
|
|
@@ -47,8 +47,10 @@ def npy_data_read(data_path, npy_file_list, mapping_dict):
|
|
|
47
47
|
def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
48
48
|
data_list = []
|
|
49
49
|
statistic_data_list = []
|
|
50
|
-
header_index = {
|
|
51
|
-
|
|
50
|
+
header_index = {
|
|
51
|
+
'Data Type': None, 'Shape': None, 'Max Value': None,
|
|
52
|
+
'Min Value': None,'Avg Value': None, 'L2Norm Value': None
|
|
53
|
+
}
|
|
52
54
|
for statistic_file in statistic_file_list:
|
|
53
55
|
with FileOpen(statistic_file, "r") as f:
|
|
54
56
|
csv_reader = csv.reader(f, delimiter=",")
|
|
@@ -65,8 +67,9 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
65
67
|
|
|
66
68
|
for data in statistic_data_list:
|
|
67
69
|
compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}"
|
|
70
|
+
op_name = f"{compare_key} {statistic_file_path}"
|
|
68
71
|
timestamp = int(data[4])
|
|
69
|
-
result_data = [
|
|
72
|
+
result_data = [op_name, compare_key, timestamp]
|
|
70
73
|
for key in header_index.keys():
|
|
71
74
|
if header_index[key] is None:
|
|
72
75
|
result_data.append(np.nan)
|
|
@@ -239,9 +242,20 @@ class GraphMSComparator:
|
|
|
239
242
|
compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
|
|
240
243
|
compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
|
|
241
244
|
check_path_before_create(compare_result_path)
|
|
245
|
+
self.to_excel(compare_result_df, compare_result_path)
|
|
246
|
+
logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
|
|
247
|
+
|
|
248
|
+
def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
|
|
249
|
+
size = len(compare_result_df)
|
|
250
|
+
# sheet size cannot be larger than 1048576
|
|
251
|
+
if size < CompareConst.MAX_EXCEL_LENGTH:
|
|
252
|
+
compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if need_slice else compare_result_path
|
|
242
253
|
compare_result_df.to_excel(compare_result_path, index=False)
|
|
243
254
|
change_mode(compare_result_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
244
|
-
|
|
255
|
+
return slice_num + 1
|
|
256
|
+
else:
|
|
257
|
+
slice_num = self.to_excel(compare_result_df.iloc[0: size//2], compare_result_path, slice_num, True)
|
|
258
|
+
return self.to_excel(compare_result_df.iloc[size//2:], compare_result_path, slice_num, True)
|
|
245
259
|
|
|
246
260
|
def compare_process(self, rank_id, step_id):
|
|
247
261
|
# generate data_path
|
|
@@ -251,8 +265,8 @@ class GraphMSComparator:
|
|
|
251
265
|
return [], ''
|
|
252
266
|
|
|
253
267
|
# generate file name
|
|
254
|
-
npu_mode =
|
|
255
|
-
bench_mode =
|
|
268
|
+
npu_mode = GraphMode.ERROR_MODE
|
|
269
|
+
bench_mode = GraphMode.ERROR_MODE
|
|
256
270
|
npu_data_list = []
|
|
257
271
|
bench_data_list = []
|
|
258
272
|
for npu_data_path in npu_data_path_list:
|
|
@@ -262,7 +276,7 @@ class GraphMSComparator:
|
|
|
262
276
|
bench_mode, data_list = generate_data_name(bench_data_path)
|
|
263
277
|
bench_data_list.extend(data_list)
|
|
264
278
|
|
|
265
|
-
if npu_mode ==
|
|
279
|
+
if npu_mode == GraphMode.ERROR_MODE or bench_mode == GraphMode.ERROR_MODE:
|
|
266
280
|
logger.warning(f"Data_path {npu_data_path} or {bench_data_path} is not exist.")
|
|
267
281
|
return [], ''
|
|
268
282
|
if npu_mode != bench_mode:
|
|
@@ -286,11 +300,13 @@ class GraphMSComparator:
|
|
|
286
300
|
CompareConst.BENCH_NORM])
|
|
287
301
|
|
|
288
302
|
npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
289
|
-
npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(
|
|
303
|
+
npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
|
|
290
304
|
|
|
291
|
-
bench_float_type = [
|
|
292
|
-
|
|
293
|
-
|
|
305
|
+
bench_float_type = [
|
|
306
|
+
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
307
|
+
CompareConst.BENCH_MEAN,CompareConst.BENCH_NORM
|
|
308
|
+
]
|
|
309
|
+
bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
|
|
294
310
|
|
|
295
311
|
npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|
|
296
312
|
bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|