mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
import glob
|
|
19
|
+
from typing import Dict, List
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.const import Const
|
|
25
|
+
from msprobe.core.common.file_utils import (
|
|
26
|
+
check_file_or_directory_path,
|
|
27
|
+
FileOpen,
|
|
28
|
+
create_directory,
|
|
29
|
+
write_csv,
|
|
30
|
+
check_path_before_create,
|
|
31
|
+
read_csv,
|
|
32
|
+
write_df_to_csv
|
|
33
|
+
)
|
|
34
|
+
from msprobe.mindspore.code_mapping.graph import GraphNode
|
|
35
|
+
from msprobe.mindspore.common.log import logger
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# 定义Trie节点
|
|
39
|
+
class TrieNode:
|
|
40
|
+
def __init__(self):
|
|
41
|
+
self.children = {}
|
|
42
|
+
self.is_end_of_key = False
|
|
43
|
+
self.value = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# 定义Trie树
|
|
47
|
+
class Trie:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.root = TrieNode()
|
|
50
|
+
|
|
51
|
+
# 向Trie中插入一个键
|
|
52
|
+
def insert(self, key, value):
|
|
53
|
+
node = self.root
|
|
54
|
+
for key_char in key:
|
|
55
|
+
if key_char not in node.children:
|
|
56
|
+
node.children[key_char] = TrieNode()
|
|
57
|
+
node = node.children[key_char]
|
|
58
|
+
# 标记结束位置
|
|
59
|
+
node.is_end_of_key = True
|
|
60
|
+
node.value = value
|
|
61
|
+
|
|
62
|
+
# 在name字符串中查找所有匹配的键
|
|
63
|
+
def search_in_string(self, string):
|
|
64
|
+
matched_values = []
|
|
65
|
+
for i in range(len(string)):
|
|
66
|
+
node = self.root
|
|
67
|
+
j = i
|
|
68
|
+
# 从字符串的每个字符开始,逐字符查找匹配
|
|
69
|
+
while j < len(string) and string[j] in node.children:
|
|
70
|
+
node = node.children[string[j]]
|
|
71
|
+
if node.is_end_of_key:
|
|
72
|
+
matched_values.append(node.value)
|
|
73
|
+
j += 1
|
|
74
|
+
return matched_values
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# 定义匹配函数
|
|
78
|
+
def match_codes(trie, name):
|
|
79
|
+
matched_nodes = trie.search_in_string(name)
|
|
80
|
+
matched_codes = [Const.NEW_LINE.join(node.code_info) for node in matched_nodes]
|
|
81
|
+
return Const.NEW_LINE.join(matched_codes)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def match_names(trie, name):
|
|
85
|
+
matched_nodes = trie.search_in_string(name)
|
|
86
|
+
matched_names = [node.scope for node in matched_nodes]
|
|
87
|
+
return Const.NEW_LINE.join(matched_names)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def map_op_names_to_codes_and_scopes(df, match_dict):
|
|
91
|
+
# 构建Trie树并插入所有键
|
|
92
|
+
trie = Trie()
|
|
93
|
+
for key, value in match_dict.items():
|
|
94
|
+
trie.insert(key, value)
|
|
95
|
+
|
|
96
|
+
df[Const.CODE_STACK] = df[Const.OP_NAME].apply(lambda name: match_codes(trie, name))
|
|
97
|
+
df[Const.SCOPE_NAME] = df[Const.OP_NAME].apply(lambda name: match_names(trie, name))
|
|
98
|
+
return df
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def find_npy_files(npy_path):
|
|
102
|
+
"""
|
|
103
|
+
查找指定路径下所有的.npy文件。
|
|
104
|
+
|
|
105
|
+
Parameters:
|
|
106
|
+
npy_path (str): 搜索的路径,可以是文件或目录。
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List[Path]: 找到的.npy文件路径列表。
|
|
110
|
+
"""
|
|
111
|
+
npy_files = []
|
|
112
|
+
npy_path_obj = Path(npy_path)
|
|
113
|
+
|
|
114
|
+
# 检查当前路径是否是一个以 .npy 结尾的文件
|
|
115
|
+
if npy_path_obj.suffix == Const.NUMPY_SUFFIX and npy_path_obj.is_file():
|
|
116
|
+
check_file_or_directory_path(npy_path_obj)
|
|
117
|
+
npy_files.append(npy_path_obj.resolve())
|
|
118
|
+
return npy_files
|
|
119
|
+
|
|
120
|
+
# 如果是目录,使用Path.rglob查找所有.npy文件
|
|
121
|
+
if npy_path_obj.is_dir():
|
|
122
|
+
for file in npy_path_obj.rglob(Const.NUMPY_PATTERN):
|
|
123
|
+
check_file_or_directory_path(file)
|
|
124
|
+
npy_files.append(file.resolve())
|
|
125
|
+
else:
|
|
126
|
+
logger.info(f"The specified path is neither an .npy file nor a directory: {npy_path}")
|
|
127
|
+
|
|
128
|
+
return npy_files
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def write_to_csv(param: Dict, output_dir: str):
|
|
132
|
+
"""
|
|
133
|
+
将参数写入CSV文件。
|
|
134
|
+
|
|
135
|
+
Parameters:
|
|
136
|
+
param (Dict): 要写入的数据,格式为{文件名: (代码堆栈, 作用域名称)}。
|
|
137
|
+
output_dir (str): 输出目录路径。
|
|
138
|
+
"""
|
|
139
|
+
create_directory(output_dir)
|
|
140
|
+
|
|
141
|
+
# 使用时间戳生成文件名
|
|
142
|
+
timestamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
143
|
+
file_path = Path(output_dir) / f"code_mapping_{timestamp}.csv"
|
|
144
|
+
check_path_before_create(file_path)
|
|
145
|
+
data = [(name, res1, res2) for name, (res1, res2) in param.items()]
|
|
146
|
+
df = pd.DataFrame(data, columns=[Const.FILE_PATH, Const.CODE_STACK, Const.SCOPE_NAME])
|
|
147
|
+
write_df_to_csv(df, file_path)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def find_statistic_files(path):
|
|
151
|
+
if not os.path.isdir(path):
|
|
152
|
+
if os.path.basename(path) == 'statistic.csv':
|
|
153
|
+
return [path]
|
|
154
|
+
else:
|
|
155
|
+
return []
|
|
156
|
+
pattern = os.path.join(path, '**', "statistic.csv")
|
|
157
|
+
|
|
158
|
+
statistic_files = list(glob.glob(pattern, recursive=True))
|
|
159
|
+
return statistic_files
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def check_and_fix_header(file_path: str):
|
|
163
|
+
"""
|
|
164
|
+
检查 CSV 文件的表头是否以逗号结尾,如果没有则添加一个逗号。
|
|
165
|
+
|
|
166
|
+
Parameters:
|
|
167
|
+
file_path (str): CSV 文件的路径。
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
bool: 如果表头被修改,返回 True;否则,返回 False。
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
with FileOpen(file_path, "r") as f:
|
|
174
|
+
lines = f.readlines()
|
|
175
|
+
|
|
176
|
+
if not lines:
|
|
177
|
+
logger.warning(f"The file {file_path} is empty.")
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
# 获取表头并去除末尾的换行符
|
|
181
|
+
header = lines[0].rstrip(Const.NEW_LINE).rstrip('\r')
|
|
182
|
+
|
|
183
|
+
if not header.endswith(','):
|
|
184
|
+
logger.info(f"The header does not end with a comma. Adding a comma to the file: {file_path}.")
|
|
185
|
+
# 添加逗号并恢复换行符
|
|
186
|
+
lines[0] = header + Const.CSV_NEWLINE_SEPARATOR
|
|
187
|
+
|
|
188
|
+
# 写回修复后的内容到文件
|
|
189
|
+
with FileOpen(file_path, "w") as f:
|
|
190
|
+
f.writelines(lines)
|
|
191
|
+
logger.info(f"Added a trailing comma to the file: {file_path}.")
|
|
192
|
+
return True
|
|
193
|
+
else:
|
|
194
|
+
logger.info(f"The header already ends with a comma. No modification needed for the file: {file_path}.")
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def bind_for_statistic(statistic_files: List[str], match_dict: Dict):
|
|
199
|
+
"""
|
|
200
|
+
处理统计文件,绑定代码信息。
|
|
201
|
+
|
|
202
|
+
Parameters:
|
|
203
|
+
statistic_files (List[str]): 统计文件路径列表。
|
|
204
|
+
match_dict (Dict): 匹配字典,用于复杂映射。
|
|
205
|
+
"""
|
|
206
|
+
for statistic_file in statistic_files:
|
|
207
|
+
# 使用FileOpen安全打开文件
|
|
208
|
+
header_modified = check_and_fix_header(statistic_file)
|
|
209
|
+
if header_modified:
|
|
210
|
+
logger.info(f"The header of the file {statistic_file} has been fixed.")
|
|
211
|
+
|
|
212
|
+
df = read_csv(statistic_file, as_pd=True)
|
|
213
|
+
|
|
214
|
+
# 进行复杂映射
|
|
215
|
+
df = map_op_names_to_codes_and_scopes(df, match_dict)
|
|
216
|
+
|
|
217
|
+
# 使用write_csv安全写入文件
|
|
218
|
+
write_df_to_csv(df, statistic_file)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def bind_code_info_for_data(input_dir: str, nodes: Dict[str, GraphNode]) -> Dict[str, str]:
|
|
222
|
+
# 待重构后优化性能
|
|
223
|
+
match_dict = {}
|
|
224
|
+
for node in nodes.values():
|
|
225
|
+
# 屏蔽子图节点
|
|
226
|
+
if node.is_subgraph:
|
|
227
|
+
continue
|
|
228
|
+
# 获取规范化后的scope name
|
|
229
|
+
scope_name = node.scope.replace(Const.SCOPE_SEPARATOR, Const.REPLACEMENT_CHARACTER)
|
|
230
|
+
match_dict[scope_name] = node
|
|
231
|
+
npy_files = find_npy_files(input_dir)
|
|
232
|
+
|
|
233
|
+
bind_result = {}
|
|
234
|
+
if not npy_files:
|
|
235
|
+
statistic_files = find_statistic_files(input_dir)
|
|
236
|
+
if statistic_files:
|
|
237
|
+
bind_for_statistic(statistic_files, match_dict)
|
|
238
|
+
return bind_result
|
|
239
|
+
|
|
240
|
+
for npy_file in npy_files:
|
|
241
|
+
directory, file_name = os.path.split(npy_file) # 拆分路径
|
|
242
|
+
name_without_ext = os.path.splitext(file_name)[0] # 提取文件名(去掉扩展名)
|
|
243
|
+
if name_without_ext.isdigit():
|
|
244
|
+
# 3. 读取find.csv文件
|
|
245
|
+
csv_file_path = os.path.join(directory, 'mapping.csv')
|
|
246
|
+
check_file_or_directory_path(csv_file_path)
|
|
247
|
+
df = read_csv(csv_file_path, header=None)
|
|
248
|
+
|
|
249
|
+
# 4. 查找是否有与xxx.npy匹配的条目
|
|
250
|
+
matching_row = df[df[0] == file_name] # 假设A列存储文件名
|
|
251
|
+
if not matching_row.empty:
|
|
252
|
+
corresponding_name = matching_row[1].values[0]
|
|
253
|
+
else:
|
|
254
|
+
corresponding_name = None
|
|
255
|
+
name_without_ext = os.path.splitext(corresponding_name)[0]
|
|
256
|
+
npy_path = os.path.realpath(npy_file)
|
|
257
|
+
node_scope = name_without_ext.split(".")[1]
|
|
258
|
+
trie = Trie()
|
|
259
|
+
for key, value in match_dict.items():
|
|
260
|
+
trie.insert(key, value)
|
|
261
|
+
bind_code = match_codes(trie, node_scope)
|
|
262
|
+
bind_name = match_names(trie, node_scope)
|
|
263
|
+
bind_result[npy_path] = (bind_code, bind_name)
|
|
264
|
+
return bind_result
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory, FileChecker
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def add_ir_parser_arguments(parser):
|
|
22
|
+
parser.add_argument('--ir', type=str, required=True, help="Path to the graph file")
|
|
23
|
+
parser.add_argument('--dump_data', type=str, required=True, default=None, help="Path to data dir")
|
|
24
|
+
parser.add_argument('--output', type=str, required=False, default="./", help="Path to output dir")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def check_args(args):
|
|
28
|
+
args.ir = os.path.abspath(args.ir)
|
|
29
|
+
|
|
30
|
+
check_file_or_directory_path(args.ir)
|
|
31
|
+
|
|
32
|
+
args.dump_data = os.path.abspath(args.dump_data)
|
|
33
|
+
if os.path.isdir(args.dump_data):
|
|
34
|
+
check_file_or_directory_path(args.dump_data, isdir=True)
|
|
35
|
+
else:
|
|
36
|
+
check_file_or_directory_path(args.dump_data, isdir=False)
|
|
37
|
+
|
|
38
|
+
args.output = os.path.abspath(args.output)
|
|
39
|
+
create_directory(args.output)
|
|
40
|
+
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import List, Dict, Union
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GraphNode:
|
|
20
|
+
def __init__(self, name: str, pos: int = -1, unique_name: str = "", operator_name: str = "",
|
|
21
|
+
return_variable: str = "", return_value: str = "",
|
|
22
|
+
var_inputs: List[str] = None, has_constant_input: bool = False,
|
|
23
|
+
unique_id: str = "", scope: str = "", code_info: List[str] = None,
|
|
24
|
+
is_subgraph: bool = False, attrs: Union[Dict[str, str], List[str]] = None):
|
|
25
|
+
self.name = name
|
|
26
|
+
self.unique_name = unique_name
|
|
27
|
+
self.pos = pos
|
|
28
|
+
self.operator_name = operator_name
|
|
29
|
+
self.return_variable = return_variable
|
|
30
|
+
self.return_value = return_value
|
|
31
|
+
self.var_inputs = var_inputs if var_inputs else []
|
|
32
|
+
self.has_constant_input = has_constant_input
|
|
33
|
+
self.unique_id = unique_id
|
|
34
|
+
self.scope = scope
|
|
35
|
+
self.code_info = code_info if code_info else []
|
|
36
|
+
self.attrs = attrs if attrs else ({} if not is_subgraph else [])
|
|
37
|
+
self.nodes = {} # Internal nodes if this is a subgraph
|
|
38
|
+
self.predecessors = [] # Predecessor nodes
|
|
39
|
+
self.successors = [] # Successor nodes
|
|
40
|
+
self.is_subgraph = is_subgraph
|
|
41
|
+
|
|
42
|
+
def trace_back_ancestors(self, ancestors: List[str], visited: Dict[str, bool], parser) -> None:
|
|
43
|
+
if visited[self.unique_name]:
|
|
44
|
+
return
|
|
45
|
+
visited[self.unique_name] = True
|
|
46
|
+
ancestors.append(self.unique_name)
|
|
47
|
+
for predecessor in self.predecessors:
|
|
48
|
+
predecessor.trace_back_ancestors(ancestors, visited, parser)
|
|
49
|
+
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Tuple, List, Dict
|
|
19
|
+
from msprobe.mindspore.code_mapping.graph import GraphNode
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Parser:
|
|
23
|
+
def __init__(self):
|
|
24
|
+
self.nodes = {}
|
|
25
|
+
self.local_dict = {}
|
|
26
|
+
self.number_dict = {}
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def parse_subgraph_attributes(text: str, subgraph_node: GraphNode, start_pos: int, end_pos: int) -> None:
|
|
30
|
+
subgraph_attr_pattern = re.compile(r'subgraph attr:\s*(.*)', re.DOTALL)
|
|
31
|
+
match = subgraph_attr_pattern.search(text, start_pos, end_pos)
|
|
32
|
+
if match:
|
|
33
|
+
attrs = match.group(1).strip().split('\n')
|
|
34
|
+
if isinstance(subgraph_node.attrs, list):
|
|
35
|
+
subgraph_node.attrs.extend(attrs)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def parse_graph_attributes(text: str, graph_node: GraphNode) -> None:
|
|
39
|
+
attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL)
|
|
40
|
+
match = attr_pattern.search(text, graph_node.pos)
|
|
41
|
+
if match:
|
|
42
|
+
attrs = match.group(1).strip().split('\n')
|
|
43
|
+
for attr in attrs:
|
|
44
|
+
if not attr:
|
|
45
|
+
break
|
|
46
|
+
key, value = attr.split(':')
|
|
47
|
+
if isinstance(graph_node.attrs, dict):
|
|
48
|
+
graph_node.attrs[key.strip()] = value.strip()
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]:
|
|
52
|
+
code_info = []
|
|
53
|
+
code_info_pattern = re.compile(r'# .*', re.MULTILINE)
|
|
54
|
+
final_pos = end_pos if end_pos else len(text) - 1
|
|
55
|
+
lines = text[start_pos + 1:final_pos].split('\n')
|
|
56
|
+
for line in lines:
|
|
57
|
+
match = code_info_pattern.search(line)
|
|
58
|
+
if not match:
|
|
59
|
+
break
|
|
60
|
+
code_info.append(match.group(0).strip('# ').strip('/'))
|
|
61
|
+
return code_info
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def extract_bracket_content(text: str, start_pos: int) -> Tuple[str, int]:
|
|
65
|
+
stack = []
|
|
66
|
+
content = []
|
|
67
|
+
for i in range(start_pos, len(text)):
|
|
68
|
+
char = text[i]
|
|
69
|
+
if char == '(':
|
|
70
|
+
stack.append('(')
|
|
71
|
+
elif char == ')':
|
|
72
|
+
stack.pop()
|
|
73
|
+
if not stack:
|
|
74
|
+
content.append(char)
|
|
75
|
+
return ''.join(content), i
|
|
76
|
+
content.append(char)
|
|
77
|
+
raise ValueError("Mismatched parentheses")
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def find_matching_brace(text: str, start_pos: int) -> int:
|
|
81
|
+
stack = []
|
|
82
|
+
for i in range(start_pos, len(text)):
|
|
83
|
+
if text[i] == '{':
|
|
84
|
+
stack.append('{')
|
|
85
|
+
elif text[i] == '}':
|
|
86
|
+
stack.pop()
|
|
87
|
+
if not stack:
|
|
88
|
+
return i
|
|
89
|
+
raise ValueError("Matching closing brace not found")
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def extract_constants(inputs_str: str) -> List[str]:
|
|
93
|
+
constant_pattern = re.compile(r'\b(\w+\(.*?\))')
|
|
94
|
+
constants = constant_pattern.findall(inputs_str)
|
|
95
|
+
return constants
|
|
96
|
+
|
|
97
|
+
def parse_func_graph(self, text: str) -> None:
|
|
98
|
+
func_graph_pattern = re.compile(r'# IR entry: @(\S+)')
|
|
99
|
+
matches = func_graph_pattern.finditer(text)
|
|
100
|
+
for match in matches:
|
|
101
|
+
func_name = match.group(1)
|
|
102
|
+
func_graph_info = GraphNode(name=func_name, pos=match.start(), is_subgraph=False)
|
|
103
|
+
self.nodes[func_name] = func_graph_info
|
|
104
|
+
|
|
105
|
+
def parse_nodes(self, text: str, subgraph_info: GraphNode) -> None:
|
|
106
|
+
node_pattern = re.compile(r'(%\d+)\((\S+)\)\s*=\s*(\S+)\(')
|
|
107
|
+
matches = list(node_pattern.finditer(text))
|
|
108
|
+
for i, match in enumerate(matches):
|
|
109
|
+
series_number = match.group(1)
|
|
110
|
+
variable_name = match.group(2)
|
|
111
|
+
operator_name = match.group(3)
|
|
112
|
+
unique_name = "&".join([series_number, variable_name])
|
|
113
|
+
self.local_dict[series_number] = unique_name
|
|
114
|
+
|
|
115
|
+
args_str, end_pos = self.__class__.extract_bracket_content(text, match.end() - 1)
|
|
116
|
+
inputs = re.findall(r'%\w+', args_str)
|
|
117
|
+
subgraph_inputs = re.findall(r'@\w+', args_str)
|
|
118
|
+
inputs += subgraph_inputs
|
|
119
|
+
|
|
120
|
+
constants = self.__class__.extract_constants(args_str)
|
|
121
|
+
|
|
122
|
+
scope_pattern = re.compile(r'# .*scope.*:\s*\((.*?)\)', re.IGNORECASE | re.MULTILINE)
|
|
123
|
+
|
|
124
|
+
scope_match = scope_pattern.search(text, end_pos)
|
|
125
|
+
scope = scope_match.group(1) if scope_match else ""
|
|
126
|
+
|
|
127
|
+
id_pattern = re.compile(r'.*cnode_primal_attrs:'
|
|
128
|
+
r'\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE)
|
|
129
|
+
unique_id_match = id_pattern.search(text, end_pos, scope_match.start())
|
|
130
|
+
unique_id = unique_id_match.group(1) if unique_id_match else None
|
|
131
|
+
|
|
132
|
+
if scope:
|
|
133
|
+
next_match = matches[i + 1].start() - 1 if i < len(matches) - 1 else None
|
|
134
|
+
code_info = self.__class__.parse_code_info(text, scope_match.end(), next_match)
|
|
135
|
+
else:
|
|
136
|
+
code_info = None
|
|
137
|
+
|
|
138
|
+
node_info = GraphNode(name=variable_name, unique_name=unique_name, operator_name=operator_name,
|
|
139
|
+
var_inputs=inputs + constants, unique_id=unique_id, scope=scope, code_info=code_info)
|
|
140
|
+
|
|
141
|
+
if unique_id and scope and not scope.startswith("Gradients"):
|
|
142
|
+
self.number_dict[unique_id] = node_info
|
|
143
|
+
|
|
144
|
+
if subgraph_info:
|
|
145
|
+
subgraph_info.nodes[variable_name] = node_info
|
|
146
|
+
|
|
147
|
+
if not self.nodes.get(unique_name, None):
|
|
148
|
+
self.nodes[unique_name] = node_info
|
|
149
|
+
else:
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
for const in constants:
|
|
153
|
+
if const not in self.nodes:
|
|
154
|
+
const_node = GraphNode(name=const, operator_name="Constant", var_inputs=[], has_constant_input=True)
|
|
155
|
+
if not self.nodes.get(const_node, None):
|
|
156
|
+
self.nodes[const] = const_node
|
|
157
|
+
if subgraph_info:
|
|
158
|
+
subgraph_info.nodes[const] = const_node
|
|
159
|
+
self.local_dict[const] = const
|
|
160
|
+
|
|
161
|
+
for input_var in node_info.var_inputs:
|
|
162
|
+
if input_var in self.local_dict or input_var in self.nodes:
|
|
163
|
+
input_name = self.local_dict.get(input_var, input_var)
|
|
164
|
+
input_node = self.nodes.get(input_name, None)
|
|
165
|
+
if input_node:
|
|
166
|
+
node_info.predecessors.append(input_node)
|
|
167
|
+
input_node.successors.append(node_info)
|
|
168
|
+
else:
|
|
169
|
+
param_node = GraphNode(name=input_var, operator_name="Param", var_inputs=[],
|
|
170
|
+
has_constant_input=False)
|
|
171
|
+
if not self.nodes.get(input_var, None):
|
|
172
|
+
self.nodes[input_var] = param_node
|
|
173
|
+
node_info.predecessors.append(param_node)
|
|
174
|
+
param_node.successors.append(node_info)
|
|
175
|
+
|
|
176
|
+
def extract_callees(self, text: str) -> None:
|
|
177
|
+
for node_info in self.nodes.values():
|
|
178
|
+
func_start_pos = node_info.pos
|
|
179
|
+
func_end_pos = text.find('}', func_start_pos)
|
|
180
|
+
func_text = text[func_start_pos:func_end_pos]
|
|
181
|
+
callee_pattern = re.compile(r'Partial\(@(\S+)\(')
|
|
182
|
+
callee_matches = callee_pattern.finditer(func_text)
|
|
183
|
+
for callee_match in callee_matches:
|
|
184
|
+
callee_name = callee_match.group(1)
|
|
185
|
+
if callee_name not in node_info.var_inputs:
|
|
186
|
+
node_info.var_inputs.append(callee_name)
|
|
187
|
+
|
|
188
|
+
def parse_subgraphs(self, text: str) -> None:
|
|
189
|
+
subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{')
|
|
190
|
+
matches = list(subgraph_pattern.finditer(text))
|
|
191
|
+
end_pos = 0
|
|
192
|
+
for match in matches:
|
|
193
|
+
last_pos = end_pos + 2
|
|
194
|
+
subgraph_name = match.group(1).split('(')[0]
|
|
195
|
+
start_pos = match.start()
|
|
196
|
+
end_pos = self.__class__.find_matching_brace(text, start_pos)
|
|
197
|
+
subgraph_text = text[start_pos:end_pos + 1]
|
|
198
|
+
attr_text = text[last_pos:start_pos]
|
|
199
|
+
subgraph_info = GraphNode(name=subgraph_name, pos=start_pos, is_subgraph=True)
|
|
200
|
+
self.nodes[subgraph_name] = subgraph_info
|
|
201
|
+
self.__class__.parse_subgraph_attributes(text, subgraph_info, last_pos, start_pos)
|
|
202
|
+
self.parse_nodes(subgraph_text, subgraph_info)
|
|
203
|
+
subgraph_info.end = end_pos
|
|
204
|
+
logging.info('Parsed subgraph: %s', subgraph_name)
|
|
205
|
+
|
|
206
|
+
def count_nodes(self) -> Tuple[int, int]:
|
|
207
|
+
total_nodes = len(self.nodes)
|
|
208
|
+
total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode'))
|
|
209
|
+
return total_nodes, total_cnodes
|
|
210
|
+
|
|
211
|
+
def create_backward_map(self):
|
|
212
|
+
for node in self.nodes.values():
|
|
213
|
+
if node.scope and node.scope.startswith("Gradients"):
|
|
214
|
+
related_forward_node = self.number_dict.get(node.unique_id, None)
|
|
215
|
+
if related_forward_node:
|
|
216
|
+
node.code_info = related_forward_node.code_info
|
|
217
|
+
|
|
218
|
+
def parse(self, text: str) -> None:
|
|
219
|
+
self.parse_func_graph(text)
|
|
220
|
+
self.parse_subgraphs(text)
|
|
221
|
+
self.parse_nodes(text, None)
|
|
222
|
+
self.extract_callees(text)
|
|
223
|
+
self.create_backward_map()
|
|
224
|
+
|
|
225
|
+
def get_nodes(self) -> Dict[str, GraphNode]:
|
|
226
|
+
return self.nodes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.code_mapping.processor import process
|
|
17
|
+
from msprobe.mindspore.code_mapping.cmd_parser import check_args
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def code_mapping_main(args):
|
|
21
|
+
check_args(args)
|
|
22
|
+
process(args)
|
|
23
|
+
|
|
24
|
+
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.code_mapping.graph_parser import Parser
|
|
17
|
+
from msprobe.mindspore.code_mapping.bind import bind_code_info_for_data, write_to_csv
|
|
18
|
+
from msprobe.core.common.file_utils import FileOpen
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def process(args):
|
|
22
|
+
ir_file_path = args.ir
|
|
23
|
+
with FileOpen(ir_file_path, 'r') as f:
|
|
24
|
+
input_text = f.read()
|
|
25
|
+
|
|
26
|
+
parser = Parser()
|
|
27
|
+
parser.parse(input_text)
|
|
28
|
+
|
|
29
|
+
nodes = parser.get_nodes()
|
|
30
|
+
|
|
31
|
+
bind_result = bind_code_info_for_data(args.dump_data, nodes)
|
|
32
|
+
if bind_result:
|
|
33
|
+
write_to_csv(bind_result, args.output)
|
|
34
|
+
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -48,6 +48,8 @@ class Const:
|
|
|
48
48
|
MINT_DATA_PREFIX = "Mint."
|
|
49
49
|
MINT_NN_FUNC_DATA_PREFIX = "MintFunctional."
|
|
50
50
|
DISTRIBUTED_DATA_PREFIX = "Distributed."
|
|
51
|
+
TORCH_DATA_PREFIX = "Torch."
|
|
52
|
+
TORCH_NPU_DATA_PREFIX = "NPU."
|
|
51
53
|
|
|
52
54
|
SUPPORTED_API_LIST_FILE = "support_wrap_ops.yaml"
|
|
53
55
|
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|