mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -22,19 +22,21 @@ import time
22
22
  import json
23
23
  from datetime import datetime, timezone
24
24
 
25
- from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path)
25
+ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
26
26
  from msprobe.core.common.const import Const, CompareConst
27
27
  from msprobe.core.common.log import logger
28
+ from msprobe.core.common.exceptions import MsprobeException
28
29
 
29
30
 
30
31
  device = collections.namedtuple('device', ['type', 'index'])
31
32
  prefixes = ['api_stack', 'list', 'range', 'acl']
32
33
 
33
34
 
34
- class CompareException(Exception):
35
+ class MsprobeBaseException(Exception):
35
36
  """
36
- Class for Accuracy Compare Exception
37
+ Base class for all custom exceptions.
37
38
  """
39
+ # 所有的错误代码
38
40
  NONE_ERROR = 0
39
41
  INVALID_PATH_ERROR = 1
40
42
  OPEN_FILE_ERROR = 2
@@ -57,10 +59,18 @@ class CompareException(Exception):
57
59
  INVALID_SUMMARY_MODE = 19
58
60
  INVALID_TASK_ERROR = 20
59
61
  DETACH_ERROR = 21
60
-
62
+ INVALID_OBJECT_TYPE_ERROR = 22
63
+ INVALID_CHAR_ERROR = 23
64
+ RECURSION_LIMIT_ERROR = 24
65
+ INVALID_ATTRIBUTE_ERROR = 25
66
+ OUTPUT_HOOK_ERROR = 26
67
+ INPUT_HOOK_ERROR = 27
68
+ FUNCTION_CALL_ERROR = 28
69
+ FORWARD_DATA_COLLECTION_ERROR = 29
70
+ BACKWARD_DATA_COLLECTION_ERROR = 30
61
71
 
62
72
  def __init__(self, code, error_info: str = ""):
63
- super(CompareException, self).__init__()
73
+ super(MsprobeBaseException, self).__init__()
64
74
  self.code = code
65
75
  self.error_info = error_info
66
76
 
@@ -68,74 +78,33 @@ class CompareException(Exception):
68
78
  return self.error_info
69
79
 
70
80
 
71
- class DumpException(CompareException):
72
- pass
73
-
74
-
75
- def check_mode_valid(mode, scope=None, api_list=None):
76
- if scope is None:
77
- scope = []
78
- if api_list is None:
79
- api_list = []
80
- if not isinstance(scope, list):
81
- raise ValueError("scope param set invalid, it's must be a list.")
82
- if not isinstance(api_list, list):
83
- raise ValueError("api_list param set invalid, it's must be a list.")
84
- mode_check = {
85
- Const.ALL: lambda: None,
86
- Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
87
- Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
88
- Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
89
- Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
90
- Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
91
- Const.API_STACK: lambda: None,
92
- }
93
- if mode not in Const.DUMP_MODE:
94
- msg = "Current mode '%s' is not supported. Please use the field in %s" % \
95
- (mode, Const.DUMP_MODE)
96
- raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
97
-
98
- if mode_check.get(mode)() is not None:
99
- raise mode_check.get(mode)()
100
-
101
-
102
- def check_switch_valid(switch):
103
- if switch not in ["ON", "OFF"]:
104
- logger.error("Please set switch with 'ON' or 'OFF'.")
105
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
106
-
81
+ class CompareException(MsprobeBaseException):
82
+ """
83
+ Class for Accuracy Compare Exception
84
+ """
107
85
 
108
- def check_dump_mode_valid(dump_mode):
109
- if not isinstance(dump_mode, list):
110
- logger.warning("Please set dump_mode as a list.")
111
- dump_mode = [dump_mode]
112
- if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
113
- raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
114
- if 'input' not in dump_mode and 'output' not in dump_mode:
115
- dump_mode.extend(['input', 'output'])
116
- if 'forward' not in dump_mode and 'backward' not in dump_mode:
117
- dump_mode.extend(['forward', 'backward'])
118
- if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
119
- return ["forward", "backward", "input", "output"]
120
- return dump_mode
86
+ def __init__(self, code, error_info: str = ""):
87
+ super(CompareException, self).__init__(code, error_info)
121
88
 
122
89
 
123
- def check_summary_mode_valid(summary_mode):
124
- if summary_mode not in Const.SUMMARY_MODE:
125
- msg = "The summary_mode is not valid"
126
- raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
90
+ class DumpException(MsprobeBaseException):
91
+ """
92
+ Class for Dump Exception
93
+ """
127
94
 
95
+ def __init__(self, code, error_info: str = ""):
96
+ super(DumpException, self).__init__(code, error_info)
128
97
 
129
- def check_summary_only_valid(summary_only):
130
- if not isinstance(summary_only, bool):
131
- logger.error("Params summary_only only support True or False.")
132
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
133
- return summary_only
98
+ def __str__(self):
99
+ return f"Dump Error Code {self.code}: {self.error_info}"
134
100
 
135
101
 
136
102
  def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
137
- if not (isinstance(input_param, dict) and isinstance(output_path, str)):
138
- logger.error("Invalid input parameters")
103
+ if not isinstance(input_param, dict):
104
+ logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
105
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
106
+ if not isinstance(output_path, str):
107
+ logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
139
108
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
140
109
 
141
110
  check_file_or_directory_path(input_param.get("npu_json_path"), False)
@@ -152,15 +121,12 @@ def check_compare_param(input_param, output_path, summary_compare=False, md5_com
152
121
  check_json_file(input_param, npu_json, bench_json, stack_json)
153
122
 
154
123
 
155
-
156
- def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
157
- if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)):
158
- logger.error("Invalid input parameters which should be only bool type.")
159
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
160
-
161
-
162
- def is_starts_with(string, prefix_list):
163
- return any(string.startswith(prefix) for prefix in prefix_list)
124
+ def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
125
+ arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
126
+ for arg in arg_list:
127
+ if not isinstance(arg, bool):
128
+ logger.error(f"Invalid input parameter, {arg} which should be only bool type.")
129
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
164
130
 
165
131
 
166
132
  def _check_json(json_file_handle, file_name):
@@ -198,28 +164,6 @@ def check_regex_prefix_format_valid(prefix):
198
164
  raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
199
165
 
200
166
 
201
- def get_dump_data_path(dump_dir):
202
- """
203
- Function Description:
204
- traverse directories and obtain the absolute path of dump data
205
- Parameter:
206
- dump_dir: dump data directory
207
- Return Value:
208
- dump data path,file is exist or file is not exist
209
- """
210
- dump_data_path = None
211
- file_is_exist = False
212
-
213
- check_file_or_directory_path(dump_dir, True)
214
- for dir_path, _, files in os.walk(dump_dir):
215
- if len(files) != 0:
216
- dump_data_path = dir_path
217
- file_is_exist = True
218
- break
219
- dump_data_path = dir_path
220
- return dump_data_path, file_is_exist
221
-
222
-
223
167
  def execute_command(cmd):
224
168
  """
225
169
  Function Description:
@@ -241,22 +185,6 @@ def execute_command(cmd):
241
185
  raise CompareException(CompareException.INVALID_DATA_ERROR)
242
186
 
243
187
 
244
- def parse_value_by_comma(value):
245
- """
246
- parse value by comma, like '1,2,4,8'
247
- """
248
- value_list = []
249
- value_str_list = value.split(Const.COMMA)
250
- for value_str in value_str_list:
251
- value_str = value_str.strip()
252
- if value_str.isdigit() or value_str == '-1':
253
- value_list.append(int(value_str))
254
- else:
255
- logger.error("please check your input shape.")
256
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
257
- return value_list
258
-
259
-
260
188
  def add_time_as_suffix(name):
261
189
  return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
262
190
 
@@ -265,6 +193,10 @@ def add_time_with_xlsx(name):
265
193
  return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
266
194
 
267
195
 
196
+ def add_time_with_yaml(name):
197
+ return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
198
+
199
+
268
200
  def get_time():
269
201
  return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
270
202
 
@@ -273,61 +205,6 @@ def format_value(value):
273
205
  return float('{:.12f}'.format(value))
274
206
 
275
207
 
276
- def check_seed_all(seed, mode):
277
- if isinstance(seed, int):
278
- if seed < 0 or seed > Const.MAX_SEED_VALUE:
279
- logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
280
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
281
- else:
282
- logger.error(f"Seed must be integer.")
283
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
284
- if not isinstance(mode, bool):
285
- logger.error(f"seed_all mode must be bool.")
286
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
287
-
288
-
289
- def get_process_rank(model):
290
- logger.info("Rank id is not provided. Trying to get the rank id of the model.")
291
- try:
292
- local_device = next(model.parameters()).device
293
- except StopIteration:
294
- logger.warning('There is no parameter in the model. Fail to get rank id.')
295
- return 0, False
296
- if local_device.type == 'cpu':
297
- logger.warning("Warning: the debugger is unable to get the rank id. "
298
- "This may cause the dumpped data to be corrupted in the "
299
- "case of distributed training. (You may ignore this if you are using only one card.) "
300
- "Transfer the model to npu or gpu before register_hook() to avoid this warning.")
301
- return 0, False
302
- else:
303
- return local_device.index, True
304
-
305
-
306
- def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
307
- template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
308
- pkl_dir = os.path.dirname(pkl_file_path)
309
- compare_script_path = os.path.join(pkl_dir, "compare_data.py")
310
- is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
311
-
312
- try:
313
- with FileOpen(template_path, 'r') as ftemp, \
314
- os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
315
- code_temp = ftemp.read()
316
- fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
317
- except OSError:
318
- logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
319
-
320
- logger.info(f"Generate compare script successfully which is {compare_script_path}.")
321
-
322
-
323
- def check_inplace_op(prefix):
324
- if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
325
- return False
326
- match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
327
- op_name = match_op[0] if match_op else None
328
- return op_name in Const.INPLACE_LIST
329
-
330
-
331
208
  def md5_find(data):
332
209
  for key_op in data:
333
210
  for api_info in data[key_op]:
@@ -340,6 +217,29 @@ def md5_find(data):
340
217
  return False
341
218
 
342
219
 
220
+ def struct_json_get(input_param, framework):
221
+ if framework == Const.PT_FRAMEWORK:
222
+ prefix = "bench"
223
+ elif framework == Const.MS_FRAMEWORK:
224
+ prefix = "npu"
225
+ else:
226
+ logger.error("Error framework found.")
227
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
228
+
229
+ frame_json_path = input_param.get(f"{prefix}_json_path", None)
230
+ if not frame_json_path:
231
+ logger.error(f"Please check the json path is valid.")
232
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
233
+ directory = os.path.dirname(frame_json_path)
234
+ check_file_or_directory_path(directory, True)
235
+ stack_json = os.path.join(directory, "stack.json")
236
+ construct_json = os.path.join(directory, "construct.json")
237
+
238
+ stack = load_json(stack_json)
239
+ construct = load_json(construct_json)
240
+ return stack, construct
241
+
242
+
343
243
  def task_dumppath_get(input_param):
344
244
  npu_path = input_param.get("npu_json_path", None)
345
245
  bench_path = input_param.get("bench_json_path", None)
@@ -383,3 +283,89 @@ def get_header_index(header_name, summary_compare=False):
383
283
 
384
284
  def convert_tuple(data):
385
285
  return data if isinstance(data, tuple) else (data, )
286
+
287
+
288
+ def check_op_str_pattern_valid(string, op_name=None, stack=False):
289
+ if isinstance(string, str) and is_invalid_pattern(string):
290
+ if stack:
291
+ message = f"stack info of {op_name} contains special characters, please check!"
292
+ elif not op_name:
293
+ message = f"api name contains special characters, please check!"
294
+ else:
295
+ message = f"data info of {op_name} contains special characters, please check!"
296
+ logger.error(message)
297
+ raise CompareException(CompareException.INVALID_CHAR_ERROR)
298
+
299
+
300
+ def is_invalid_pattern(string):
301
+ pattern = Const.STRING_BLACKLIST
302
+ return re.search(pattern, string)
303
+
304
+
305
+ def print_tools_ends_info():
306
+ total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
307
+ logger.info('*' * total_len)
308
+ logger.info(f"*{Const.TOOL_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
309
+ logger.info('*' * total_len)
310
+
311
+
312
+ def get_step_or_rank_from_string(step_or_rank, obj):
313
+ splited = step_or_rank.split(Const.HYPHEN)
314
+ if len(splited) == 2:
315
+ try:
316
+ borderlines = int(splited[0]), int(splited[1])
317
+ except (ValueError, IndexError) as e:
318
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
319
+ "The hyphen(-) must start and end with decimal numbers.") from e
320
+ else:
321
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
322
+ f'The string parameter for {obj} only supports formats like "3-5". Now string parameter for {obj} is "{step_or_rank}".')
323
+ if all(Const.STEP_RANK_MAXIMUM_RANGE[0] <= b <= Const.STEP_RANK_MAXIMUM_RANGE[1] for b in borderlines):
324
+ if borderlines[0] <= borderlines[1]:
325
+ continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
326
+ else:
327
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
328
+ f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be greater than the right boundary ({borderlines[1]}).')
329
+ else:
330
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
331
+ f"The boundaries must fall within the range of [{Const.STEP_RANK_MAXIMUM_RANGE[0]}, {Const.STEP_RANK_MAXIMUM_RANGE[1]}].")
332
+ return continual_step_or_rank
333
+
334
+
335
+ def get_real_step_or_rank(step_or_rank_input, obj):
336
+ if obj not in [Const.STEP, Const.RANK]:
337
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
338
+ f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
339
+ if step_or_rank_input is None:
340
+ return []
341
+ if not isinstance(step_or_rank_input, list):
342
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
343
+ real_step_or_rank = []
344
+ for element in step_or_rank_input:
345
+ if not isinstance(element, (int, str)):
346
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
347
+ f"{obj} element {element} must be an integer or string.")
348
+ if isinstance(element, int) and element < 0:
349
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
350
+ f"Each element of {obj} must be non-negative, currently it is {element}.")
351
+ if isinstance(element, int) and Const.STEP_RANK_MAXIMUM_RANGE[0] <= element <= Const.STEP_RANK_MAXIMUM_RANGE[1]:
352
+ real_step_or_rank.append(element)
353
+ elif isinstance(element, str) and Const.HYPHEN in element:
354
+ continual_step_or_rank = get_step_or_rank_from_string(element, obj)
355
+ real_step_or_rank.extend(continual_step_or_rank)
356
+ real_step_or_rank = list(set(real_step_or_rank))
357
+ real_step_or_rank.sort()
358
+ return real_step_or_rank
359
+
360
+
361
+ def check_seed_all(seed, mode):
362
+ if isinstance(seed, int):
363
+ if seed < 0 or seed > Const.MAX_SEED_VALUE:
364
+ logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
365
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
366
+ else:
367
+ logger.error("Seed must be integer.")
368
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
369
+ if not isinstance(mode, bool):
370
+ logger.error("seed_all mode must be bool.")
371
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
@@ -2,18 +2,17 @@ from msprobe.core.common.const import Const, FileCheckConst
2
2
  from msprobe.core.common.log import logger
3
3
  from msprobe.core.common.exceptions import MsprobeException
4
4
  from msprobe.core.common.file_utils import FileChecker
5
+ from msprobe.core.common.utils import get_real_step_or_rank
5
6
 
6
7
 
7
8
  class CommonConfig:
8
9
  def __init__(self, json_config):
9
10
  self.task = json_config.get('task')
10
11
  self.dump_path = json_config.get('dump_path')
11
- self.rank = json_config.get('rank')
12
- self.step = json_config.get('step')
12
+ self.rank = get_real_step_or_rank(json_config.get('rank'), Const.RANK)
13
+ self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
13
14
  self.level = json_config.get('level')
14
- self.seed = json_config.get('seed')
15
15
  self.acl_config = json_config.get('acl_config')
16
- self.is_deterministic = json_config.get('is_deterministic', False)
17
16
  self.enable_dataloader = json_config.get('enable_dataloader', False)
18
17
  self._check_config()
19
18
 
@@ -24,21 +23,9 @@ class CommonConfig:
24
23
  if self.dump_path is not None and not isinstance(self.dump_path, str):
25
24
  logger.error_log_with_exp("dump_path is invalid, it should be a string",
26
25
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
27
- if self.rank is not None and not isinstance(self.rank, list):
28
- logger.error_log_with_exp("rank is invalid, it should be a list",
29
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
30
- if self.step is not None and not isinstance(self.step, list):
31
- logger.error_log_with_exp("step is invalid, it should be a list",
32
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
33
26
  if self.level and self.level not in Const.LEVEL_LIST:
34
27
  logger.error_log_with_exp("level is invalid, it should be one of {}".format(Const.LEVEL_LIST),
35
28
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
36
- if self.seed is not None and not isinstance(self.seed, int):
37
- logger.error_log_with_exp("seed is invalid, it should be an integer",
38
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
39
- if not isinstance(self.is_deterministic, bool):
40
- logger.error_log_with_exp("is_deterministic is invalid, it should be a boolean",
41
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
42
29
  if not isinstance(self.enable_dataloader, bool):
43
30
  logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
44
31
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
@@ -73,13 +60,19 @@ class BaseConfig:
73
60
  self.preheat_step = json_config.get("preheat_step")
74
61
  self.max_sample = json_config.get("max_sample")
75
62
 
63
+ @staticmethod
64
+ def _check_str_list_config(config_item, config_name):
65
+ if config_item is not None:
66
+ if not isinstance(config_item, list):
67
+ logger.error_log_with_exp(f"{config_name} is invalid, it should be a list[str]",
68
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
69
+ for name in config_item:
70
+ if not isinstance(name, str):
71
+ logger.error_log_with_exp(f"{config_name} is invalid, it should be a list[str]",
72
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
73
+
76
74
  def check_config(self):
77
- if self.scope is not None and not isinstance(self.scope, list):
78
- logger.error_log_with_exp("scope is invalid, it should be a list",
79
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
80
- if self.list is not None and not isinstance(self.list, list):
81
- logger.error_log_with_exp("list is invalid, it should be a list",
82
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
83
- if self.data_mode is not None and not isinstance(self.data_mode, list):
84
- logger.error_log_with_exp("data_mode is invalid, it should be a list",
85
- MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
75
+ self._check_str_list_config(self.scope, "scope")
76
+ self._check_str_list_config(self.list, "list")
77
+ self._check_str_list_config(self.data_mode, "data_mode")
78
+ self._check_str_list_config(self.backward_input, "backward_input")