mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,28 +12,32 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import collections
18
17
  import os
19
18
  import re
20
19
  import subprocess
21
20
  import time
22
- import json
21
+ from collections import defaultdict
23
22
  from datetime import datetime, timezone
23
+ from functools import wraps
24
24
 
25
- from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path)
25
+ import numpy as np
26
+
27
+ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
26
28
  from msprobe.core.common.const import Const, CompareConst
27
29
  from msprobe.core.common.log import logger
28
-
30
+ from msprobe.core.common.exceptions import MsprobeException
29
31
 
30
32
  device = collections.namedtuple('device', ['type', 'index'])
31
33
  prefixes = ['api_stack', 'list', 'range', 'acl']
32
34
 
33
35
 
34
- class CompareException(Exception):
36
+ class MsprobeBaseException(Exception):
35
37
  """
36
- Class for Accuracy Compare Exception
38
+ Base class for all custom exceptions.
37
39
  """
40
+ # 所有的错误代码
38
41
  NONE_ERROR = 0
39
42
  INVALID_PATH_ERROR = 1
40
43
  OPEN_FILE_ERROR = 2
@@ -57,10 +60,20 @@ class CompareException(Exception):
57
60
  INVALID_SUMMARY_MODE = 19
58
61
  INVALID_TASK_ERROR = 20
59
62
  DETACH_ERROR = 21
60
-
63
+ INVALID_OBJECT_TYPE_ERROR = 22
64
+ INVALID_CHAR_ERROR = 23
65
+ RECURSION_LIMIT_ERROR = 24
66
+ INVALID_ATTRIBUTE_ERROR = 25
67
+ OUTPUT_HOOK_ERROR = 26
68
+ INPUT_HOOK_ERROR = 27
69
+ FUNCTION_CALL_ERROR = 28
70
+ FORWARD_DATA_COLLECTION_ERROR = 29
71
+ BACKWARD_DATA_COLLECTION_ERROR = 30
72
+ INVALID_KEY_ERROR = 31
73
+ MISSING_HEADER_ERROR = 32
61
74
 
62
75
  def __init__(self, code, error_info: str = ""):
63
- super(CompareException, self).__init__()
76
+ super(MsprobeBaseException, self).__init__()
64
77
  self.code = code
65
78
  self.error_info = error_info
66
79
 
@@ -68,80 +81,55 @@ class CompareException(Exception):
68
81
  return self.error_info
69
82
 
70
83
 
71
- class DumpException(CompareException):
72
- pass
73
-
74
-
75
- def check_mode_valid(mode, scope=None, api_list=None):
76
- if scope is None:
77
- scope = []
78
- if api_list is None:
79
- api_list = []
80
- if not isinstance(scope, list):
81
- raise ValueError("scope param set invalid, it's must be a list.")
82
- if not isinstance(api_list, list):
83
- raise ValueError("api_list param set invalid, it's must be a list.")
84
- mode_check = {
85
- Const.ALL: lambda: None,
86
- Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
87
- Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
88
- Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
89
- Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
90
- Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
91
- Const.API_STACK: lambda: None,
92
- }
93
- if mode not in Const.DUMP_MODE:
94
- msg = "Current mode '%s' is not supported. Please use the field in %s" % \
95
- (mode, Const.DUMP_MODE)
96
- raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
97
-
98
- if mode_check.get(mode)() is not None:
99
- raise mode_check.get(mode)()
100
-
101
-
102
- def check_switch_valid(switch):
103
- if switch not in ["ON", "OFF"]:
104
- logger.error("Please set switch with 'ON' or 'OFF'.")
105
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
84
+ class CompareException(MsprobeBaseException):
85
+ """
86
+ Class for Accuracy Compare Exception
87
+ """
106
88
 
89
+ def __init__(self, code, error_info: str = ""):
90
+ super(CompareException, self).__init__(code, error_info)
107
91
 
108
- def check_dump_mode_valid(dump_mode):
109
- if not isinstance(dump_mode, list):
110
- logger.warning("Please set dump_mode as a list.")
111
- dump_mode = [dump_mode]
112
- if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
113
- raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
114
- if 'input' not in dump_mode and 'output' not in dump_mode:
115
- dump_mode.extend(['input', 'output'])
116
- if 'forward' not in dump_mode and 'backward' not in dump_mode:
117
- dump_mode.extend(['forward', 'backward'])
118
- if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
119
- return ["forward", "backward", "input", "output"]
120
- return dump_mode
121
92
 
93
+ class DumpException(MsprobeBaseException):
94
+ """
95
+ Class for Dump Exception
96
+ """
122
97
 
123
- def check_summary_mode_valid(summary_mode):
124
- if summary_mode not in Const.SUMMARY_MODE:
125
- msg = "The summary_mode is not valid"
126
- raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
98
+ def __init__(self, code, error_info: str = ""):
99
+ super(DumpException, self).__init__(code, error_info)
127
100
 
101
+ def __str__(self):
102
+ return f"Dump Error Code {self.code}: {self.error_info}"
128
103
 
129
- def check_summary_only_valid(summary_only):
130
- if not isinstance(summary_only, bool):
131
- logger.error("Params summary_only only support True or False.")
132
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
133
- return summary_only
104
+
105
+ def is_json_file(file_path):
106
+ if isinstance(file_path, str) and file_path.lower().endswith('.json'):
107
+ return True
108
+ else:
109
+ return False
134
110
 
135
111
 
136
- def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
137
- if not (isinstance(input_param, dict) and isinstance(output_path, str)):
138
- logger.error("Invalid input parameters")
112
+ def check_compare_param(input_param, output_path, dump_mode):
113
+ if not isinstance(input_param, dict):
114
+ logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
139
115
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
116
+ if not isinstance(output_path, str):
117
+ logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
118
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
119
+
120
+ def check_json_path(json_path_str):
121
+ json_path = input_param.get(json_path_str)
122
+ check_file_or_directory_path(json_path, False)
123
+ json_type_check = is_json_file(json_path)
124
+ if not json_type_check:
125
+ logger.error(f"Invalid {json_path_str}: {json_path}, please check!")
126
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
127
+
128
+ check_json_path("npu_json_path")
129
+ check_json_path("bench_json_path")
130
+ check_json_path("stack_json_path")
140
131
 
141
- check_file_or_directory_path(input_param.get("npu_json_path"), False)
142
- check_file_or_directory_path(input_param.get("bench_json_path"), False)
143
- check_file_or_directory_path(input_param.get("stack_json_path"), False)
144
- if not summary_compare and not md5_compare:
132
+ if dump_mode == Const.ALL:
145
133
  check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
146
134
  check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
147
135
  check_file_or_directory_path(output_path, True)
@@ -152,15 +140,12 @@ def check_compare_param(input_param, output_path, summary_compare=False, md5_com
152
140
  check_json_file(input_param, npu_json, bench_json, stack_json)
153
141
 
154
142
 
155
-
156
- def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
157
- if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)):
158
- logger.error("Invalid input parameters which should be only bool type.")
159
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
160
-
161
-
162
- def is_starts_with(string, prefix_list):
163
- return any(string.startswith(prefix) for prefix in prefix_list)
143
+ def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
144
+ arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
145
+ for arg in arg_list:
146
+ if not isinstance(arg, bool):
147
+ logger.error(f"Invalid input parameter, {arg} which should be only bool type.")
148
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
164
149
 
165
150
 
166
151
  def _check_json(json_file_handle, file_name):
@@ -198,28 +183,6 @@ def check_regex_prefix_format_valid(prefix):
198
183
  raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
199
184
 
200
185
 
201
- def get_dump_data_path(dump_dir):
202
- """
203
- Function Description:
204
- traverse directories and obtain the absolute path of dump data
205
- Parameter:
206
- dump_dir: dump data directory
207
- Return Value:
208
- dump data path,file is exist or file is not exist
209
- """
210
- dump_data_path = None
211
- file_is_exist = False
212
-
213
- check_file_or_directory_path(dump_dir, True)
214
- for dir_path, _, files in os.walk(dump_dir):
215
- if len(files) != 0:
216
- dump_data_path = dir_path
217
- file_is_exist = True
218
- break
219
- dump_data_path = dir_path
220
- return dump_data_path, file_is_exist
221
-
222
-
223
186
  def execute_command(cmd):
224
187
  """
225
188
  Function Description:
@@ -235,28 +198,12 @@ def execute_command(cmd):
235
198
  line = process.stdout.readline()
236
199
  line = line.strip()
237
200
  if line:
238
- print(line)
201
+ logger.info(line)
239
202
  if process.returncode != 0:
240
203
  logger.error('Failed to execute command:%s' % " ".join(cmd))
241
204
  raise CompareException(CompareException.INVALID_DATA_ERROR)
242
205
 
243
206
 
244
- def parse_value_by_comma(value):
245
- """
246
- parse value by comma, like '1,2,4,8'
247
- """
248
- value_list = []
249
- value_str_list = value.split(Const.COMMA)
250
- for value_str in value_str_list:
251
- value_str = value_str.strip()
252
- if value_str.isdigit() or value_str == '-1':
253
- value_list.append(int(value_str))
254
- else:
255
- logger.error("please check your input shape.")
256
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
257
- return value_list
258
-
259
-
260
207
  def add_time_as_suffix(name):
261
208
  return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
262
209
 
@@ -265,6 +212,10 @@ def add_time_with_xlsx(name):
265
212
  return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
266
213
 
267
214
 
215
+ def add_time_with_yaml(name):
216
+ return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
217
+
218
+
268
219
  def get_time():
269
220
  return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
270
221
 
@@ -273,61 +224,6 @@ def format_value(value):
273
224
  return float('{:.12f}'.format(value))
274
225
 
275
226
 
276
- def check_seed_all(seed, mode):
277
- if isinstance(seed, int):
278
- if seed < 0 or seed > Const.MAX_SEED_VALUE:
279
- logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
280
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
281
- else:
282
- logger.error(f"Seed must be integer.")
283
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
284
- if not isinstance(mode, bool):
285
- logger.error(f"seed_all mode must be bool.")
286
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
287
-
288
-
289
- def get_process_rank(model):
290
- logger.info("Rank id is not provided. Trying to get the rank id of the model.")
291
- try:
292
- local_device = next(model.parameters()).device
293
- except StopIteration:
294
- logger.warning('There is no parameter in the model. Fail to get rank id.')
295
- return 0, False
296
- if local_device.type == 'cpu':
297
- logger.warning("Warning: the debugger is unable to get the rank id. "
298
- "This may cause the dumpped data to be corrupted in the "
299
- "case of distributed training. (You may ignore this if you are using only one card.) "
300
- "Transfer the model to npu or gpu before register_hook() to avoid this warning.")
301
- return 0, False
302
- else:
303
- return local_device.index, True
304
-
305
-
306
- def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
307
- template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
308
- pkl_dir = os.path.dirname(pkl_file_path)
309
- compare_script_path = os.path.join(pkl_dir, "compare_data.py")
310
- is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
311
-
312
- try:
313
- with FileOpen(template_path, 'r') as ftemp, \
314
- os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
315
- code_temp = ftemp.read()
316
- fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
317
- except OSError:
318
- logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
319
-
320
- logger.info(f"Generate compare script successfully which is {compare_script_path}.")
321
-
322
-
323
- def check_inplace_op(prefix):
324
- if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
325
- return False
326
- match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
327
- op_name = match_op[0] if match_op else None
328
- return op_name in Const.INPLACE_LIST
329
-
330
-
331
227
  def md5_find(data):
332
228
  for key_op in data:
333
229
  for api_info in data[key_op]:
@@ -335,46 +231,89 @@ def md5_find(data):
335
231
  for data_detail in data[key_op][api_info]:
336
232
  if data_detail and 'md5' in data_detail:
337
233
  return True
338
- elif 'md5' in data[key_op][api_info]:
234
+ elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
339
235
  return True
340
236
  return False
341
237
 
342
238
 
343
- def task_dumppath_get(input_param):
239
+ def detect_framework_by_dump_json(file_path):
240
+ pattern_ms = r'"type":\s*"mindspore'
241
+ pattern_pt = r'"type":\s*"torch'
242
+ with FileOpen(file_path, 'r') as file:
243
+ for line in file:
244
+ if re.search(pattern_ms, line):
245
+ return Const.MS_FRAMEWORK
246
+ if re.search(pattern_pt, line):
247
+ return Const.PT_FRAMEWORK
248
+ logger.error(f"{file_path} must be based on the MindSpore or PyTorch framework.")
249
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
250
+
251
+
252
+ def get_stack_construct_by_dump_json_path(dump_json_path):
253
+ if not dump_json_path:
254
+ logger.error("The path is empty. Please enter a valid path.")
255
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
256
+ directory = os.path.dirname(dump_json_path)
257
+ check_file_or_directory_path(directory, True)
258
+ stack_json = os.path.join(directory, "stack.json")
259
+ construct_json = os.path.join(directory, "construct.json")
260
+
261
+ stack = load_json(stack_json)
262
+ construct = load_json(construct_json)
263
+ return stack, construct
264
+
265
+
266
+ def set_dump_path(input_param):
344
267
  npu_path = input_param.get("npu_json_path", None)
345
268
  bench_path = input_param.get("bench_json_path", None)
346
- if not npu_path or not bench_path:
347
- logger.error(f"Please check the json path is valid.")
269
+ npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
270
+ bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
271
+ if not npu_path_valid or not bench_path_valid:
272
+ logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}")
348
273
  raise CompareException(CompareException.INVALID_PATH_ERROR)
349
- with FileOpen(npu_path, 'r') as npu_f:
350
- npu_json_data = json.load(npu_f)
351
- with FileOpen(bench_path, 'r') as bench_f:
352
- bench_json_data = json.load(bench_f)
353
- if npu_json_data['task'] != bench_json_data['task']:
274
+ input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
275
+ input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
276
+
277
+
278
+ def get_dump_mode(input_param):
279
+ npu_path = input_param.get("npu_json_path", None)
280
+ bench_path = input_param.get("bench_json_path", None)
281
+ npu_json_data = load_json(npu_path)
282
+ bench_json_data = load_json(bench_path)
283
+
284
+ npu_task = npu_json_data.get('task', None)
285
+ bench_task = bench_json_data.get('task', None)
286
+
287
+ if not npu_task or not bench_task:
288
+ logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.")
289
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
290
+
291
+ if npu_task != bench_task:
354
292
  logger.error(f"Please check the dump task is consistent.")
355
293
  raise CompareException(CompareException.INVALID_TASK_ERROR)
356
- if npu_json_data['task'] == Const.TENSOR:
357
- summary_compare = False
358
- md5_compare = False
359
- elif npu_json_data['task'] == Const.STATISTICS:
360
- md5_compare = md5_find(npu_json_data['data'])
361
- if md5_compare:
362
- summary_compare = False
294
+
295
+ if npu_task == Const.TENSOR:
296
+ return Const.ALL
297
+
298
+ if npu_task == Const.STATISTICS:
299
+ npu_md5_compare = md5_find(npu_json_data['data'])
300
+ bench_md5_compare = md5_find(bench_json_data['data'])
301
+ if npu_md5_compare == bench_md5_compare:
302
+ return Const.MD5 if npu_md5_compare else Const.SUMMARY
363
303
  else:
364
- summary_compare = True
365
- else:
366
- logger.error(f"Compare is not required for overflow_check or free_benchmark.")
367
- raise CompareException(CompareException.INVALID_TASK_ERROR)
368
- input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
369
- input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
370
- return summary_compare, md5_compare
304
+ logger.error(f"Please check the dump task is consistent, "
305
+ f"dump mode of npu and bench should both be statistics or md5.")
306
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
371
307
 
308
+ logger.error(f"Compare applies only to task is tensor or statistics")
309
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
372
310
 
373
- def get_header_index(header_name, summary_compare=False):
374
- if summary_compare:
375
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
376
- else:
377
- header = CompareConst.COMPARE_RESULT_HEADER[:]
311
+
312
+ def get_header_index(header_name, dump_mode):
313
+ header = CompareConst.HEAD_OF_COMPARE_MODE.get(dump_mode)
314
+ if not header:
315
+ logger.error(f"{dump_mode} not in {CompareConst.HEAD_OF_COMPARE_MODE}")
316
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
378
317
  if header_name not in header:
379
318
  logger.error(f"{header_name} not in data name")
380
319
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -382,4 +321,164 @@ def get_header_index(header_name, summary_compare=False):
382
321
 
383
322
 
384
323
  def convert_tuple(data):
385
- return data if isinstance(data, tuple) else (data, )
324
+ return data if isinstance(data, tuple) else (data,)
325
+
326
+
327
+ def check_op_str_pattern_valid(string, op_name=None, stack=False):
328
+ if isinstance(string, str) and is_invalid_pattern(string):
329
+ if stack:
330
+ message = f"stack info of {op_name} contains special characters, please check!"
331
+ elif not op_name:
332
+ message = f"api name contains special characters, please check!"
333
+ else:
334
+ message = f"data info of {op_name} contains special characters, please check!"
335
+ logger.error(message)
336
+ raise CompareException(CompareException.INVALID_CHAR_ERROR)
337
+
338
+
339
+ def is_invalid_pattern(string):
340
+ pattern = Const.STRING_BLACKLIST
341
+ return re.search(pattern, string)
342
+
343
+
344
+ def is_int(x):
345
+ return isinstance(x, int) and not isinstance(x, bool)
346
+
347
+
348
+ def print_tools_ends_info():
349
+ total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
350
+ logger.info('*' * total_len)
351
+ logger.info(f"*{Const.TOOL_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
352
+ logger.info('*' * total_len)
353
+
354
+
355
+ def get_step_or_rank_from_string(step_or_rank, obj):
356
+ splited = step_or_rank.split(Const.HYPHEN)
357
+ if len(splited) == 2:
358
+ try:
359
+ borderlines = int(splited[0]), int(splited[1])
360
+ except (ValueError, IndexError) as e:
361
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
362
+ "The hyphen(-) must start and end with decimal numbers.") from e
363
+ else:
364
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
365
+ f'The string parameter for {obj} only supports formats like "3-5". '
366
+ f'Now string parameter for {obj} is "{step_or_rank}".')
367
+ if all(Const.STEP_RANK_MINIMUM_VALUE <= b <= Const.STEP_RANK_MAXIMUM_VALUE for b in borderlines):
368
+ if borderlines[0] <= borderlines[1]:
369
+ continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
370
+ else:
371
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
372
+ f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be '
373
+ f'greater than the right boundary ({borderlines[1]}).')
374
+ else:
375
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
376
+ f"The boundaries must fall within the range of "
377
+ f"[{Const.STEP_RANK_MINIMUM_VALUE}, {Const.STEP_RANK_MAXIMUM_VALUE}].")
378
+ return continual_step_or_rank
379
+
380
+
381
+ def get_real_step_or_rank(step_or_rank_input, obj):
382
+ if obj not in [Const.STEP, Const.RANK]:
383
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
384
+ f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
385
+ if step_or_rank_input is None:
386
+ return []
387
+ if not isinstance(step_or_rank_input, list):
388
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
389
+ if len(step_or_rank_input) > Const.STEP_RANK_MAXIMUM_VALUE:
390
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
391
+ f"{obj} is invalid, its length cannot exceed {Const.STEP_RANK_MAXIMUM_VALUE}")
392
+
393
+ real_step_or_rank = []
394
+ for element in step_or_rank_input:
395
+ if not is_int(element) and not isinstance(element, str):
396
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
397
+ f"{obj} element {element} must be an integer or string.")
398
+ if isinstance(element, int) and element < 0:
399
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
400
+ f"Each element of {obj} must be non-negative, currently it is {element}.")
401
+ if isinstance(element, int) and Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
402
+ real_step_or_rank.append(element)
403
+ elif isinstance(element, str) and Const.HYPHEN in element:
404
+ continual_step_or_rank = get_step_or_rank_from_string(element, obj)
405
+ real_step_or_rank.extend(continual_step_or_rank)
406
+ real_step_or_rank = list(set(real_step_or_rank))
407
+ real_step_or_rank.sort()
408
+ return real_step_or_rank
409
+
410
+
411
+ def check_seed_all(seed, mode):
412
+ if is_int(seed):
413
+ if seed < 0 or seed > Const.MAX_SEED_VALUE:
414
+ logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
415
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
416
+ else:
417
+ logger.error("Seed must be integer.")
418
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
419
+ if not isinstance(mode, bool):
420
+ logger.error("seed_all mode must be bool.")
421
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
422
+
423
+
424
+ def safe_get_value(container, index, container_name, key=None):
425
+ try:
426
+ # 处理字典情况
427
+ if isinstance(container, dict):
428
+ return container.get(key)[index]
429
+ # 处理列表、元组、numpy情况
430
+ elif isinstance(container, (list, tuple, np.ndarray)):
431
+ return container[index]
432
+ else:
433
+ err_msg = f"Unsupported container type for '{container_name}': {type(container)}"
434
+ logger.error(err_msg)
435
+ raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR)
436
+ except IndexError as e:
437
+ err_msg = "index out of bounds error occurs, please check!\n" \
438
+ f"{container_name} is {container}\n" \
439
+ f"index is {index}"
440
+ logger.error(err_msg)
441
+ raise MsprobeBaseException(MsprobeBaseException.INDEX_OUT_OF_BOUNDS_ERROR) from e
442
+ except TypeError as e:
443
+ err_msg = "wrong type, please check!\n" \
444
+ f"{container_name} is {container}\n" \
445
+ f"index is {index}\n" \
446
+ f"key is {key}"
447
+ logger.error(err_msg)
448
+ raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
449
+
450
+
451
+ # 记录工具函数递归的深度
452
+ recursion_depth = defaultdict(int)
453
+
454
+
455
+ # 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
456
+ def recursion_depth_decorator(func_info):
457
+ def decorator(func):
458
+ @wraps(func)
459
+ def wrapper(*args, **kwargs):
460
+ func_id = id(func)
461
+ recursion_depth[func_id] += 1
462
+ if recursion_depth[func_id] > Const.MAX_DEPTH:
463
+ msg = f"call {func_info} exceeds the recursion limit."
464
+ logger.error_log_with_exp(
465
+ msg,
466
+ MsprobeException(
467
+ MsprobeException.RECURSION_LIMIT_ERROR, msg
468
+ ),
469
+ )
470
+ try:
471
+ result = func(*args, **kwargs)
472
+ finally:
473
+ recursion_depth[func_id] -= 1
474
+ return result
475
+
476
+ return wrapper
477
+
478
+ return decorator
479
+
480
+
481
+ def check_str_param(param):
482
+ if not re.match(Const.REGEX_PREFIX_PATTERN, param):
483
+ logger.error('The parameter {} contains special characters.'.format(param))
484
+ raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR)