mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import time
3
18
  import sys
@@ -5,6 +20,16 @@ from functools import wraps
5
20
  from msprobe.core.common.const import MsgConst
6
21
 
7
22
 
23
+ def filter_special_chars(func):
24
+ @wraps(func)
25
+ def func_level(self, msg, **kwargs):
26
+ for char in MsgConst.SPECIAL_CHAR:
27
+ msg = msg.replace(char, '_')
28
+ return func(self, msg, **kwargs)
29
+
30
+ return func_level
31
+
32
+
8
33
  class BaseLogger:
9
34
  def __init__(self):
10
35
  self.rank = None
@@ -21,14 +46,6 @@ class BaseLogger:
21
46
  def get_rank(self):
22
47
  return self.rank
23
48
 
24
- def filter_special_chars(func):
25
- @wraps(func)
26
- def func_level(self, msg, **kwargs):
27
- for char in MsgConst.SPECIAL_CHAR:
28
- msg = msg.replace(char, '_')
29
- return func(self, msg, **kwargs)
30
- return func_level
31
-
32
49
  @filter_special_chars
33
50
  def error(self, msg):
34
51
  if self.level <= MsgConst.LogLevel.ERROR.value:
@@ -56,6 +73,7 @@ class BaseLogger:
56
73
  return func(*args, **kwargs)
57
74
  else:
58
75
  return None
76
+
59
77
  return func_rank_0
60
78
 
61
79
  def info_on_rank_0(self, msg):
@@ -66,7 +84,7 @@ class BaseLogger:
66
84
 
67
85
  def warning_on_rank_0(self, msg):
68
86
  return self.on_rank_0(self.warning)(msg)
69
-
87
+
70
88
  def error_log_with_exp(self, msg, exception):
71
89
  self.error(msg)
72
90
  raise exception
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,21 +12,23 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import collections
18
17
  import os
19
18
  import re
20
19
  import subprocess
21
20
  import time
22
- import json
21
+ from collections import defaultdict
23
22
  from datetime import datetime, timezone
23
+ from functools import wraps
24
+
25
+ import numpy as np
24
26
 
25
27
  from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
26
28
  from msprobe.core.common.const import Const, CompareConst
27
29
  from msprobe.core.common.log import logger
28
30
  from msprobe.core.common.exceptions import MsprobeException
29
31
 
30
-
31
32
  device = collections.namedtuple('device', ['type', 'index'])
32
33
  prefixes = ['api_stack', 'list', 'range', 'acl']
33
34
 
@@ -68,6 +69,8 @@ class MsprobeBaseException(Exception):
68
69
  FUNCTION_CALL_ERROR = 28
69
70
  FORWARD_DATA_COLLECTION_ERROR = 29
70
71
  BACKWARD_DATA_COLLECTION_ERROR = 30
72
+ INVALID_KEY_ERROR = 31
73
+ MISSING_HEADER_ERROR = 32
71
74
 
72
75
  def __init__(self, code, error_info: str = ""):
73
76
  super(MsprobeBaseException, self).__init__()
@@ -99,7 +102,14 @@ class DumpException(MsprobeBaseException):
99
102
  return f"Dump Error Code {self.code}: {self.error_info}"
100
103
 
101
104
 
102
- def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
105
+ def is_json_file(file_path):
106
+ if isinstance(file_path, str) and file_path.lower().endswith('.json'):
107
+ return True
108
+ else:
109
+ return False
110
+
111
+
112
+ def check_compare_param(input_param, output_path, dump_mode):
103
113
  if not isinstance(input_param, dict):
104
114
  logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
105
115
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -107,10 +117,19 @@ def check_compare_param(input_param, output_path, summary_compare=False, md5_com
107
117
  logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
108
118
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
109
119
 
110
- check_file_or_directory_path(input_param.get("npu_json_path"), False)
111
- check_file_or_directory_path(input_param.get("bench_json_path"), False)
112
- check_file_or_directory_path(input_param.get("stack_json_path"), False)
113
- if not summary_compare and not md5_compare:
120
+ def check_json_path(json_path_str):
121
+ json_path = input_param.get(json_path_str)
122
+ check_file_or_directory_path(json_path, False)
123
+ json_type_check = is_json_file(json_path)
124
+ if not json_type_check:
125
+ logger.error(f"Invalid {json_path_str}: {json_path}, please check!")
126
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
127
+
128
+ check_json_path("npu_json_path")
129
+ check_json_path("bench_json_path")
130
+ check_json_path("stack_json_path")
131
+
132
+ if dump_mode == Const.ALL:
114
133
  check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
115
134
  check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
116
135
  check_file_or_directory_path(output_path, True)
@@ -179,7 +198,7 @@ def execute_command(cmd):
179
198
  line = process.stdout.readline()
180
199
  line = line.strip()
181
200
  if line:
182
- print(line)
201
+ logger.info(line)
183
202
  if process.returncode != 0:
184
203
  logger.error('Failed to execute command:%s' % " ".join(cmd))
185
204
  raise CompareException(CompareException.INVALID_DATA_ERROR)
@@ -212,25 +231,29 @@ def md5_find(data):
212
231
  for data_detail in data[key_op][api_info]:
213
232
  if data_detail and 'md5' in data_detail:
214
233
  return True
215
- elif 'md5' in data[key_op][api_info]:
234
+ elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
216
235
  return True
217
236
  return False
218
237
 
219
238
 
220
- def struct_json_get(input_param, framework):
221
- if framework == Const.PT_FRAMEWORK:
222
- prefix = "bench"
223
- elif framework == Const.MS_FRAMEWORK:
224
- prefix = "npu"
225
- else:
226
- logger.error("Error framework found.")
227
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
239
+ def detect_framework_by_dump_json(file_path):
240
+ pattern_ms = r'"type":\s*"mindspore'
241
+ pattern_pt = r'"type":\s*"torch'
242
+ with FileOpen(file_path, 'r') as file:
243
+ for line in file:
244
+ if re.search(pattern_ms, line):
245
+ return Const.MS_FRAMEWORK
246
+ if re.search(pattern_pt, line):
247
+ return Const.PT_FRAMEWORK
248
+ logger.error(f"{file_path} must be based on the MindSpore or PyTorch framework.")
249
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
250
+
228
251
 
229
- frame_json_path = input_param.get(f"{prefix}_json_path", None)
230
- if not frame_json_path:
231
- logger.error(f"Please check the json path is valid.")
252
+ def get_stack_construct_by_dump_json_path(dump_json_path):
253
+ if not dump_json_path:
254
+ logger.error("The path is empty. Please enter a valid path.")
232
255
  raise CompareException(CompareException.INVALID_PATH_ERROR)
233
- directory = os.path.dirname(frame_json_path)
256
+ directory = os.path.dirname(dump_json_path)
234
257
  check_file_or_directory_path(directory, True)
235
258
  stack_json = os.path.join(directory, "stack.json")
236
259
  construct_json = os.path.join(directory, "construct.json")
@@ -240,41 +263,57 @@ def struct_json_get(input_param, framework):
240
263
  return stack, construct
241
264
 
242
265
 
243
- def task_dumppath_get(input_param):
266
+ def set_dump_path(input_param):
244
267
  npu_path = input_param.get("npu_json_path", None)
245
268
  bench_path = input_param.get("bench_json_path", None)
246
- if not npu_path or not bench_path:
247
- logger.error(f"Please check the json path is valid.")
269
+ npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
270
+ bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
271
+ if not npu_path_valid or not bench_path_valid:
272
+ logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}")
248
273
  raise CompareException(CompareException.INVALID_PATH_ERROR)
249
- with FileOpen(npu_path, 'r') as npu_f:
250
- npu_json_data = json.load(npu_f)
251
- with FileOpen(bench_path, 'r') as bench_f:
252
- bench_json_data = json.load(bench_f)
253
- if npu_json_data['task'] != bench_json_data['task']:
274
+ input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
275
+ input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
276
+
277
+
278
+ def get_dump_mode(input_param):
279
+ npu_path = input_param.get("npu_json_path", None)
280
+ bench_path = input_param.get("bench_json_path", None)
281
+ npu_json_data = load_json(npu_path)
282
+ bench_json_data = load_json(bench_path)
283
+
284
+ npu_task = npu_json_data.get('task', None)
285
+ bench_task = bench_json_data.get('task', None)
286
+
287
+ if not npu_task or not bench_task:
288
+ logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.")
289
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
290
+
291
+ if npu_task != bench_task:
254
292
  logger.error(f"Please check the dump task is consistent.")
255
293
  raise CompareException(CompareException.INVALID_TASK_ERROR)
256
- if npu_json_data['task'] == Const.TENSOR:
257
- summary_compare = False
258
- md5_compare = False
259
- elif npu_json_data['task'] == Const.STATISTICS:
260
- md5_compare = md5_find(npu_json_data['data'])
261
- if md5_compare:
262
- summary_compare = False
294
+
295
+ if npu_task == Const.TENSOR:
296
+ return Const.ALL
297
+
298
+ if npu_task == Const.STATISTICS:
299
+ npu_md5_compare = md5_find(npu_json_data['data'])
300
+ bench_md5_compare = md5_find(bench_json_data['data'])
301
+ if npu_md5_compare == bench_md5_compare:
302
+ return Const.MD5 if npu_md5_compare else Const.SUMMARY
263
303
  else:
264
- summary_compare = True
265
- else:
266
- logger.error(f"Compare is not required for overflow_check or free_benchmark.")
267
- raise CompareException(CompareException.INVALID_TASK_ERROR)
268
- input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
269
- input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
270
- return summary_compare, md5_compare
304
+ logger.error(f"Please check the dump task is consistent, "
305
+ f"dump mode of npu and bench should both be statistics or md5.")
306
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
271
307
 
308
+ logger.error(f"Compare applies only to task is tensor or statistics")
309
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
272
310
 
273
- def get_header_index(header_name, summary_compare=False):
274
- if summary_compare:
275
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
276
- else:
277
- header = CompareConst.COMPARE_RESULT_HEADER[:]
311
+
312
+ def get_header_index(header_name, dump_mode):
313
+ header = CompareConst.HEAD_OF_COMPARE_MODE.get(dump_mode)
314
+ if not header:
315
+ logger.error(f"{dump_mode} not in {CompareConst.HEAD_OF_COMPARE_MODE}")
316
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
278
317
  if header_name not in header:
279
318
  logger.error(f"{header_name} not in data name")
280
319
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -282,7 +321,7 @@ def get_header_index(header_name, summary_compare=False):
282
321
 
283
322
 
284
323
  def convert_tuple(data):
285
- return data if isinstance(data, tuple) else (data, )
324
+ return data if isinstance(data, tuple) else (data,)
286
325
 
287
326
 
288
327
  def check_op_str_pattern_valid(string, op_name=None, stack=False):
@@ -302,6 +341,10 @@ def is_invalid_pattern(string):
302
341
  return re.search(pattern, string)
303
342
 
304
343
 
344
+ def is_int(x):
345
+ return isinstance(x, int) and not isinstance(x, bool)
346
+
347
+
305
348
  def print_tools_ends_info():
306
349
  total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
307
350
  logger.info('*' * total_len)
@@ -315,40 +358,47 @@ def get_step_or_rank_from_string(step_or_rank, obj):
315
358
  try:
316
359
  borderlines = int(splited[0]), int(splited[1])
317
360
  except (ValueError, IndexError) as e:
318
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
361
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
319
362
  "The hyphen(-) must start and end with decimal numbers.") from e
320
363
  else:
321
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
322
- f'The string parameter for {obj} only supports formats like "3-5". Now string parameter for {obj} is "{step_or_rank}".')
323
- if all(Const.STEP_RANK_MAXIMUM_RANGE[0] <= b <= Const.STEP_RANK_MAXIMUM_RANGE[1] for b in borderlines):
364
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
365
+ f'The string parameter for {obj} only supports formats like "3-5". '
366
+ f'Now string parameter for {obj} is "{step_or_rank}".')
367
+ if all(Const.STEP_RANK_MINIMUM_VALUE <= b <= Const.STEP_RANK_MAXIMUM_VALUE for b in borderlines):
324
368
  if borderlines[0] <= borderlines[1]:
325
369
  continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
326
370
  else:
327
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
328
- f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be greater than the right boundary ({borderlines[1]}).')
371
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
372
+ f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be '
373
+ f'greater than the right boundary ({borderlines[1]}).')
329
374
  else:
330
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
331
- f"The boundaries must fall within the range of [{Const.STEP_RANK_MAXIMUM_RANGE[0]}, {Const.STEP_RANK_MAXIMUM_RANGE[1]}].")
375
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
376
+ f"The boundaries must fall within the range of "
377
+ f"[{Const.STEP_RANK_MINIMUM_VALUE}, {Const.STEP_RANK_MAXIMUM_VALUE}].")
332
378
  return continual_step_or_rank
333
379
 
334
380
 
335
381
  def get_real_step_or_rank(step_or_rank_input, obj):
336
382
  if obj not in [Const.STEP, Const.RANK]:
337
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
383
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
338
384
  f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
339
385
  if step_or_rank_input is None:
340
386
  return []
341
387
  if not isinstance(step_or_rank_input, list):
342
388
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
389
+ if len(step_or_rank_input) > Const.STEP_RANK_MAXIMUM_VALUE:
390
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
391
+ f"{obj} is invalid, its length cannot exceed {Const.STEP_RANK_MAXIMUM_VALUE}")
392
+
343
393
  real_step_or_rank = []
344
394
  for element in step_or_rank_input:
345
- if not isinstance(element, (int, str)):
346
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
395
+ if not is_int(element) and not isinstance(element, str):
396
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
347
397
  f"{obj} element {element} must be an integer or string.")
348
398
  if isinstance(element, int) and element < 0:
349
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
399
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
350
400
  f"Each element of {obj} must be non-negative, currently it is {element}.")
351
- if isinstance(element, int) and Const.STEP_RANK_MAXIMUM_RANGE[0] <= element <= Const.STEP_RANK_MAXIMUM_RANGE[1]:
401
+ if isinstance(element, int) and Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
352
402
  real_step_or_rank.append(element)
353
403
  elif isinstance(element, str) and Const.HYPHEN in element:
354
404
  continual_step_or_rank = get_step_or_rank_from_string(element, obj)
@@ -359,7 +409,7 @@ def get_real_step_or_rank(step_or_rank_input, obj):
359
409
 
360
410
 
361
411
  def check_seed_all(seed, mode):
362
- if isinstance(seed, int):
412
+ if is_int(seed):
363
413
  if seed < 0 or seed > Const.MAX_SEED_VALUE:
364
414
  logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
365
415
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
@@ -369,3 +419,66 @@ def check_seed_all(seed, mode):
369
419
  if not isinstance(mode, bool):
370
420
  logger.error("seed_all mode must be bool.")
371
421
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
422
+
423
+
424
+ def safe_get_value(container, index, container_name, key=None):
425
+ try:
426
+ # 处理字典情况
427
+ if isinstance(container, dict):
428
+ return container.get(key)[index]
429
+ # 处理列表、元组、numpy情况
430
+ elif isinstance(container, (list, tuple, np.ndarray)):
431
+ return container[index]
432
+ else:
433
+ err_msg = f"Unsupported container type for '{container_name}': {type(container)}"
434
+ logger.error(err_msg)
435
+ raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR)
436
+ except IndexError as e:
437
+ err_msg = "index out of bounds error occurs, please check!\n" \
438
+ f"{container_name} is {container}\n" \
439
+ f"index is {index}"
440
+ logger.error(err_msg)
441
+ raise MsprobeBaseException(MsprobeBaseException.INDEX_OUT_OF_BOUNDS_ERROR) from e
442
+ except TypeError as e:
443
+ err_msg = "wrong type, please check!\n" \
444
+ f"{container_name} is {container}\n" \
445
+ f"index is {index}\n" \
446
+ f"key is {key}"
447
+ logger.error(err_msg)
448
+ raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
449
+
450
+
451
+ # 记录工具函数递归的深度
452
+ recursion_depth = defaultdict(int)
453
+
454
+
455
+ # 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
456
+ def recursion_depth_decorator(func_info):
457
+ def decorator(func):
458
+ @wraps(func)
459
+ def wrapper(*args, **kwargs):
460
+ func_id = id(func)
461
+ recursion_depth[func_id] += 1
462
+ if recursion_depth[func_id] > Const.MAX_DEPTH:
463
+ msg = f"call {func_info} exceeds the recursion limit."
464
+ logger.error_log_with_exp(
465
+ msg,
466
+ MsprobeException(
467
+ MsprobeException.RECURSION_LIMIT_ERROR, msg
468
+ ),
469
+ )
470
+ try:
471
+ result = func(*args, **kwargs)
472
+ finally:
473
+ recursion_depth[func_id] -= 1
474
+ return result
475
+
476
+ return wrapper
477
+
478
+ return decorator
479
+
480
+
481
+ def check_str_param(param):
482
+ if not re.match(Const.REGEX_PREFIX_PATTERN, param):
483
+ logger.error('The parameter {} contains special characters.'.format(param))
484
+ raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR)
@@ -1,7 +1,21 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const, FileCheckConst
2
17
  from msprobe.core.common.log import logger
3
18
  from msprobe.core.common.exceptions import MsprobeException
4
- from msprobe.core.common.file_utils import FileChecker
5
19
  from msprobe.core.common.utils import get_real_step_or_rank
6
20
 
7
21
 
@@ -12,7 +26,6 @@ class CommonConfig:
12
26
  self.rank = get_real_step_or_rank(json_config.get('rank'), Const.RANK)
13
27
  self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
14
28
  self.level = json_config.get('level')
15
- self.acl_config = json_config.get('acl_config')
16
29
  self.enable_dataloader = json_config.get('enable_dataloader', False)
17
30
  self._check_config()
18
31
 
@@ -29,16 +42,6 @@ class CommonConfig:
29
42
  if not isinstance(self.enable_dataloader, bool):
30
43
  logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
31
44
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
32
- if self.acl_config:
33
- self._check_acl_config()
34
-
35
- def _check_acl_config(self):
36
- if not isinstance(self.acl_config, str):
37
- logger.error_log_with_exp("acl_config is invalid, it should be a string",
38
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
39
- file_checker = FileChecker(
40
- file_path=self.acl_config, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
41
- file_checker.common_check()
42
45
 
43
46
 
44
47
  class BaseConfig:
@@ -46,7 +49,6 @@ class BaseConfig:
46
49
  self.scope = json_config.get('scope')
47
50
  self.list = json_config.get('list')
48
51
  self.data_mode = json_config.get('data_mode')
49
- self.backward_input = json_config.get("backward_input")
50
52
  self.file_format = json_config.get("file_format")
51
53
  self.summary_mode = json_config.get("summary_mode")
52
54
  self.overflow_nums = json_config.get("overflow_nums")
@@ -74,5 +76,32 @@ class BaseConfig:
74
76
  def check_config(self):
75
77
  self._check_str_list_config(self.scope, "scope")
76
78
  self._check_str_list_config(self.list, "list")
77
- self._check_str_list_config(self.data_mode, "data_mode")
78
- self._check_str_list_config(self.backward_input, "backward_input")
79
+ self._check_data_mode()
80
+
81
+ def _check_data_mode(self):
82
+ if self.data_mode is not None:
83
+ if not isinstance(self.data_mode, list):
84
+ logger.error_log_with_exp("data_mode is invalid, it should be a list[str]",
85
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
86
+
87
+ if Const.ALL in self.data_mode and len(self.data_mode) != 1:
88
+ logger.error_log_with_exp(
89
+ "'all' cannot be combined with other options in data_mode.",
90
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
91
+ )
92
+
93
+ if len(self.data_mode) >= len(Const.DUMP_DATA_MODE_LIST):
94
+ logger.error_log_with_exp(
95
+ f"The number of elements in the data_made cannot exceed {len(Const.DUMP_DATA_MODE_LIST) - 1}.",
96
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
97
+ )
98
+
99
+ for mode in self.data_mode:
100
+ if not isinstance(mode, str):
101
+ logger.error_log_with_exp("data_mode is invalid, it should be a list[str]",
102
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
103
+ if mode not in Const.DUMP_DATA_MODE_LIST:
104
+ logger.error_log_with_exp(
105
+ f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.",
106
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
107
+ )