mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ import re
2
+
3
+ from msprobe.core.common.const import Const
4
+ from msprobe.core.common.log import logger
5
+ from msprobe.core.common.utils import CompareException
6
+
7
+
8
+ class Trie:
9
+ def __init__(self, type_name=None, has_data=False):
10
+ self.type_name = type_name
11
+ self.call_count_list = []
12
+ self.children = {}
13
+ self.has_data = has_data
14
+ self.node_type = None
15
+
16
+ def __repr__(self):
17
+ return (f"Node(type_name={self.type_name}, "
18
+ f"has_data={self.has_data}, call number={len(self.call_count_list)})")
19
+
20
+ def insert(self, word, word_type="func"):
21
+ parts = word.split(Const.SEP)
22
+ if len(parts) < 2:
23
+ logger.error('result dataframe elements can not be access.')
24
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
25
+ """
26
+ xxx, node_name, type_name, execute_num
27
+ etc: Cell.network_with_loss.language_model.encoder.layers.1.attention.out_proj.RowParallelLinear.1
28
+ prefix_name_list: Cell.network_with_loss.language_model.encoder.layers.1.attention
29
+ node_name: out_proj
30
+ type_name: RowParallelLinear
31
+ call_count: 1
32
+ """
33
+ type_name = parts[-2]
34
+ call_count = parts[-1]
35
+ node = self
36
+ prefix_name_list = parts[:-2]
37
+
38
+ for name in prefix_name_list:
39
+ if name not in node.children:
40
+ node.children[name] = Trie()
41
+ node = node.children[name]
42
+ if node.type_name is None:
43
+ node.type_name = name
44
+
45
+ node.type_name = type_name
46
+ node.has_data = True
47
+ node.call_count_list.append(call_count)
48
+ node.node_type = word_type
49
+
50
+
51
+ class DFSConverter:
52
+ def __init__(self, mapping, max_depth=100):
53
+ self.mapping = mapping
54
+ self.max_depth = max_depth
55
+ self.result = {}
56
+
57
+ def traverse_and_collect(self, node, path="", mapping_path="", depth=0):
58
+ if depth > self.max_depth:
59
+ logger.error("The converted data depth is too large, please check the data")
60
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
61
+
62
+ if node is None:
63
+ return self.result
64
+
65
+ type_name = node.type_name
66
+ if node.has_data:
67
+ for count in node.call_count_list:
68
+ origin_name = f"{path}.{count}" if node.node_type == "Cell" else f"{path}.{type_name}.{count}"
69
+ mapping_name = f"{mapping_path}.{count}" if node.node_type == "Cell" else f"{mapping_path}.{type_name}.{count}"
70
+ self.result[origin_name] = mapping_name
71
+
72
+ name_mapping = self.mapping.get(type_name, {})
73
+
74
+ for child_name, child_node in node.children.items():
75
+ new_path = f"{path}.{child_name}" if path else child_name
76
+ converted_name = name_mapping.get(child_name, child_name)
77
+ new_mapping_path = f"{mapping_path}.{converted_name}" if mapping_path else converted_name
78
+ self.traverse_and_collect(child_node, new_path, new_mapping_path, depth+1)
79
+
80
+ return self.result
81
+
82
+
83
+ def get_mapping_list(ms_tree, mapping):
84
+ dfs_converter = DFSConverter(mapping)
85
+ ms_pt_mapping = dfs_converter.traverse_and_collect(ms_tree)
86
+ mapping_list = []
87
+ for ms_name, pt_name in ms_pt_mapping.items():
88
+ pt_name = re.sub(r"^Cell", "Module", pt_name)
89
+ mapping_list.append((ms_name, pt_name))
90
+ return mapping_list
91
+
92
+
93
+ def get_prefix_mapping(scope_list):
94
+ """layer name to layer name.class_name"""
95
+ layer_mapping = {}
96
+ for name, v in scope_list.items():
97
+ origin_data = v.get("origin_data")
98
+ if not origin_data.startswith(("Cell", "Module")):
99
+ continue
100
+ name_list = name.split(Const.SEP)
101
+ if len(name_list) < 2:
102
+ logger.error('result dataframe elements can not be access.')
103
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
104
+ prefix_name_list = name_list[:-2] + [name_list[-1]]
105
+ prefix_name = Const.SEP.join(prefix_name_list)
106
+ layer_mapping[prefix_name] = name
107
+ return layer_mapping
108
+
109
+
110
+ def get_layer_mapping(ms_scope_list, pt_scope_list, mapping):
111
+ # 1. get layer prefix to full name mapping
112
+ # ect: Cell.network_with_loss.language_model.embedding.3 : Cell.network_with_loss.language_model.embedding.Embedding.3
113
+ ms_prefix2fullname = get_prefix_mapping(ms_scope_list)
114
+ # 2. build trie tree
115
+ ms_tree = Trie(type_name="Cell")
116
+ for k, r in ms_scope_list.items():
117
+ origin_data_name = r.get('origin_data')
118
+ data_type = origin_data_name.split(Const.SEP)[0]
119
+ ms_tree.insert(k, data_type)
120
+ msname2ptname = get_mapping_list(ms_tree, mapping)
121
+ # 3. get pt layer prefix to full name mapping
122
+ # ect: Module.network_with_loss.language_model.embedding.3 : Module.network_with_loss.language_model.embedding.Embedding.3
123
+ pt_prefix2fullname = get_prefix_mapping(pt_scope_list)
124
+
125
+ final_mapping = []
126
+ for ms_name, pt_name in msname2ptname:
127
+ final_ms_name = ms_name
128
+ final_pt_name = pt_name
129
+ # cell
130
+ if ms_name in ms_prefix2fullname:
131
+ final_ms_name = ms_prefix2fullname.get(ms_name)
132
+ final_pt_name = pt_prefix2fullname.get(pt_name, None)
133
+ # func
134
+ elif final_ms_name in ms_scope_list:
135
+ final_ms_name = ms_scope_list.get(ms_name)['origin_data']
136
+ # remove forward/backward
137
+ final_ms_name = Const.SEP.join(final_ms_name.split(Const.SEP)[:-1])
138
+ final_pt_name = pt_scope_list.get(pt_name, None)
139
+ if final_pt_name:
140
+ final_pt_name = final_pt_name['origin_data']
141
+ final_pt_name = Const.SEP.join(final_pt_name.split(Const.SEP)[:-1])
142
+ else:
143
+ continue
144
+ final_mapping.append((final_ms_name, final_pt_name))
145
+
146
+ return final_mapping
@@ -0,0 +1,107 @@
1
+ from msprobe.core.common.const import Const
2
+ from msprobe.core.common.log import logger
3
+
4
+ def find_regard_scope(lines, start_sign, end_sign):
5
+ # 找出 start_pos 和 end_pos
6
+ start_pos = end_pos = -1
7
+ for idx, ii in enumerate(lines):
8
+ if start_sign in ii:
9
+ start_pos = idx
10
+ elif end_sign in ii:
11
+ end_pos = idx
12
+ break
13
+ return start_pos, end_pos
14
+
15
+
16
+ def find_stack_func_list(lines):
17
+ res_list = []
18
+ # 过滤和处理 regard_scope
19
+ for line in lines:
20
+ ele_list = line.split(',')
21
+ file_ele = ele_list[Const.STACK_FILE_INDEX]
22
+ if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
23
+ continue
24
+
25
+ func_ele = ele_list[Const.STACK_FUNC_INDEX]
26
+ if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
27
+ continue
28
+
29
+ in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
30
+
31
+ res_list.append(in_func_name)
32
+ # 反转res_list并生成final_res
33
+ reversed_list = res_list[::-1]
34
+ return reversed_list
35
+
36
+
37
+ def get_duplicated_name(components):
38
+ duplicated_components = components
39
+ if len(components) < 3 or components[Const.CONSTRUCT_NAME_INDEX].isdigit():
40
+ logger.warning("key in construct.json is shorter than 3 parts or not name valid.")
41
+ else:
42
+ # 重复name,如Functional.add.add.X ward
43
+ duplicated_components = components[:Const.CONSTRUCT_NAME_INDEX + 1] + components[Const.CONSTRUCT_NAME_INDEX:]
44
+ return duplicated_components
45
+
46
+
47
+ def modify_mapping_with_stack(stack, construct):
48
+ if not stack or not construct:
49
+ return {}
50
+
51
+ # 是否是mindspore的数据结构
52
+ is_ms = any("Cell" in ii for ii in construct)
53
+ # 调整后的mapping结构
54
+ final_pres = {}
55
+ # 查看归属关系
56
+ for key in construct:
57
+ key_components = key.split(Const.SEP)
58
+ code_list = stack.get(key, None)
59
+ parent_node = construct.get(key, None)
60
+ # 名称如果非标准开头,转为标准开头
61
+ if not key.startswith(("Module", "Cell")):
62
+ # 如果没有拿到父属scope name,默认顶级域名为Module或Cell
63
+ if not parent_node:
64
+ # 将节点名字转为标准的Module或Cell
65
+ key_components[0] = "Cell" if is_ms else "Module"
66
+ # 重复该节点的名字作为类型 如add.add add在-3位置
67
+ duplicated_components = get_duplicated_name(key_components)
68
+ modified_key = Const.SEP.join(duplicated_components)
69
+
70
+ modified_key = modified_key.replace(".forward", "").replace(".backward", "")
71
+ final_pres[modified_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: None, Const.STACK: None}
72
+ continue
73
+ parent = parent_node.split(Const.SEP)
74
+ if len(parent) < 4:
75
+ logger.info(f"Parent name in construct.json is not valid")
76
+ continue
77
+ parent_idx = Const.NAME_FIRST_POSSIBLE_INDEX if not \
78
+ parent[Const.NAME_FIRST_POSSIBLE_INDEX].isdigit() else Const.NAME_SECOND_POSSIBLE_INDEX
79
+ parent_name = parent[parent_idx]
80
+
81
+ if code_list:
82
+ # {name}.Class.count_number.X ward Or {name}.Class.count_number.X ward.ele_number
83
+ if parent_name.endswith('s'):
84
+ parent_name = parent_name[:-1]
85
+ if len(key_components) < 3:
86
+ logger.info("The length of key in construct is less than 3, please check")
87
+ continue
88
+ # {name}.count_number.X ward
89
+ func_name = key_components[-3]
90
+ start_pos, end_pos = find_regard_scope(code_list, func_name, parent_name)
91
+
92
+ # 获取指定范围的代码
93
+ regard_scope = code_list[start_pos:end_pos]
94
+
95
+ func_stack_list = find_stack_func_list(regard_scope)
96
+ else:
97
+ func_stack_list = []
98
+ # 组合逻辑:parent的节点名(到节点名字为止)加上调用栈名[reversed_list]加上原来key重复key的节点名[key_components[1:-2] + key_components[-3:]]
99
+ final_res_key = Const.SEP.join(parent[:parent_idx + 1] + func_stack_list +
100
+ key_components[1:Const.CONSTRUCT_NAME_INDEX + 1] + key_components[Const.CONSTRUCT_NAME_INDEX:])
101
+ final_res_key = final_res_key.strip(".forward").strip(".backward")
102
+ else:
103
+ final_res_key = Const.SEP.join(key_components[:-2] + [key_components[-1]])
104
+ func_stack_list = []
105
+ final_pres[final_res_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: parent_node,
106
+ Const.STACK: Const.SEP.join(func_stack_list) if func_stack_list else None}
107
+ return final_pres
@@ -1,29 +1,46 @@
1
1
  import os
2
+ import re
2
3
  import copy
4
+ import sys
5
+ from itertools import zip_longest
6
+
3
7
  from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
4
- task_dumppath_get
5
- from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy
8
+ task_dumppath_get, struct_json_get, add_time_with_yaml
9
+ from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy, load_json, save_yaml, FileOpen
6
10
  from msprobe.core.common.const import Const, CompareConst
7
11
  from msprobe.core.common.log import logger
8
12
  from msprobe.core.common.exceptions import FileCheckException
9
13
  from msprobe.core.compare.acc_compare import Comparator
10
14
  from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
11
-
15
+ from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack
16
+ from msprobe.mindspore.compare.layer_mapping import get_layer_mapping
12
17
 
13
18
  class MSComparator(Comparator):
14
- def __init__(self, cell_mapping=None, api_mapping=None):
19
+ def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
15
20
  self.frame_name = MSComparator.__name__
16
21
  self.cell_mapping = cell_mapping
17
22
  self.api_mapping = api_mapping
18
- self.cross_frame = cell_mapping is not None or api_mapping is not None
23
+ self.data_mapping = data_mapping
24
+ if data_mapping:
25
+ self.cross_frame = is_cross_framework
26
+ else:
27
+ self.cross_frame = cell_mapping is not None or api_mapping is not None
19
28
  self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
20
29
  self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
21
30
  if api_mapping is not None:
22
31
  self.ms_to_pt_mapping = self.load_internal_api()
23
-
32
+
33
+ if isinstance(self.data_mapping, str) or self.data_mapping is None:
34
+ self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
35
+ elif isinstance(self.data_mapping, dict):
36
+ self.data_mapping_dict = self.data_mapping
37
+ else:
38
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
39
+ f"{type(self.data_mapping)}")
40
+
24
41
  def load_internal_api(self):
25
42
  cur_path = os.path.dirname(os.path.realpath(__file__))
26
- yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
43
+ yaml_path = os.path.join(cur_path, "ms_to_pt_api.yaml")
27
44
  return load_yaml(yaml_path)
28
45
 
29
46
  def load_mapping_file(self, mapping_file):
@@ -52,10 +69,12 @@ class MSComparator(Comparator):
52
69
  if self.api_mapping is not None:
53
70
  npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
54
71
  if isinstance(self.api_mapping, str):
55
- npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new, bench_dict_new)
72
+ npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
73
+ bench_dict_new)
56
74
  if target_dict:
57
75
  bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
58
- npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
76
+ npu_op_name = npu_dict_new.get(CompareConst.OP_NAME)
77
+ bench_op_name = bench_dict_new.get(CompareConst.OP_NAME)
59
78
  struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
60
79
  if not fuzzy_match:
61
80
  return npu_op_name == bench_op_name and struct_match
@@ -72,7 +91,7 @@ class MSComparator(Comparator):
72
91
  if load_pt_file:
73
92
  import torch
74
93
  from msprobe.pytorch.common.utils import load_pt
75
- data_value = load_pt(data_path).detach()
94
+ data_value = load_pt(data_path, True).detach()
76
95
  if data_value.dtype == torch.bfloat16:
77
96
  data_value = data_value.to(torch.float32)
78
97
  data_value = data_value.numpy()
@@ -99,7 +118,7 @@ class MSComparator(Comparator):
99
118
  elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
100
119
  return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
101
120
  else:
102
- return npu_op_name
121
+ return npu_op_name
103
122
 
104
123
  def remove_element(self, op_name, struct, summary, idx):
105
124
  del op_name[idx]
@@ -107,7 +126,12 @@ class MSComparator(Comparator):
107
126
  del summary[idx]
108
127
 
109
128
  def get_api_name(self, api_list):
110
- return api_list[0] + Const.SEP + api_list[1]
129
+ try:
130
+ api_name = api_list[0] + Const.SEP + api_list[1]
131
+ except IndexError as error:
132
+ logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
133
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
134
+ return api_name
111
135
 
112
136
  def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
113
137
  """
@@ -119,10 +143,13 @@ class MSComparator(Comparator):
119
143
  tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
120
144
  """
121
145
  npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
122
- npu_struct_in, bench_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT), new_bench_dict.get(CompareConst.INPUT_STRUCT)
123
- npu_struct_out, bench_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT), new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
146
+ npu_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT)
147
+ bench_struct_in = new_bench_dict.get(CompareConst.INPUT_STRUCT)
148
+ npu_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT)
149
+ bench_struct_out = new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
124
150
  npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
125
- npu_in_len, bench_in_len, npu_out_len, bench_out_len = len(npu_struct_in), len(bench_struct_in), len(npu_struct_out), len(bench_struct_out)
151
+ npu_in_len, bench_in_len = len(npu_struct_in), len(bench_struct_in)
152
+ npu_out_len, bench_out_len = len(npu_struct_out), len(bench_struct_out)
126
153
  ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
127
154
  ms_api_name = self.get_api_name(ms_api_list)
128
155
  pt_api_name = self.get_api_name(pt_api_list)
@@ -130,22 +157,25 @@ class MSComparator(Comparator):
130
157
  for api_dict in self.api_mapping_dict:
131
158
  if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
132
159
  ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
133
- ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
160
+ ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
134
161
  if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
135
- logger.warning("The user-defined mapping table is incorrect, make sure that the number of parameters is equal" )
162
+ logger.warning("The user-defined mapping table is incorrect,\
163
+ make sure that the number of parameters is equal")
136
164
  break
137
165
  ms_out_list = api_dict.get("ms_output", [])
138
166
  for idx in reversed(range(npu_out_len)):
139
167
  if idx not in ms_out_list:
140
168
  del npu_struct_out[idx]
141
- del npu_summary[idx + npu_in_len]
142
- del npu_op_name[idx + npu_in_len]
169
+ if idx + npu_in_len < len(npu_summary) and idx + npu_in_len < len(npu_op_name):
170
+ del npu_summary[idx + npu_in_len]
171
+ del npu_op_name[idx + npu_in_len]
143
172
  pt_out_list = api_dict.get("pt_output", [])
144
173
  for idx in reversed(range(bench_out_len)):
145
174
  if idx not in pt_out_list:
146
175
  del bench_struct_out[idx]
147
- del bench_summary[idx + bench_in_len]
148
- del bench_op_name[idx + bench_in_len]
176
+ if idx + bench_in_len < len(bench_summary) and idx + bench_in_len < len(bench_op_name):
177
+ del bench_summary[idx + bench_in_len]
178
+ del bench_op_name[idx + bench_in_len]
149
179
  ms_para_list = api_dict.get("ms_args", [])
150
180
  for idx in reversed(range(npu_in_len)):
151
181
  if idx not in ms_para_list:
@@ -159,8 +189,10 @@ class MSComparator(Comparator):
159
189
  target_dict = api_dict
160
190
  break
161
191
  if target_dict:
162
- new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in, CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
163
- new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
192
+ new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
193
+ CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
194
+ new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
195
+ CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
164
196
  return new_npu_dict, new_bench_dict, target_dict
165
197
 
166
198
  def para_sequence_update(self, npu_op_name, bench_op_name):
@@ -180,25 +212,115 @@ class MSComparator(Comparator):
180
212
  if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
181
213
  return del_bench_dict
182
214
  ms_input_args_list = [i for i in range(npu_in_len)]
183
- input_sub_list =list(set(ms_input_args_list) - set(ms_user_args_list))
215
+ input_sub_list = list(set(ms_input_args_list) - set(ms_user_args_list))
184
216
  ms_output_args_list = [i for i in range(npu_out_len)]
185
- output_sub_list =list(set(ms_output_args_list) - set(ms_user_output_list))
217
+ output_sub_list = list(set(ms_output_args_list) - set(ms_user_output_list))
186
218
  bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
187
219
  bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
188
220
  bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
189
221
  bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
190
222
  for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
191
- bench_op_name.insert(idx, CompareConst.NAN)
192
- bench_struct_in.insert(idx, CompareConst.NAN)
193
- bench_summary.insert(idx, CompareConst.NAN)
223
+ bench_op_name.insert(idx, CompareConst.N_A)
224
+ bench_struct_in.insert(idx, CompareConst.N_A)
225
+ bench_summary.insert(idx, CompareConst.N_A)
194
226
  for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
195
- bench_op_name.insert(npu_in_len + idx, CompareConst.NAN)
196
- bench_struct_out.insert(idx, CompareConst.NAN)
197
- bench_summary.insert(npu_in_len + idx, CompareConst.NAN)
198
- del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
227
+ bench_op_name.insert(npu_in_len + idx, CompareConst.N_A)
228
+ bench_struct_out.insert(idx, CompareConst.N_A)
229
+ bench_summary.insert(npu_in_len + idx, CompareConst.N_A)
230
+ del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
231
+ CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
199
232
  return del_bench_dict
200
233
 
201
-
234
+
235
+ def sort_by_execution_sequence(npu_data, bench_data, mapping_list, flag):
236
+ def generate_execution_sequence(data):
237
+ sequence_map = {}
238
+ for index, item in enumerate(data.keys()):
239
+ if flag in item:
240
+ item_split = item.split(Const.SEP)
241
+ item_name = Const.SEP.join(item_split[0:-2])
242
+ item_index = item_split[-1]
243
+ if item_index == 'forward' or item_index == 'backward':
244
+ item_index = item_split[-2]
245
+ item_key = f"{item_name}.{item_index}"
246
+ sequence_map[item_key] = index
247
+ return sequence_map
248
+
249
+ npu_map = generate_execution_sequence(npu_data)
250
+ bench_map = generate_execution_sequence(bench_data)
251
+
252
+ def sort_by_map(item):
253
+ first_key = npu_map.get(item[0], sys.maxsize)
254
+ second_key = bench_map.get(item[1], sys.maxsize)
255
+ return first_key, second_key
256
+
257
+ return sorted(mapping_list, key=sort_by_map)
258
+
259
+
260
+ def generate_kernel_data(map_value, data, flag):
261
+ if not map_value:
262
+ return [], []
263
+ inputs_name = []
264
+ outputs_name = []
265
+ map_split = map_value.split(Const.SEP)
266
+ map_name = Const.SEP.join(map_split[0:-1])
267
+ map_index = map_split[-1]
268
+ for key, value in data.items():
269
+ if key.find(flag) != -1 and key.find(map_name) != -1:
270
+ if key.split(Const.SEP)[-1] != map_index and key.split(Const.SEP)[-2] != map_index :
271
+ continue
272
+ if flag == 'forward':
273
+ input_args = value.get('input_args', {})
274
+ else:
275
+ input_args = value.get('input', {})
276
+ output_args = value.get('output', {})
277
+ for i in range(len(input_args)):
278
+ inputs_name.append(f"{key}.input.{i}")
279
+ for i in range(len(output_args)):
280
+ outputs_name.append(f"{key}.output.{i}")
281
+ return inputs_name, outputs_name
282
+
283
+
284
+ def generate_file_mapping(npu_json_path, bench_json_path, mapping_list):
285
+
286
+ npu_data = load_json(npu_json_path).get("data", {})
287
+ bench_data = load_json(bench_json_path).get("data", {})
288
+
289
+ forward_data = []
290
+ mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.FORWARD)
291
+ for map_value in mapping_list:
292
+ npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "forward")
293
+ bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "forward")
294
+ inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
295
+ outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
296
+ forward_data.extend(inputs_zip)
297
+ forward_data.extend(outputs_zip)
298
+
299
+ backward_data = []
300
+ mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.BACKWARD)
301
+ for map_value in mapping_list:
302
+ npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward")
303
+ bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward")
304
+ inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
305
+ outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
306
+ backward_data.extend(inputs_zip)
307
+ backward_data.extend(outputs_zip)
308
+
309
+ kernel_data = forward_data + backward_data
310
+ result = {key: value for key, value in kernel_data if key is not None}
311
+
312
+ return result
313
+
314
+
315
+ def check_cross_framework(bench_json_path):
316
+ pattern = r'"data_name":\s*"[^"]+\.pt"'
317
+ with FileOpen(bench_json_path, 'r') as file:
318
+ for line in file:
319
+ if re.search(pattern, line):
320
+ return True
321
+ return False
322
+
323
+
202
324
  def ms_compare(input_param, output_path, **kwargs):
203
325
  try:
204
326
  stack_mode = kwargs.get('stack_mode', False)
@@ -206,14 +328,30 @@ def ms_compare(input_param, output_path, **kwargs):
206
328
  fuzzy_match = kwargs.get('fuzzy_match', False)
207
329
  cell_mapping = kwargs.get('cell_mapping', None)
208
330
  api_mapping = kwargs.get('api_mapping', None)
331
+ data_mapping = kwargs.get('data_mapping', None)
332
+ layer_mapping = kwargs.get('layer_mapping', None)
333
+
209
334
  summary_compare, md5_compare = task_dumppath_get(input_param)
210
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
335
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
211
336
  create_directory(output_path)
212
337
  check_compare_param(input_param, output_path, summary_compare, md5_compare)
213
338
  except (CompareException, FileCheckException) as error:
214
339
  logger.error('Compare failed. Please check the arguments and do it again!')
215
340
  raise CompareException(error.code) from error
216
- ms_comparator = MSComparator(cell_mapping, api_mapping)
341
+ if layer_mapping:
342
+ pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK)
343
+ ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK)
344
+ mapping = load_yaml(layer_mapping)
345
+ ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct)
346
+ pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct)
347
+ layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping)
348
+ data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping)
349
+
350
+ data_mapping_name = add_time_with_yaml(f"data_mapping")
351
+ data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}")
352
+ save_yaml(data_mapping_path, data_mapping)
353
+ is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
354
+ ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
217
355
  ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
218
356
  auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
219
357
  md5_compare=md5_compare)
@@ -47,8 +47,10 @@ def npy_data_read(data_path, npy_file_list, mapping_dict):
47
47
  def statistic_data_read(statistic_file_list, statistic_file_path):
48
48
  data_list = []
49
49
  statistic_data_list = []
50
- header_index = {'Data Type': None, 'Shape': None, 'Max Value': None, 'Min Value': None,
51
- 'Avg Value': None, 'L2Norm Value': None}
50
+ header_index = {
51
+ 'Data Type': None, 'Shape': None, 'Max Value': None,
52
+ 'Min Value': None,'Avg Value': None, 'L2Norm Value': None
53
+ }
52
54
  for statistic_file in statistic_file_list:
53
55
  with FileOpen(statistic_file, "r") as f:
54
56
  csv_reader = csv.reader(f, delimiter=",")
@@ -65,8 +67,9 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
65
67
 
66
68
  for data in statistic_data_list:
67
69
  compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}"
70
+ op_name = f"{compare_key} {statistic_file_path}"
68
71
  timestamp = int(data[4])
69
- result_data = [statistic_file_path, compare_key, timestamp]
72
+ result_data = [op_name, compare_key, timestamp]
70
73
  for key in header_index.keys():
71
74
  if header_index[key] is None:
72
75
  result_data.append(np.nan)
@@ -239,9 +242,20 @@ class GraphMSComparator:
239
242
  compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
240
243
  compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
241
244
  check_path_before_create(compare_result_path)
245
+ self.to_excel(compare_result_df, compare_result_path)
246
+ logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
247
+
248
+ def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
249
+ size = len(compare_result_df)
250
+ # sheet size cannot be larger than 1048576
251
+ if size < CompareConst.MAX_EXCEL_LENGTH:
252
+ compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if need_slice else compare_result_path
242
253
  compare_result_df.to_excel(compare_result_path, index=False)
243
254
  change_mode(compare_result_path, FileCheckConst.DATA_FILE_AUTHORITY)
244
- logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
255
+ return slice_num + 1
256
+ else:
257
+ slice_num = self.to_excel(compare_result_df.iloc[0: size//2], compare_result_path, slice_num, True)
258
+ return self.to_excel(compare_result_df.iloc[size//2:], compare_result_path, slice_num, True)
245
259
 
246
260
  def compare_process(self, rank_id, step_id):
247
261
  # generate data_path
@@ -251,8 +265,8 @@ class GraphMSComparator:
251
265
  return [], ''
252
266
 
253
267
  # generate file name
254
- npu_mode = 'ERROR_MODE'
255
- bench_mode = 'ERROR_MODE'
268
+ npu_mode = GraphMode.ERROR_MODE
269
+ bench_mode = GraphMode.ERROR_MODE
256
270
  npu_data_list = []
257
271
  bench_data_list = []
258
272
  for npu_data_path in npu_data_path_list:
@@ -262,7 +276,7 @@ class GraphMSComparator:
262
276
  bench_mode, data_list = generate_data_name(bench_data_path)
263
277
  bench_data_list.extend(data_list)
264
278
 
265
- if npu_mode == "ERROR_MODE" or bench_mode == "ERROR_MODE":
279
+ if npu_mode == GraphMode.ERROR_MODE or bench_mode == GraphMode.ERROR_MODE:
266
280
  logger.warning(f"Data_path {npu_data_path} or {bench_data_path} is not exist.")
267
281
  return [], ''
268
282
  if npu_mode != bench_mode:
@@ -286,11 +300,13 @@ class GraphMSComparator:
286
300
  CompareConst.BENCH_NORM])
287
301
 
288
302
  npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
289
- npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(np.float32)
303
+ npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
290
304
 
291
- bench_float_type = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
292
- CompareConst.BENCH_NORM]
293
- bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(np.float32)
305
+ bench_float_type = [
306
+ CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
307
+ CompareConst.BENCH_MEAN,CompareConst.BENCH_NORM
308
+ ]
309
+ bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
294
310
 
295
311
  npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
296
312
  bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()