mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -0,0 +1,242 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.const import CompareConst, Const
19
+ from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml
20
+ from msprobe.core.common.utils import (add_time_with_yaml,
21
+ detect_framework_by_dump_json,
22
+ get_stack_construct_by_dump_json_path)
23
+ from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
24
+ from msprobe.core.compare.utils import read_op
25
+
26
+
27
+ class LayerTrie:
28
+ def __init__(self, type_name, framework=None):
29
+ self.type_name = type_name
30
+ self.data_items = []
31
+ self.children = {}
32
+ self.framework = framework
33
+
34
+ def __repr__(self):
35
+ return f"Layer(type_name={self.type_name}, data_number={len(self.data_items)})"
36
+
37
+ def get(self, name):
38
+ return self.children.get(name)
39
+
40
+ def insert(self, data_item):
41
+ parts = data_item.full_scope.split(Const.SEP)
42
+ node = self
43
+ scope_name_list = parts[Const.RIGHT_MOVE_INDEX:]
44
+
45
+ for name in scope_name_list:
46
+ if name not in node.children:
47
+ node.children[name] = LayerTrie(name, data_item.framework)
48
+ node = node.children[name]
49
+ node.data_items.append(data_item)
50
+ node.type_name = data_item.type_name
51
+
52
+ def query_data(self, scope, index, default_value=None):
53
+ parts = scope.split(Const.SEP)
54
+ node = self
55
+ scope_name_list = parts[1:]
56
+
57
+ for name in scope_name_list:
58
+ if name not in node.children:
59
+ return default_value
60
+ node = node.children[name]
61
+ if index >= len(node.data_items):
62
+ return default_value
63
+ return node.data_items[index]
64
+
65
+ def save_to_yaml(self, output_path):
66
+ result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)}
67
+ file_name = add_time_with_yaml(f"{self.framework}_tree")
68
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
69
+ save_yaml(file_path, result)
70
+
71
+ def convert_to_dict(self, node):
72
+ result = {}
73
+ result["data_item"] = [node.data_name for node in node.data_items]
74
+ for child_key, child_node in node.children.items():
75
+ key = f"{child_key} @ {child_node}"
76
+ result[key] = self.convert_to_dict(child_node)
77
+ return result
78
+
79
+
80
+ def convert_scope(layer_trie, data_item, mapping=None):
81
+ if not mapping:
82
+ mapping = {}
83
+ new_scope = Const.TOP_LAYER
84
+ scope_list = data_item.full_scope.split(Const.SEP)
85
+ cur_node = layer_trie
86
+
87
+ idx = 0
88
+ while idx < len(scope_list) - 1:
89
+ child_name = scope_list[idx + 1]
90
+ type_name = cur_node.type_name
91
+ prefix_mapping = mapping.get(type_name, {})
92
+ mapping_list = prefix_mapping.get(child_name, [])
93
+ mapping_list.append((child_name, child_name, 1))
94
+ step = 1
95
+ for origin, target, level in mapping_list:
96
+ if Const.SEP.join(scope_list[idx + 1: idx + level + 1]) == origin:
97
+ new_scope = new_scope + Const.SEP + target
98
+ step = level
99
+ break
100
+ for _ in range(step):
101
+ child_node = cur_node.get(scope_list[idx + 1])
102
+ cur_node = child_node
103
+ idx += 1
104
+ index = -1
105
+ for idx, child in enumerate(cur_node.data_items):
106
+ if data_item.data_name == child.data_name:
107
+ index = idx
108
+ return new_scope, index
109
+
110
+
111
+ def get_data_items_and_tree(dump_json_path, output_path):
112
+ framework = detect_framework_by_dump_json(dump_json_path)
113
+ stack, construct = get_stack_construct_by_dump_json_path(dump_json_path)
114
+ dump = load_json(dump_json_path)
115
+ dump_data_items = get_dump_data_items(dump, stack, construct, framework, output_path)
116
+ root = LayerTrie(Const.TOP_LAYER, framework)
117
+ for data_item in dump_data_items:
118
+ root.insert(data_item)
119
+ if output_path:
120
+ root.save_to_yaml(output_path)
121
+ return dump_data_items, root
122
+
123
+
124
+ def convert_data_item(npu_tree, bench_tree, npu_data_item, mapping):
125
+ new_scope, index = convert_scope(npu_tree, npu_data_item, mapping)
126
+ bench_data_item = bench_tree.query_data(new_scope, index)
127
+ return bench_data_item
128
+
129
+
130
+ def update_keys_in_place(d):
131
+ """
132
+ This function is used to compare and maintain compatibility between the old and new versions.
133
+ In the old version, 'Cell' was used as the top layer name, while the new version uses 'TopLayer'.
134
+ """
135
+ cell_value = d.pop(Const.CELL, None)
136
+
137
+ if cell_value is not None:
138
+ d[Const.TOP_LAYER] = cell_value
139
+
140
+
141
+ def preprocess_layer_mapping(mapping):
142
+ """
143
+ before:
144
+ {'A': {'a.b.c': 'new_c',
145
+ 'a.demo': 'new_demo',
146
+ 'z': 'new_z',
147
+ 'd.e': 'e'}}
148
+ after:
149
+ {'A': {'a': [('a.b.c', 'new_c', 3), ('a.demo', 'new_demo', 2)],
150
+ 'z': [('z', 'new_z', 1)],
151
+ 'd': [('d.e', 'e', 2)]}}
152
+ """
153
+ update_keys_in_place(mapping)
154
+ final_mapping = {}
155
+
156
+ for type_name, name_map in mapping.items():
157
+ final_mapping[type_name] = {}
158
+
159
+ for key, value in name_map.items():
160
+ key_list = key.split('.')
161
+ prefix = key_list[0] # 取前缀
162
+ key_len = len(key_list)
163
+ if prefix not in final_mapping[type_name]:
164
+ final_mapping[type_name][prefix] = []
165
+ final_mapping[type_name][prefix].append((key, value, key_len))
166
+
167
+ # 前缀映射列表按规则长度排序
168
+ for prefix in final_mapping[type_name]:
169
+ final_mapping[type_name][prefix].sort(key=lambda x: -x[-1])
170
+
171
+ return final_mapping
172
+
173
+
174
+ def convert_data_items(npu_tree, bench_tree, npu_data_items, mapping):
175
+ mapping = preprocess_layer_mapping(mapping)
176
+ api_mapping = {}
177
+ for npu_data_item in npu_data_items:
178
+ bench_data_item = convert_data_item(npu_tree, bench_tree, npu_data_item, mapping)
179
+ bench_name = bench_data_item.data_name if bench_data_item else CompareConst.N_A
180
+ npu_name = npu_data_item.data_name
181
+ api_mapping[npu_name] = bench_name
182
+ return api_mapping
183
+
184
+
185
+ def generate_api_mapping_by_layer_mapping(npu_json_path, bench_json_path, layer_mapping_path=None, output_path=None):
186
+ npu_data_items, npu_root = get_data_items_and_tree(npu_json_path, output_path)
187
+ _, bench_root = get_data_items_and_tree(bench_json_path, output_path)
188
+ if isinstance(layer_mapping_path, str):
189
+ mapping = load_yaml(layer_mapping_path)
190
+ else:
191
+ mapping = {}
192
+ api_mapping = convert_data_items(npu_root, bench_root, npu_data_items, mapping)
193
+ if output_path:
194
+ file_name = add_time_with_yaml("api_mapping")
195
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
196
+ save_yaml(file_path, api_mapping)
197
+ return api_mapping
198
+
199
+
200
+ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_path=None):
201
+ def read_full_op_names(data, op_name):
202
+ op_parsed_list = read_op(data.get(op_name, {}), op_name)
203
+ full_op_names = [op_parsed.get('full_op_name') for op_parsed in op_parsed_list]
204
+ return full_op_names
205
+
206
+ def generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names):
207
+ suffix_to_full_op_name = {}
208
+ op_data_mapping = {}
209
+ for bench_full_op_name in bench_full_op_names:
210
+ suffix = bench_full_op_name[len(bench_op_name):]
211
+ suffix_to_full_op_name[suffix] = bench_full_op_name
212
+
213
+ for npu_full_op_name in npu_full_op_names:
214
+ suffix = npu_full_op_name[len(npu_op_name):]
215
+ op_data_mapping[npu_full_op_name] = suffix_to_full_op_name.get(suffix, CompareConst.N_A)
216
+ return op_data_mapping
217
+
218
+ npu_data = load_json(npu_json_path).get("data", {})
219
+ bench_data = load_json(bench_json_path).get("data", {})
220
+ data_mapping = {}
221
+ for npu_op_name, bench_op_name in api_mapping.items():
222
+ if not npu_op_name:
223
+ continue
224
+ npu_full_op_names = read_full_op_names(npu_data, npu_op_name)
225
+ bench_full_op_names = read_full_op_names(bench_data, bench_op_name)
226
+ mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names)
227
+ data_mapping.update(mapping)
228
+ if output_path:
229
+ file_name = add_time_with_yaml("data_mapping")
230
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
231
+ save_yaml(file_path, data_mapping)
232
+ return data_mapping
233
+
234
+
235
+ def generate_data_mapping_by_layer_mapping(input_param, layer_mapping_path=None, output_path=None):
236
+ npu_json_path = input_param.get("npu_json_path")
237
+ bench_json_path = input_param.get("bench_json_path")
238
+ api_mapping = generate_api_mapping_by_layer_mapping(
239
+ npu_json_path, bench_json_path, layer_mapping_path)
240
+ data_mapping = generate_data_mapping(
241
+ npu_json_path, bench_json_path, api_mapping, output_path)
242
+ return data_mapping
@@ -0,0 +1,94 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import math
17
+
18
+ from msprobe.core.common.const import Const
19
+
20
+
21
+ def postprocess_pass(data_items, name2item):
22
+ backward_pass(data_items, name2item)
23
+ renumber_index_pass(data_items, "ParallelTransformer", "layers")
24
+
25
+
26
+ def backward_pass(data_items, name2item):
27
+ # 处理反向数据,反向无栈信息,沿用正向数据栈信息
28
+ for data_item in data_items:
29
+ data_name_list = data_item.data_name.split(Const.SEP)
30
+ if not data_name_list:
31
+ continue
32
+ if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX :]:
33
+ data_name_list[Const.SCOPE_DIRECTION_INDEX :] = [
34
+ s.replace(Const.BACKWARD, Const.FORWARD) for s in data_name_list[Const.SCOPE_DIRECTION_INDEX :]
35
+ ]
36
+ forward_name = Const.SEP.join(data_name_list)
37
+ forward_item = name2item.get(forward_name, None)
38
+ if not forward_item:
39
+ continue
40
+ data_item.stack_scope = forward_item.stack_scope
41
+ data_item.full_scope = forward_item.full_scope
42
+ data_item.layer_scope = forward_item.layer_scope
43
+
44
+
45
+ def extract_next_item_last_number(data, prefix, default_result=None):
46
+ result = default_result
47
+ match = re.search(rf"^{re.escape(prefix)}\.(\S+?)(?:\.|$)", data)
48
+ if match:
49
+ next_item = match.group(1)
50
+ numbers = re.findall(r"\d+", next_item)
51
+ if numbers:
52
+ result = int(numbers[-1])
53
+ return result
54
+
55
+
56
+ def replace_next_item_index(full_scope, prefix, index):
57
+ if math.isinf(index):
58
+ return full_scope
59
+ prefix_pattern = rf"^{re.escape(prefix)}\."
60
+ result = full_scope
61
+ match = re.search(rf"{prefix_pattern}(\S+?)(?:\.|$)", full_scope)
62
+ if match:
63
+ next_item = match.group(1)
64
+ pattern = rf"{prefix_pattern}{re.escape(next_item)}"
65
+ result = re.sub(pattern, f"{prefix}.{index}", full_scope, count=1)
66
+ return result
67
+
68
+
69
+ def renumber_index_pass(data_items, type_name, suffix=None):
70
+ """
71
+ 该函数为解决并行切分场景中编号不一致的比对问题。例如在MindSpore中ParallelTransformer层的PP切分场景,
72
+ MindSpore中的layers的成员编号是全局的,而在PyTorch中编号为局部的。
73
+ 为适配此种场景,对指定层的索引进行重新编号,以确保在后续处理阶段序号对齐。
74
+ """
75
+ prefix_dict = {} # 保存类型为type_name的前缀和最小编号的映射
76
+ for data_item in data_items:
77
+ if data_item.type_name == type_name:
78
+ prefix = f"{data_item.full_scope}.{suffix}" if suffix else data_item.layer_scope
79
+ prefix_dict[prefix] = math.inf
80
+
81
+ # 计算前缀对应的最小编号
82
+ for prefix in prefix_dict:
83
+ for data_item in data_items:
84
+ res = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
85
+ prefix_dict[prefix] = min(prefix_dict[prefix], res)
86
+
87
+ # 重新编号
88
+ for prefix, min_index in prefix_dict.items():
89
+ for data_item in data_items:
90
+ full_scope = data_item.full_scope
91
+ abs_index = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
92
+ rel_index = abs_index - min_index
93
+ full_scope = replace_next_item_index(full_scope, prefix, rel_index)
94
+ data_item.full_scope = full_scope
@@ -14,18 +14,32 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import abc
17
+
17
18
  import numpy as np
18
19
  from msprobe.core.common.utils import format_value
19
20
  from msprobe.core.common.const import Const, CompareConst
20
21
  from msprobe.core.common.log import logger
21
22
 
23
+ from msprobe.core.common.utils import CompareException
24
+
22
25
 
23
26
  def handle_inf_nan(n_value, b_value):
27
+ def convert_to_float(value):
28
+ try:
29
+ if isinstance(value, np.ndarray):
30
+ return value.astype(float)
31
+ else:
32
+ return float(value)
33
+ except ValueError as e:
34
+ logger.error('\n'.join(e.args))
35
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
36
+
37
+ n_value_convert, b_value_convert = convert_to_float(n_value), convert_to_float(b_value)
24
38
  """处理inf和nan的数据"""
25
- n_inf = np.isinf(n_value)
26
- b_inf = np.isinf(b_value)
27
- n_nan = np.isnan(n_value)
28
- b_nan = np.isnan(b_value)
39
+ n_inf = np.isinf(n_value_convert)
40
+ b_inf = np.isinf(b_value_convert)
41
+ n_nan = np.isnan(n_value_convert)
42
+ b_nan = np.isnan(b_value_convert)
29
43
  n_invalid = np.any(n_inf) or np.any(n_nan)
30
44
  b_invalid = np.any(b_inf) or np.any(b_nan)
31
45
  if n_invalid or b_invalid:
@@ -50,7 +64,11 @@ def get_error_type(n_value, b_value, error_flag):
50
64
  if not n_value.shape: # 判断数据是否为标量
51
65
  return n_value, b_value, False
52
66
 
53
- n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
67
+ try:
68
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
69
+ except CompareException:
70
+ logger.error('Numpy data is unreadable, please check!')
71
+ return CompareConst.UNREADABLE, CompareConst.UNREADABLE, True
54
72
  if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
55
73
  return CompareConst.NAN, CompareConst.NAN, True
56
74
  return n_value, b_value, False
@@ -73,7 +91,9 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
73
91
  """获取异常情况的错误信息"""
74
92
  if error_flag:
75
93
  if n_value == CompareConst.READ_NONE:
76
- if error_file:
94
+ if error_file == 'no_bench_data':
95
+ return 'Bench does not have data file.'
96
+ elif error_file is not None:
77
97
  return "Dump file: {} not found.".format(error_file)
78
98
  return CompareConst.NO_BENCH
79
99
  if n_value == CompareConst.NONE:
@@ -82,6 +102,8 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
82
102
  return "Shape of NPU and bench Tensor do not match. Skipped."
83
103
  if n_value == CompareConst.NAN:
84
104
  return "The position of inf or nan in NPU and bench Tensor do not match."
105
+ if n_value == CompareConst.UNREADABLE:
106
+ return "The npy data is unable to be read or compared, please check dump data files."
85
107
  else:
86
108
  if not n_value.shape:
87
109
  return "This is type of scalar data, can not compare."
@@ -109,7 +131,11 @@ def npy_data_check(n_value, b_value):
109
131
  error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
110
132
 
111
133
  if not error_message:
112
- n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
134
+ try:
135
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
136
+ except CompareException:
137
+ logger.error('Numpy data is unreadable, please check!')
138
+ return True, 'Numpy data is unreadable, please check!'
113
139
  # handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
114
140
  if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
115
141
  error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
@@ -160,14 +186,14 @@ class GetCosineSimilarity(TensorComparisonBasic):
160
186
 
161
187
  def apply(self, n_value, b_value, error_flag, relative_err=None):
162
188
  if error_flag:
163
- if n_value == CompareConst.READ_NONE:
164
- return CompareConst.NONE, ''
189
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
190
+ return CompareConst.UNSUPPORTED, ''
165
191
  if n_value == CompareConst.NONE:
166
192
  return CompareConst.UNSUPPORTED, ''
167
193
  if n_value == CompareConst.SHAPE_UNMATCH:
168
194
  return CompareConst.SHAPE_UNMATCH, ''
169
195
  if n_value == CompareConst.NAN:
170
- return "N/A", ''
196
+ return CompareConst.N_A, ''
171
197
 
172
198
  if not n_value.shape:
173
199
  return CompareConst.UNSUPPORTED, ''
@@ -198,17 +224,20 @@ class GetMaxAbsErr(TensorComparisonBasic):
198
224
  """计算最大绝对误差"""
199
225
  def apply(self, n_value, b_value, error_flag, relative_err=None):
200
226
  if error_flag:
201
- if n_value == CompareConst.READ_NONE:
202
- return CompareConst.NONE, ""
227
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
228
+ return CompareConst.UNSUPPORTED, ""
203
229
  if n_value == CompareConst.NONE:
204
230
  return 0, ""
205
231
  if n_value == CompareConst.SHAPE_UNMATCH:
206
232
  return CompareConst.SHAPE_UNMATCH, ""
207
233
  if n_value == CompareConst.NAN:
208
- return "N/A", ""
234
+ return CompareConst.N_A, ""
209
235
 
210
236
  temp_res = n_value - b_value
211
237
  max_value = np.max(np.abs(temp_res))
238
+ if np.isnan(max_value):
239
+ message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
240
+ return CompareConst.NAN, message
212
241
  return format_value(max_value), ""
213
242
 
214
243
 
@@ -228,20 +257,20 @@ class GetMaxRelativeErr(TensorComparisonBasic):
228
257
  """计算最大相对误差"""
229
258
  def apply(self, n_value, b_value, error_flag, relative_err=None):
230
259
  if error_flag:
231
- if n_value == CompareConst.READ_NONE:
232
- return CompareConst.NONE, ''
260
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
261
+ return CompareConst.UNSUPPORTED, ''
233
262
  if n_value == CompareConst.NONE:
234
263
  return 0, ''
235
264
  if n_value == CompareConst.SHAPE_UNMATCH:
236
265
  return CompareConst.SHAPE_UNMATCH, ''
237
266
  if n_value == CompareConst.NAN:
238
- return "N/A", ''
267
+ return CompareConst.N_A, ''
239
268
 
240
269
  if relative_err is None:
241
270
  relative_err = get_relative_err(n_value, b_value)
242
271
  max_relative_err = np.max(np.abs(relative_err))
243
272
  if np.isnan(max_relative_err):
244
- message = 'Cannot compare by MaxRelativeError, the data contains nan in dump data.'
273
+ message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
245
274
  return CompareConst.NAN, message
246
275
  return format_value(max_relative_err), ''
247
276
 
@@ -250,14 +279,14 @@ class GetThousandErrRatio(TensorComparisonBasic):
250
279
  """计算相对误差小于千分之一的比例"""
251
280
  def apply(self, n_value, b_value, error_flag, relative_err=None):
252
281
  if error_flag:
253
- if n_value == CompareConst.READ_NONE:
254
- return CompareConst.NONE, ""
282
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
283
+ return CompareConst.UNSUPPORTED, ""
255
284
  if n_value == CompareConst.NONE:
256
285
  return 0, ""
257
286
  if n_value == CompareConst.SHAPE_UNMATCH:
258
287
  return CompareConst.SHAPE_UNMATCH, ""
259
288
  if n_value == CompareConst.NAN:
260
- return "N/A", ""
289
+ return CompareConst.N_A, ""
261
290
 
262
291
  if not n_value.shape:
263
292
  return CompareConst.NAN, ""
@@ -272,14 +301,14 @@ class GetFiveThousandErrRatio(TensorComparisonBasic):
272
301
  """计算相对误差小于千分之五的比例"""
273
302
  def apply(self, n_value, b_value, error_flag, relative_err=None):
274
303
  if error_flag:
275
- if n_value == CompareConst.READ_NONE:
276
- return CompareConst.NONE, ""
304
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
305
+ return CompareConst.UNSUPPORTED, ""
277
306
  if n_value == CompareConst.NONE:
278
307
  return 0, ""
279
308
  if n_value == CompareConst.SHAPE_UNMATCH:
280
309
  return CompareConst.SHAPE_UNMATCH, ""
281
310
  if n_value == CompareConst.NAN:
282
- return "N/A", ""
311
+ return CompareConst.N_A, ""
283
312
 
284
313
  if not n_value.shape:
285
314
  return CompareConst.NAN, ""