mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,12 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import copy
2
- import csv
3
17
  import glob
4
18
  import os
19
+ import re
5
20
 
6
21
  import numpy as np
7
22
  import pandas as pd
8
- from msprobe.core.common.const import CompareConst, GraphMode, Const, FileCheckConst
9
- from msprobe.core.common.file_utils import FileOpen, check_path_before_create, change_mode, load_npy
23
+ from msprobe.core.common.const import CompareConst, GraphMode, Const
24
+ from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
10
25
  from msprobe.core.common.log import logger
11
26
  from msprobe.core.common.utils import add_time_with_xlsx, CompareException
12
27
  from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
@@ -14,7 +29,7 @@ from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_che
14
29
  from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
15
30
 
16
31
 
17
- class row_data:
32
+ class RowData:
18
33
  def __init__(self, mode):
19
34
  self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
20
35
  self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
@@ -28,17 +43,34 @@ class row_data:
28
43
  return self.data
29
44
 
30
45
 
46
+ def get_name_dict(name: str) -> dict:
47
+ compare_pattern = re.compile(r'^([^.]+)\.([^.]+)\.([^.]+)\.([^.]+)\.(\d+(?:\.\d+)*)\.'
48
+ r'((?:in|out)put(?:\.\d+)*)\.([^.]+)\.([^.]+)\.npy$')
49
+ match = compare_pattern.match(name)
50
+ if match:
51
+ return {'op_type': match.group(1),
52
+ 'op_name': match.group(2),
53
+ 'task_id': match.group(3),
54
+ 'stream_id': match.group(4),
55
+ 'timestamp': match.group(5).split(Const.SEP)[0],
56
+ 'input_output_index': match.group(6),
57
+ 'slot': match.group(7),
58
+ 'format': match.group(8)}
59
+ return {}
60
+
61
+
31
62
  def npy_data_read(data_path, npy_file_list, mapping_dict):
32
63
  data_list = []
64
+ compare_key_elements = ['op_name', 'task_id', 'input_output_index', 'slot']
33
65
  for data in npy_file_list:
34
66
  if data in mapping_dict:
35
- split_list = mapping_dict[data].split(Const.SEP)
67
+ name_dict = get_name_dict(mapping_dict[data])
36
68
  else:
37
- split_list = data.split(Const.SEP)
38
- if len(split_list) < 7:
69
+ name_dict = get_name_dict(data)
70
+ if not name_dict:
39
71
  continue
40
- compare_key = f"{split_list[1]}.{split_list[2]}.{split_list[3]}.{split_list[5]}.{split_list[6]}"
41
- timestamp = convert_to_int(split_list[4])
72
+ compare_key = Const.SEP.join([name_dict.get(element) for element in compare_key_elements])
73
+ timestamp = convert_to_int(name_dict.get('timestamp'))
42
74
 
43
75
  data_list.append([os.path.join(data_path, data), compare_key, timestamp])
44
76
  return data_list
@@ -48,18 +80,17 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
48
80
  data_list = []
49
81
  statistic_data_list = []
50
82
  header_index = {
51
- 'Data Type': None, 'Shape': None, 'Max Value': None,
52
- 'Min Value': None,'Avg Value': None, 'L2Norm Value': None
83
+ 'Data Type': None, 'Shape': None, 'Max Value': None,
84
+ 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
53
85
  }
54
86
  for statistic_file in statistic_file_list:
55
- with FileOpen(statistic_file, "r") as f:
56
- csv_reader = csv.reader(f, delimiter=",")
57
- header = next(csv_reader)
58
- for key in header_index.keys():
59
- for index, value in enumerate(header):
60
- if key == value:
61
- header_index[key] = index
62
- statistic_data_list.extend([row for row in csv_reader])
87
+ content = read_csv(statistic_file, as_pd=False)
88
+ header = content[0]
89
+ for key in header_index.keys():
90
+ for index, value in enumerate(header):
91
+ if key == value:
92
+ header_index[key] = index
93
+ statistic_data_list.extend(content[1:])
63
94
 
64
95
  for key in header_index.keys():
65
96
  if header_index[key] is None:
@@ -97,11 +128,9 @@ def generate_data_name(data_path):
97
128
  mapping_dict = {}
98
129
  if mapping_exist:
99
130
  for mapping_file in mapping_file_list:
100
- with FileOpen(mapping_file, "r") as f:
101
- csv_reader = csv.reader(f, delimiter=",")
102
- header = next(csv_reader)
103
- for row in csv_reader:
104
- mapping_dict[row[0]] = row[1]
131
+ content = read_csv(mapping_file, False)
132
+ for row in content[1:]:
133
+ mapping_dict[row[0]] = row[1]
105
134
 
106
135
  if npy_exist:
107
136
  data_list = npy_data_read(data_path, npy_file_list, mapping_dict)
@@ -136,7 +165,7 @@ class GraphMSComparator:
136
165
  def compare_ops(compare_result_db, mode):
137
166
 
138
167
  def npy_mode_compute(row):
139
- result_dict = row_data(GraphMode.NPY_MODE)()
168
+ result_dict = RowData(GraphMode.NPY_MODE)()
140
169
 
141
170
  def process_npy_file(file_path, name_prefix, result):
142
171
  if os.path.exists(file_path):
@@ -171,7 +200,7 @@ class GraphMSComparator:
171
200
  return pd.Series(result_dict)
172
201
 
173
202
  def statistic_mode_compute(row):
174
- result_dict = row_data('STATISTIC')()
203
+ result_dict = RowData('STATISTIC')()
175
204
 
176
205
  def update_result_dict(result, rows, prefix):
177
206
  result[f'{prefix} Name'] = rows[f'{prefix} Name']
@@ -198,24 +227,30 @@ class GraphMSComparator:
198
227
  result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM])
199
228
  result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[
200
229
  CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0
201
- result_dict[CompareConst.MAX_RELATIVE_ERR] = str(result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
230
+ if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]):
231
+ result_dict[CompareConst.MAX_RELATIVE_ERR] = str(
232
+ result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
202
233
  result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[
203
234
  CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0
204
- result_dict[CompareConst.MIN_RELATIVE_ERR] = str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
235
+ if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]):
236
+ result_dict[CompareConst.MIN_RELATIVE_ERR] = \
237
+ str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
205
238
  result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[
206
239
  CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0
207
- result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
208
- result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
240
+ if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
241
+ result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
242
+ result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
209
243
  result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
210
244
  CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
211
- result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
212
- result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
245
+ if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
246
+ result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
247
+ result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
213
248
  magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
214
249
  max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
215
- if magnitude_diff > CompareConst.MAGNITUDE:
216
- result_dict[CompareConst.ACCURACY] = 'No'
217
- else:
218
- result_dict[CompareConst.ACCURACY] = 'Yes'
250
+ if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
251
+ magnitude_diff = 0
252
+ result_dict[CompareConst.ACCURACY] = CompareConst.YES if \
253
+ magnitude_diff <= CompareConst.MAGNITUDE else CompareConst.NO
219
254
 
220
255
  return pd.Series(result_dict)
221
256
 
@@ -238,24 +273,23 @@ class GraphMSComparator:
238
273
  is_empty = True
239
274
  if is_empty or not mode:
240
275
  continue
241
- compare_result_df = self._do_multi_process(compare_result_df, mode)
276
+ compare_result_df = self.do_multi_process(compare_result_df, mode)
242
277
  compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
243
278
  compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
244
- check_path_before_create(compare_result_path)
245
279
  self.to_excel(compare_result_df, compare_result_path)
246
280
  logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
247
-
281
+
248
282
  def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
249
283
  size = len(compare_result_df)
250
284
  # sheet size cannot be larger than 1048576
251
285
  if size < CompareConst.MAX_EXCEL_LENGTH:
252
- compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if need_slice else compare_result_path
253
- compare_result_df.to_excel(compare_result_path, index=False)
254
- change_mode(compare_result_path, FileCheckConst.DATA_FILE_AUTHORITY)
286
+ compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
287
+ need_slice else compare_result_path
288
+ save_excel(compare_result_path, compare_result_df)
255
289
  return slice_num + 1
256
290
  else:
257
- slice_num = self.to_excel(compare_result_df.iloc[0: size//2], compare_result_path, slice_num, True)
258
- return self.to_excel(compare_result_df.iloc[size//2:], compare_result_path, slice_num, True)
291
+ slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
292
+ return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
259
293
 
260
294
  def compare_process(self, rank_id, step_id):
261
295
  # generate data_path
@@ -303,8 +337,8 @@ class GraphMSComparator:
303
337
  npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
304
338
 
305
339
  bench_float_type = [
306
- CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
307
- CompareConst.BENCH_MEAN,CompareConst.BENCH_NORM
340
+ CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
341
+ CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
308
342
  ]
309
343
  bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
310
344
 
@@ -355,7 +389,7 @@ class GraphMSComparator:
355
389
  rank_step_path_dict[rank_step_key] = [dir_path]
356
390
  return dict(sorted(rank_step_path_dict.items()))
357
391
 
358
- def _do_multi_process(self, result_df, mode):
392
+ def do_multi_process(self, result_df, mode):
359
393
  try:
360
394
  result_df = _ms_graph_handle_multi_process(self.compare_ops, result_df, mode)
361
395
  except ValueError as e:
@@ -33,7 +33,7 @@ class DebuggerConfig:
33
33
  self.level_ori = common_config.level
34
34
  self.list = [] if not task_config.list else task_config.list
35
35
  self.scope = [] if not task_config.scope else task_config.scope
36
- self.data_mode = [] if not task_config.data_mode else task_config.data_mode
36
+ self.data_mode = [Const.ALL] if not task_config.data_mode else task_config.data_mode
37
37
  self.file_format = task_config.file_format
38
38
  self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
39
39
  self.check_mode = task_config.check_mode
@@ -52,6 +52,9 @@ class DebuggerConfig:
52
52
  self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
53
53
  raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
54
54
  f"but got {self.pert_type}.")
55
+ if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
56
+ raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
57
+ f"but got {self.handler_type}.")
55
58
  self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
56
59
 
57
60
  def check(self):
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -14,13 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ from collections import defaultdict
17
18
 
18
19
  import mindspore as ms
19
20
  from mindspore._c_expression import MSContext
20
21
 
21
22
  from msprobe.core.common.const import Const, MsgConst
23
+ from msprobe.mindspore.cell_processor import CellProcessor
22
24
  from msprobe.mindspore.common.const import Const as MsConst
23
25
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
26
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
24
27
  from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
25
28
  from msprobe.mindspore.ms_config import parse_json_config
26
29
  from msprobe.mindspore.runtime import Runtime
@@ -128,6 +131,9 @@ class PrecisionDebugger:
128
131
  return
129
132
  if instance.service:
130
133
  instance.service.step()
134
+ HOOKCell.cell_count = defaultdict(int)
135
+ CellProcessor.reset_cell_stats()
136
+
131
137
  Runtime.step_count += 1
132
138
 
133
139
  @classmethod
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -40,6 +40,8 @@ class DumpToolFactory:
40
40
 
41
41
  @staticmethod
42
42
  def create(config: DebuggerConfig):
43
+ if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
44
+ raise Exception("data_mode must be one of all, input, output.")
43
45
  tool = DumpToolFactory.tools.get(config.level)
44
46
  if not tool:
45
47
  raise Exception("Valid level is needed.")
@@ -24,6 +24,12 @@ from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTenso
24
24
  from msprobe.core.common.utils import Const
25
25
 
26
26
 
27
+ def stub_method(method):
28
+ def wrapped_method(*args, **kwargs):
29
+ return method(*args, **kwargs)
30
+ return wrapped_method
31
+
32
+
27
33
  class ApiRegistry:
28
34
  def __init__(self):
29
35
  self.tensor_ori_attr = {}
@@ -50,9 +56,13 @@ class ApiRegistry:
50
56
  if Const.SEP in api:
51
57
  sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
52
58
  sub_module = getattr(ori_api_group, sub_module_name)
53
- api_ori_attr[api] = getattr(sub_module, sub_op)
59
+ ori_api_func = getattr(sub_module, sub_op)
54
60
  else:
55
- api_ori_attr[api] = getattr(ori_api_group, api)
61
+ ori_api_func = getattr(ori_api_group, api)
62
+ if ori_api_group == StubTensor:
63
+ api_ori_attr[api] = stub_method(ori_api_func)
64
+ continue
65
+ api_ori_attr[api] = ori_api_func
56
66
 
57
67
  @staticmethod
58
68
  def set_api_attr(api_group, attr_dict):
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,18 +12,16 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
16
  import os
17
17
 
18
- import mindspore as ms
19
- from mindspore.common.tensor import Tensor
20
18
  from mindspore import ops
19
+ from mindspore.common.tensor import Tensor
21
20
 
22
- from msprobe.mindspore.common.log import logger
23
21
  from msprobe.core.common.utils import Const, DumpException
24
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
25
- ModuleBackwardInputs, ModuleBackwardOutputs
22
+ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
23
+ ModuleForwardInputsOutputs)
24
+ from msprobe.mindspore.common.log import logger
26
25
 
27
26
 
28
27
  class PrimitiveHookService:
@@ -41,6 +40,7 @@ class PrimitiveHookService:
41
40
  Returns:
42
41
  callable: 包装后的 primitive 函数。
43
42
  """
43
+
44
44
  def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
45
45
  """
46
46
  创建反向 hook 函数,用于捕获梯度。
@@ -54,26 +54,24 @@ class PrimitiveHookService:
54
54
  Returns:
55
55
  callable: 反向 hook 函数。
56
56
  """
57
- def backward_hook(grad):
58
57
 
59
- captured_grads.append(grad)
58
+ def backward_hook(grad):
59
+ captured_grads.extend(grad)
60
60
  backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
61
61
 
62
62
  try:
63
- if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
63
+ if hook_type == Const.INPUT:
64
64
  self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
65
65
  new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
66
66
  self.service_instance.data_collector.backward_output_data_collect(
67
67
  backward_primitive_name, self, os.getpid(), new_module_input_output
68
68
  )
69
- captured_grads.clear()
70
- elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
69
+ elif hook_type == Const.OUTPUT:
71
70
  self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
72
71
  new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
73
72
  self.service_instance.data_collector.backward_input_data_collect(
74
73
  backward_primitive_name, self, os.getpid(), new_module_input_output
75
74
  )
76
- captured_grads.clear()
77
75
 
78
76
  except Exception as exception:
79
77
  logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
@@ -104,7 +102,7 @@ class PrimitiveHookService:
104
102
  hooked_inputs.append(arg_hooked)
105
103
  else:
106
104
  hooked_inputs.append(arg)
107
- return hooked_inputs
105
+ return tuple(hooked_inputs)
108
106
 
109
107
  def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
110
108
  """
@@ -178,7 +176,7 @@ class PrimitiveHookService:
178
176
  module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
179
177
  try:
180
178
  self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
181
- os.getpid(), module_input_output)
179
+ os.getpid(), module_input_output)
182
180
  except Exception as exception:
183
181
  logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
184
182
  f"primitive_name: {primitive_name}")
@@ -203,4 +201,3 @@ class PrimitiveHookService:
203
201
  self.primitive_counters[primitive_name] = 0
204
202
  else:
205
203
  self.primitive_counters[primitive_name] += 1
206
-
@@ -490,6 +490,31 @@ ops:
490
490
  - scatter_update
491
491
  - derivative
492
492
  - jet
493
+ - row_stack
494
+ - gather
495
+ - arange
496
+ - cond
497
+ - slice_scatter
498
+ - clip_by_norm
499
+ - eps
500
+ - layer_norm
501
+ - cast
502
+ - numel
503
+ - permute
504
+ - select_scatter
505
+ - group_norm
506
+ - eq
507
+ - embedding
508
+ - ones_like
509
+ - zeros
510
+ - nanmean
511
+ - shape
512
+ - zeros_like
513
+ - ones
514
+ - diagonal_scatter
515
+ - vander
516
+ - is_nonzero
517
+ - rotary_position_embedding
493
518
 
494
519
  tensor:
495
520
  - __abs__
@@ -20,7 +20,7 @@ from mindspore import Tensor
20
20
  from mindspore._c_expression import PyNativeExecutor_
21
21
  from mindspore.common.api import _MindsporeFunctionExecutor
22
22
 
23
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
23
+ from msprobe.core.common.log import logger
24
24
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
25
25
  from msprobe.core.common.const import Const
26
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
@@ -33,6 +33,8 @@ def dump_jit(name, in_feat, out_feat, is_forward):
33
33
  index = ori_args.find("<")
34
34
  if index != 0 and index != -1:
35
35
  result = ori_args[0:index]
36
+ elif name is not None and "<" not in str(name):
37
+ result = str(name)
36
38
  else:
37
39
  result = "JitFunction"
38
40
  if JitDump.need_dump():
@@ -47,7 +49,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
47
49
  name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
48
50
  Const.BACKWARD
49
51
  JitDump.data_collector.update_api_or_module_name(name_template)
50
- module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat ,grad_output=out_feat)
52
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
51
53
  JitDump.data_collector.backward_data_collect(name_template, None, pid, module_input_output)
52
54
 
53
55
 
@@ -59,15 +61,25 @@ class JitDump(_MindsporeFunctionExecutor):
59
61
 
60
62
  def __init__(self, *args, **kwargs):
61
63
  super().__init__(*args, **kwargs)
64
+ self.name = None
65
+ if len(args) > 0:
66
+ self.name = args[0].__name__
62
67
  self._executor = PyNativeExecutor_.get_instance()
63
68
 
64
69
  def __call__(self, *args, **kwargs):
65
- api_register.api_set_ori_func()
70
+ if JitDump.jit_dump_switch:
71
+ api_register.api_set_ori_func()
66
72
  out = super().__call__(*args, **kwargs)
67
73
  if JitDump.jit_dump_switch and len(args) > 0:
68
- dump_jit(args[0], args, out, True)
74
+ if self.name and self.name != "construct":
75
+ dump_jit(self.name, args, out, True)
76
+ else:
77
+ dump_jit(args[0], args, out, True)
69
78
  JitDump.jit_enable = True
70
- api_register.api_set_hook_func()
79
+ elif len(args) == 0:
80
+ logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
81
+ if JitDump.jit_dump_switch:
82
+ api_register.api_set_hook_func()
71
83
  return out
72
84
 
73
85
  @classmethod
@@ -13,10 +13,9 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
17
16
  import os
18
17
 
19
- from msprobe.core.common.file_utils import FileOpen, create_directory
18
+ from msprobe.core.common.file_utils import create_directory, save_json
20
19
  from msprobe.mindspore.common.log import logger
21
20
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
22
21
 
@@ -62,8 +61,7 @@ class KernelGraphDump:
62
61
  json_path = self.dump_json["common_dump_settings"]["path"]
63
62
  create_directory(json_path)
64
63
  json_path = os.path.join(json_path, "kernel_graph_dump.json")
65
- with FileOpen(json_path, 'w') as f:
66
- json.dump(self.dump_json, f)
64
+ save_json(json_path, self.dump_json, indent=4)
67
65
  logger.info(json_path + " has been created.")
68
66
  os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
69
67
  if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
@@ -13,11 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
17
16
  import os
18
17
 
19
18
  from msprobe.core.common.const import Const
20
- from msprobe.core.common.file_utils import FileOpen, create_directory
19
+ from msprobe.core.common.file_utils import create_directory, save_json
21
20
  from msprobe.mindspore.common.log import logger
22
21
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
23
22
 
@@ -70,8 +69,7 @@ class KernelKbykDump:
70
69
  json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
71
70
  create_directory(json_path)
72
71
  json_path = os.path.join(json_path, "kernel_kbyk_dump.json")
73
- with FileOpen(json_path, 'w') as f:
74
- json.dump(self.dump_json, f)
72
+ save_json(json_path, self.dump_json, indent=4)
75
73
  logger.info(json_path + " has been created.")
76
74
 
77
75
  os.environ["MINDSPORE_DUMP_CONFIG"] = json_path