mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -0,0 +1,264 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import time
18
+ import glob
19
+ from typing import Dict, List
20
+ from pathlib import Path
21
+
22
+ import pandas as pd
23
+
24
+ from msprobe.core.common.const import Const
25
+ from msprobe.core.common.file_utils import (
26
+ check_file_or_directory_path,
27
+ FileOpen,
28
+ create_directory,
29
+ write_csv,
30
+ check_path_before_create,
31
+ read_csv,
32
+ write_df_to_csv
33
+ )
34
+ from msprobe.mindspore.code_mapping.graph import GraphNode
35
+ from msprobe.mindspore.common.log import logger
36
+
37
+
38
+ # 定义Trie节点
39
+ class TrieNode:
40
+ def __init__(self):
41
+ self.children = {}
42
+ self.is_end_of_key = False
43
+ self.value = None
44
+
45
+
46
+ # 定义Trie树
47
+ class Trie:
48
+ def __init__(self):
49
+ self.root = TrieNode()
50
+
51
+ # 向Trie中插入一个键
52
+ def insert(self, key, value):
53
+ node = self.root
54
+ for key_char in key:
55
+ if key_char not in node.children:
56
+ node.children[key_char] = TrieNode()
57
+ node = node.children[key_char]
58
+ # 标记结束位置
59
+ node.is_end_of_key = True
60
+ node.value = value
61
+
62
+ # 在name字符串中查找所有匹配的键
63
+ def search_in_string(self, string):
64
+ matched_values = []
65
+ for i in range(len(string)):
66
+ node = self.root
67
+ j = i
68
+ # 从字符串的每个字符开始,逐字符查找匹配
69
+ while j < len(string) and string[j] in node.children:
70
+ node = node.children[string[j]]
71
+ if node.is_end_of_key:
72
+ matched_values.append(node.value)
73
+ j += 1
74
+ return matched_values
75
+
76
+
77
+ # 定义匹配函数
78
+ def match_codes(trie, name):
79
+ matched_nodes = trie.search_in_string(name)
80
+ matched_codes = [Const.NEW_LINE.join(node.code_info) for node in matched_nodes]
81
+ return Const.NEW_LINE.join(matched_codes)
82
+
83
+
84
+ def match_names(trie, name):
85
+ matched_nodes = trie.search_in_string(name)
86
+ matched_names = [node.scope for node in matched_nodes]
87
+ return Const.NEW_LINE.join(matched_names)
88
+
89
+
90
+ def map_op_names_to_codes_and_scopes(df, match_dict):
91
+ # 构建Trie树并插入所有键
92
+ trie = Trie()
93
+ for key, value in match_dict.items():
94
+ trie.insert(key, value)
95
+
96
+ df[Const.CODE_STACK] = df[Const.OP_NAME].apply(lambda name: match_codes(trie, name))
97
+ df[Const.SCOPE_NAME] = df[Const.OP_NAME].apply(lambda name: match_names(trie, name))
98
+ return df
99
+
100
+
101
+ def find_npy_files(npy_path):
102
+ """
103
+ 查找指定路径下所有的.npy文件。
104
+
105
+ Parameters:
106
+ npy_path (str): 搜索的路径,可以是文件或目录。
107
+
108
+ Returns:
109
+ List[Path]: 找到的.npy文件路径列表。
110
+ """
111
+ npy_files = []
112
+ npy_path_obj = Path(npy_path)
113
+
114
+ # 检查当前路径是否是一个以 .npy 结尾的文件
115
+ if npy_path_obj.suffix == Const.NUMPY_SUFFIX and npy_path_obj.is_file():
116
+ check_file_or_directory_path(npy_path_obj)
117
+ npy_files.append(npy_path_obj.resolve())
118
+ return npy_files
119
+
120
+ # 如果是目录,使用Path.rglob查找所有.npy文件
121
+ if npy_path_obj.is_dir():
122
+ for file in npy_path_obj.rglob(Const.NUMPY_PATTERN):
123
+ check_file_or_directory_path(file)
124
+ npy_files.append(file.resolve())
125
+ else:
126
+ logger.info(f"The specified path is neither an .npy file nor a directory: {npy_path}")
127
+
128
+ return npy_files
129
+
130
+
131
+ def write_to_csv(param: Dict, output_dir: str):
132
+ """
133
+ 将参数写入CSV文件。
134
+
135
+ Parameters:
136
+ param (Dict): 要写入的数据,格式为{文件名: (代码堆栈, 作用域名称)}。
137
+ output_dir (str): 输出目录路径。
138
+ """
139
+ create_directory(output_dir)
140
+
141
+ # 使用时间戳生成文件名
142
+ timestamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
143
+ file_path = Path(output_dir) / f"code_mapping_{timestamp}.csv"
144
+ check_path_before_create(file_path)
145
+ data = [(name, res1, res2) for name, (res1, res2) in param.items()]
146
+ df = pd.DataFrame(data, columns=[Const.FILE_PATH, Const.CODE_STACK, Const.SCOPE_NAME])
147
+ write_df_to_csv(df, file_path)
148
+
149
+
150
+ def find_statistic_files(path):
151
+ if not os.path.isdir(path):
152
+ if os.path.basename(path) == 'statistic.csv':
153
+ return [path]
154
+ else:
155
+ return []
156
+ pattern = os.path.join(path, '**', "statistic.csv")
157
+
158
+ statistic_files = list(glob.glob(pattern, recursive=True))
159
+ return statistic_files
160
+
161
+
162
+ def check_and_fix_header(file_path: str):
163
+ """
164
+ 检查 CSV 文件的表头是否以逗号结尾,如果没有则添加一个逗号。
165
+
166
+ Parameters:
167
+ file_path (str): CSV 文件的路径。
168
+
169
+ Returns:
170
+ bool: 如果表头被修改,返回 True;否则,返回 False。
171
+ """
172
+
173
+ with FileOpen(file_path, "r") as f:
174
+ lines = f.readlines()
175
+
176
+ if not lines:
177
+ logger.warning(f"The file {file_path} is empty.")
178
+ return False
179
+
180
+ # 获取表头并去除末尾的换行符
181
+ header = lines[0].rstrip(Const.NEW_LINE).rstrip('\r')
182
+
183
+ if not header.endswith(','):
184
+ logger.info(f"The header does not end with a comma. Adding a comma to the file: {file_path}.")
185
+ # 添加逗号并恢复换行符
186
+ lines[0] = header + Const.CSV_NEWLINE_SEPARATOR
187
+
188
+ # 写回修复后的内容到文件
189
+ with FileOpen(file_path, "w") as f:
190
+ f.writelines(lines)
191
+ logger.info(f"Added a trailing comma to the file: {file_path}.")
192
+ return True
193
+ else:
194
+ logger.info(f"The header already ends with a comma. No modification needed for the file: {file_path}.")
195
+ return False
196
+
197
+
198
+ def bind_for_statistic(statistic_files: List[str], match_dict: Dict):
199
+ """
200
+ 处理统计文件,绑定代码信息。
201
+
202
+ Parameters:
203
+ statistic_files (List[str]): 统计文件路径列表。
204
+ match_dict (Dict): 匹配字典,用于复杂映射。
205
+ """
206
+ for statistic_file in statistic_files:
207
+ # 使用FileOpen安全打开文件
208
+ header_modified = check_and_fix_header(statistic_file)
209
+ if header_modified:
210
+ logger.info(f"The header of the file {statistic_file} has been fixed.")
211
+
212
+ df = read_csv(statistic_file, as_pd=True)
213
+
214
+ # 进行复杂映射
215
+ df = map_op_names_to_codes_and_scopes(df, match_dict)
216
+
217
+ # 使用write_csv安全写入文件
218
+ write_df_to_csv(df, statistic_file)
219
+
220
+
221
+ def bind_code_info_for_data(input_dir: str, nodes: Dict[str, GraphNode]) -> Dict[str, str]:
222
+ # 待重构后优化性能
223
+ match_dict = {}
224
+ for node in nodes.values():
225
+ # 屏蔽子图节点
226
+ if node.is_subgraph:
227
+ continue
228
+ # 获取规范化后的scope name
229
+ scope_name = node.scope.replace(Const.SCOPE_SEPARATOR, Const.REPLACEMENT_CHARACTER)
230
+ match_dict[scope_name] = node
231
+ npy_files = find_npy_files(input_dir)
232
+
233
+ bind_result = {}
234
+ if not npy_files:
235
+ statistic_files = find_statistic_files(input_dir)
236
+ if statistic_files:
237
+ bind_for_statistic(statistic_files, match_dict)
238
+ return bind_result
239
+
240
+ for npy_file in npy_files:
241
+ directory, file_name = os.path.split(npy_file) # 拆分路径
242
+ name_without_ext = os.path.splitext(file_name)[0] # 提取文件名(去掉扩展名)
243
+ if name_without_ext.isdigit():
244
+ # 3. 读取find.csv文件
245
+ csv_file_path = os.path.join(directory, 'mapping.csv')
246
+ check_file_or_directory_path(csv_file_path)
247
+ df = read_csv(csv_file_path, header=None)
248
+
249
+ # 4. 查找是否有与xxx.npy匹配的条目
250
+ matching_row = df[df[0] == file_name] # 假设A列存储文件名
251
+ if not matching_row.empty:
252
+ corresponding_name = matching_row[1].values[0]
253
+ else:
254
+ corresponding_name = None
255
+ name_without_ext = os.path.splitext(corresponding_name)[0]
256
+ npy_path = os.path.realpath(npy_file)
257
+ node_scope = name_without_ext.split(".")[1]
258
+ trie = Trie()
259
+ for key, value in match_dict.items():
260
+ trie.insert(key, value)
261
+ bind_code = match_codes(trie, node_scope)
262
+ bind_name = match_names(trie, node_scope)
263
+ bind_result[npy_path] = (bind_code, bind_name)
264
+ return bind_result
@@ -0,0 +1,40 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory, FileChecker
19
+
20
+
21
+ def add_ir_parser_arguments(parser):
22
+ parser.add_argument('--ir', type=str, required=True, help="Path to the graph file")
23
+ parser.add_argument('--dump_data', type=str, required=True, default=None, help="Path to data dir")
24
+ parser.add_argument('--output', type=str, required=False, default="./", help="Path to output dir")
25
+
26
+
27
+ def check_args(args):
28
+ args.ir = os.path.abspath(args.ir)
29
+
30
+ check_file_or_directory_path(args.ir)
31
+
32
+ args.dump_data = os.path.abspath(args.dump_data)
33
+ if os.path.isdir(args.dump_data):
34
+ check_file_or_directory_path(args.dump_data, isdir=True)
35
+ else:
36
+ check_file_or_directory_path(args.dump_data, isdir=False)
37
+
38
+ args.output = os.path.abspath(args.output)
39
+ create_directory(args.output)
40
+
@@ -0,0 +1,49 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Dict, Union
17
+
18
+
19
+ class GraphNode:
20
+ def __init__(self, name: str, pos: int = -1, unique_name: str = "", operator_name: str = "",
21
+ return_variable: str = "", return_value: str = "",
22
+ var_inputs: List[str] = None, has_constant_input: bool = False,
23
+ unique_id: str = "", scope: str = "", code_info: List[str] = None,
24
+ is_subgraph: bool = False, attrs: Union[Dict[str, str], List[str]] = None):
25
+ self.name = name
26
+ self.unique_name = unique_name
27
+ self.pos = pos
28
+ self.operator_name = operator_name
29
+ self.return_variable = return_variable
30
+ self.return_value = return_value
31
+ self.var_inputs = var_inputs if var_inputs else []
32
+ self.has_constant_input = has_constant_input
33
+ self.unique_id = unique_id
34
+ self.scope = scope
35
+ self.code_info = code_info if code_info else []
36
+ self.attrs = attrs if attrs else ({} if not is_subgraph else [])
37
+ self.nodes = {} # Internal nodes if this is a subgraph
38
+ self.predecessors = [] # Predecessor nodes
39
+ self.successors = [] # Successor nodes
40
+ self.is_subgraph = is_subgraph
41
+
42
+ def trace_back_ancestors(self, ancestors: List[str], visited: Dict[str, bool], parser) -> None:
43
+ if visited[self.unique_name]:
44
+ return
45
+ visited[self.unique_name] = True
46
+ ancestors.append(self.unique_name)
47
+ for predecessor in self.predecessors:
48
+ predecessor.trace_back_ancestors(ancestors, visited, parser)
49
+
@@ -0,0 +1,226 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import logging
18
+ from typing import Tuple, List, Dict
19
+ from msprobe.mindspore.code_mapping.graph import GraphNode
20
+
21
+
22
+ class Parser:
23
+ def __init__(self):
24
+ self.nodes = {}
25
+ self.local_dict = {}
26
+ self.number_dict = {}
27
+
28
+ @staticmethod
29
+ def parse_subgraph_attributes(text: str, subgraph_node: GraphNode, start_pos: int, end_pos: int) -> None:
30
+ subgraph_attr_pattern = re.compile(r'subgraph attr:\s*(.*)', re.DOTALL)
31
+ match = subgraph_attr_pattern.search(text, start_pos, end_pos)
32
+ if match:
33
+ attrs = match.group(1).strip().split('\n')
34
+ if isinstance(subgraph_node.attrs, list):
35
+ subgraph_node.attrs.extend(attrs)
36
+
37
+ @staticmethod
38
+ def parse_graph_attributes(text: str, graph_node: GraphNode) -> None:
39
+ attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL)
40
+ match = attr_pattern.search(text, graph_node.pos)
41
+ if match:
42
+ attrs = match.group(1).strip().split('\n')
43
+ for attr in attrs:
44
+ if not attr:
45
+ break
46
+ key, value = attr.split(':')
47
+ if isinstance(graph_node.attrs, dict):
48
+ graph_node.attrs[key.strip()] = value.strip()
49
+
50
+ @staticmethod
51
+ def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]:
52
+ code_info = []
53
+ code_info_pattern = re.compile(r'# .*', re.MULTILINE)
54
+ final_pos = end_pos if end_pos else len(text) - 1
55
+ lines = text[start_pos + 1:final_pos].split('\n')
56
+ for line in lines:
57
+ match = code_info_pattern.search(line)
58
+ if not match:
59
+ break
60
+ code_info.append(match.group(0).strip('# ').strip('/'))
61
+ return code_info
62
+
63
+ @staticmethod
64
+ def extract_bracket_content(text: str, start_pos: int) -> Tuple[str, int]:
65
+ stack = []
66
+ content = []
67
+ for i in range(start_pos, len(text)):
68
+ char = text[i]
69
+ if char == '(':
70
+ stack.append('(')
71
+ elif char == ')':
72
+ stack.pop()
73
+ if not stack:
74
+ content.append(char)
75
+ return ''.join(content), i
76
+ content.append(char)
77
+ raise ValueError("Mismatched parentheses")
78
+
79
+ @staticmethod
80
+ def find_matching_brace(text: str, start_pos: int) -> int:
81
+ stack = []
82
+ for i in range(start_pos, len(text)):
83
+ if text[i] == '{':
84
+ stack.append('{')
85
+ elif text[i] == '}':
86
+ stack.pop()
87
+ if not stack:
88
+ return i
89
+ raise ValueError("Matching closing brace not found")
90
+
91
+ @staticmethod
92
+ def extract_constants(inputs_str: str) -> List[str]:
93
+ constant_pattern = re.compile(r'\b(\w+\(.*?\))')
94
+ constants = constant_pattern.findall(inputs_str)
95
+ return constants
96
+
97
+ def parse_func_graph(self, text: str) -> None:
98
+ func_graph_pattern = re.compile(r'# IR entry: @(\S+)')
99
+ matches = func_graph_pattern.finditer(text)
100
+ for match in matches:
101
+ func_name = match.group(1)
102
+ func_graph_info = GraphNode(name=func_name, pos=match.start(), is_subgraph=False)
103
+ self.nodes[func_name] = func_graph_info
104
+
105
+ def parse_nodes(self, text: str, subgraph_info: GraphNode) -> None:
106
+ node_pattern = re.compile(r'(%\d+)\((\S+)\)\s*=\s*(\S+)\(')
107
+ matches = list(node_pattern.finditer(text))
108
+ for i, match in enumerate(matches):
109
+ series_number = match.group(1)
110
+ variable_name = match.group(2)
111
+ operator_name = match.group(3)
112
+ unique_name = "&".join([series_number, variable_name])
113
+ self.local_dict[series_number] = unique_name
114
+
115
+ args_str, end_pos = self.__class__.extract_bracket_content(text, match.end() - 1)
116
+ inputs = re.findall(r'%\w+', args_str)
117
+ subgraph_inputs = re.findall(r'@\w+', args_str)
118
+ inputs += subgraph_inputs
119
+
120
+ constants = self.__class__.extract_constants(args_str)
121
+
122
+ scope_pattern = re.compile(r'# .*scope.*:\s*\((.*?)\)', re.IGNORECASE | re.MULTILINE)
123
+
124
+ scope_match = scope_pattern.search(text, end_pos)
125
+ scope = scope_match.group(1) if scope_match else ""
126
+
127
+ id_pattern = re.compile(r'.*cnode_primal_attrs:'
128
+ r'\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE)
129
+ unique_id_match = id_pattern.search(text, end_pos, scope_match.start())
130
+ unique_id = unique_id_match.group(1) if unique_id_match else None
131
+
132
+ if scope:
133
+ next_match = matches[i + 1].start() - 1 if i < len(matches) - 1 else None
134
+ code_info = self.__class__.parse_code_info(text, scope_match.end(), next_match)
135
+ else:
136
+ code_info = None
137
+
138
+ node_info = GraphNode(name=variable_name, unique_name=unique_name, operator_name=operator_name,
139
+ var_inputs=inputs + constants, unique_id=unique_id, scope=scope, code_info=code_info)
140
+
141
+ if unique_id and scope and not scope.startswith("Gradients"):
142
+ self.number_dict[unique_id] = node_info
143
+
144
+ if subgraph_info:
145
+ subgraph_info.nodes[variable_name] = node_info
146
+
147
+ if not self.nodes.get(unique_name, None):
148
+ self.nodes[unique_name] = node_info
149
+ else:
150
+ pass
151
+
152
+ for const in constants:
153
+ if const not in self.nodes:
154
+ const_node = GraphNode(name=const, operator_name="Constant", var_inputs=[], has_constant_input=True)
155
+ if not self.nodes.get(const_node, None):
156
+ self.nodes[const] = const_node
157
+ if subgraph_info:
158
+ subgraph_info.nodes[const] = const_node
159
+ self.local_dict[const] = const
160
+
161
+ for input_var in node_info.var_inputs:
162
+ if input_var in self.local_dict or input_var in self.nodes:
163
+ input_name = self.local_dict.get(input_var, input_var)
164
+ input_node = self.nodes.get(input_name, None)
165
+ if input_node:
166
+ node_info.predecessors.append(input_node)
167
+ input_node.successors.append(node_info)
168
+ else:
169
+ param_node = GraphNode(name=input_var, operator_name="Param", var_inputs=[],
170
+ has_constant_input=False)
171
+ if not self.nodes.get(input_var, None):
172
+ self.nodes[input_var] = param_node
173
+ node_info.predecessors.append(param_node)
174
+ param_node.successors.append(node_info)
175
+
176
+ def extract_callees(self, text: str) -> None:
177
+ for node_info in self.nodes.values():
178
+ func_start_pos = node_info.pos
179
+ func_end_pos = text.find('}', func_start_pos)
180
+ func_text = text[func_start_pos:func_end_pos]
181
+ callee_pattern = re.compile(r'Partial\(@(\S+)\(')
182
+ callee_matches = callee_pattern.finditer(func_text)
183
+ for callee_match in callee_matches:
184
+ callee_name = callee_match.group(1)
185
+ if callee_name not in node_info.var_inputs:
186
+ node_info.var_inputs.append(callee_name)
187
+
188
+ def parse_subgraphs(self, text: str) -> None:
189
+ subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{')
190
+ matches = list(subgraph_pattern.finditer(text))
191
+ end_pos = 0
192
+ for match in matches:
193
+ last_pos = end_pos + 2
194
+ subgraph_name = match.group(1).split('(')[0]
195
+ start_pos = match.start()
196
+ end_pos = self.__class__.find_matching_brace(text, start_pos)
197
+ subgraph_text = text[start_pos:end_pos + 1]
198
+ attr_text = text[last_pos:start_pos]
199
+ subgraph_info = GraphNode(name=subgraph_name, pos=start_pos, is_subgraph=True)
200
+ self.nodes[subgraph_name] = subgraph_info
201
+ self.__class__.parse_subgraph_attributes(text, subgraph_info, last_pos, start_pos)
202
+ self.parse_nodes(subgraph_text, subgraph_info)
203
+ subgraph_info.end = end_pos
204
+ logging.info('Parsed subgraph: %s', subgraph_name)
205
+
206
+ def count_nodes(self) -> Tuple[int, int]:
207
+ total_nodes = len(self.nodes)
208
+ total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode'))
209
+ return total_nodes, total_cnodes
210
+
211
+ def create_backward_map(self):
212
+ for node in self.nodes.values():
213
+ if node.scope and node.scope.startswith("Gradients"):
214
+ related_forward_node = self.number_dict.get(node.unique_id, None)
215
+ if related_forward_node:
216
+ node.code_info = related_forward_node.code_info
217
+
218
+ def parse(self, text: str) -> None:
219
+ self.parse_func_graph(text)
220
+ self.parse_subgraphs(text)
221
+ self.parse_nodes(text, None)
222
+ self.extract_callees(text)
223
+ self.create_backward_map()
224
+
225
+ def get_nodes(self) -> Dict[str, GraphNode]:
226
+ return self.nodes
@@ -0,0 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.code_mapping.processor import process
17
+ from msprobe.mindspore.code_mapping.cmd_parser import check_args
18
+
19
+
20
+ def code_mapping_main(args):
21
+ check_args(args)
22
+ process(args)
23
+
24
+
@@ -0,0 +1,34 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.code_mapping.graph_parser import Parser
17
+ from msprobe.mindspore.code_mapping.bind import bind_code_info_for_data, write_to_csv
18
+ from msprobe.core.common.file_utils import FileOpen
19
+
20
+
21
+ def process(args):
22
+ ir_file_path = args.ir
23
+ with FileOpen(ir_file_path, 'r') as f:
24
+ input_text = f.read()
25
+
26
+ parser = Parser()
27
+ parser.parse(input_text)
28
+
29
+ nodes = parser.get_nodes()
30
+
31
+ bind_result = bind_code_info_for_data(args.dump_data, nodes)
32
+ if bind_result:
33
+ write_to_csv(bind_result, args.output)
34
+
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -48,6 +48,8 @@ class Const:
48
48
  MINT_DATA_PREFIX = "Mint."
49
49
  MINT_NN_FUNC_DATA_PREFIX = "MintFunctional."
50
50
  DISTRIBUTED_DATA_PREFIX = "Distributed."
51
+ TORCH_DATA_PREFIX = "Torch."
52
+ TORCH_NPU_DATA_PREFIX = "NPU."
51
53
 
52
54
  SUPPORTED_API_LIST_FILE = "support_wrap_ops.yaml"
53
55
  SUPPORTED_TENSOR_LIST_KEY = "tensor"