mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,86 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import multiprocessing
17
+ from multiprocessing.shared_memory import SharedMemory
18
+ import random
19
+ import time
20
+ import atexit
21
+ import os
22
+
23
+ from msprobe.core.common.log import logger
24
+
25
+
26
+ def is_main_process():
27
+ return multiprocessing.current_process().name == 'MainProcess'
28
+
29
+
30
+ class GlobalLock:
31
+ def __init__(self):
32
+ self.name = self.get_lock_name()
33
+ try:
34
+ self._shm = SharedMemory(create=False, name=self.name)
35
+ time.sleep(random.randint(0, 500) / 10000) # 等待随机时长以避免同时获得锁
36
+ except FileNotFoundError:
37
+ try:
38
+ self._shm = SharedMemory(create=True, name=self.name, size=1)
39
+ self._shm.buf[0] = 0
40
+ logger.debug(f'{self.name} is created.')
41
+ except FileExistsError:
42
+ self.__init__()
43
+
44
+ @classmethod
45
+ def get_lock_name(cls):
46
+ if is_main_process():
47
+ return f'global_lock_{os.getpid()}'
48
+ return f'global_lock_{os.getppid()}'
49
+
50
+ @classmethod
51
+ def is_lock_exist(cls):
52
+ try:
53
+ SharedMemory(create=False, name=cls.get_lock_name()).close()
54
+ return True
55
+ except FileNotFoundError:
56
+ return False
57
+
58
+ def cleanup(self):
59
+ self._shm.close()
60
+ if is_main_process():
61
+ try:
62
+ self._shm.unlink()
63
+ logger.debug(f'{self.name} is unlinked.')
64
+ except FileNotFoundError:
65
+ logger.warning(f'{self.name} has already been unlinked.')
66
+
67
+ def acquire(self, timeout=180):
68
+ """
69
+ acquire global lock, default timeout is 3 minutes.
70
+
71
+ :param float timeout: timeout(seconds), default value is 180.
72
+ """
73
+ start = time.time()
74
+ while time.time() - start < timeout:
75
+ if self._shm.buf[0] == 0:
76
+ self._shm.buf[0] = 1
77
+ return
78
+ time.sleep(random.randint(10, 500) / 10000) # 自旋,等待1-50ms
79
+ self._shm.buf[0] = 1
80
+
81
+ def release(self):
82
+ self._shm.buf[0] = 0
83
+
84
+
85
+ global_lock = GlobalLock()
86
+ atexit.register(global_lock.cleanup)
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.common.const import Const
17
+
18
+
19
+ class Runtime:
20
+ step_count: int = 0
21
+ rank_id: int = -1
22
+ is_running: bool = False
23
+ run_mode: str = Const.PYNATIVE_MODE
24
+ current_iter: int = 0
25
+ current_rank: None
@@ -18,9 +18,8 @@ import os
18
18
  import re
19
19
  import subprocess
20
20
  import time
21
- from collections import defaultdict
21
+ import inspect
22
22
  from datetime import datetime, timezone
23
- from functools import wraps
24
23
 
25
24
  import numpy as np
26
25
 
@@ -28,10 +27,15 @@ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_pa
28
27
  from msprobe.core.common.const import Const, CompareConst
29
28
  from msprobe.core.common.log import logger
30
29
  from msprobe.core.common.exceptions import MsprobeException
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
31
31
 
32
32
 
33
33
  device = collections.namedtuple('device', ['type', 'index'])
34
34
  prefixes = ['api_stack', 'list', 'range', 'acl']
35
+ file_suffix_to_file_type = {
36
+ "dump.json": Const.DUMP_JSON_FILE,
37
+ "debug.json": Const.DEBUG_JSON_FILE,
38
+ }
35
39
 
36
40
 
37
41
  class MsprobeBaseException(Exception):
@@ -75,6 +79,8 @@ class MsprobeBaseException(Exception):
75
79
  MERGE_COMPARE_RESULT_ERROR = 33
76
80
  NAMES_STRUCTS_MATCH_ERROR = 34
77
81
  INVALID_STATE_ERROR = 35
82
+ INVALID_API_NAME_ERROR = 36
83
+ CROSS_FRAME_ERROR = 37
78
84
 
79
85
  def __init__(self, code, error_info: str = ""):
80
86
  super(MsprobeBaseException, self).__init__()
@@ -191,27 +197,6 @@ def check_regex_prefix_format_valid(prefix):
191
197
  raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
192
198
 
193
199
 
194
- def execute_command(cmd):
195
- """
196
- Function Description:
197
- run the following command
198
- Parameter:
199
- cmd: command
200
- Exception Description:
201
- when invalid command throw exception
202
- """
203
- logger.info('Execute command:%s' % cmd)
204
- process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
205
- while process.poll() is None:
206
- line = process.stdout.readline()
207
- line = line.strip()
208
- if line:
209
- logger.info(line)
210
- if process.returncode != 0:
211
- logger.error('Failed to execute command:%s' % " ".join(cmd))
212
- raise CompareException(CompareException.INVALID_DATA_ERROR)
213
-
214
-
215
200
  def add_time_as_suffix(name):
216
201
  return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
217
202
 
@@ -232,21 +217,41 @@ def format_value(value):
232
217
  return float('{:.12f}'.format(value))
233
218
 
234
219
 
235
- def md5_find(data):
236
- for key_op in data:
237
- for api_info in data[key_op]:
238
- if isinstance(data[key_op][api_info], list):
239
- for data_detail in data[key_op][api_info]:
240
- if data_detail and 'md5' in data_detail:
241
- return True
242
- if isinstance(data[key_op][api_info], bool):
243
- continue
244
- elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
220
+ @recursion_depth_decorator('msprobe.core.common.utils.md5_find', max_depth=Const.DUMP_MAX_DEPTH)
221
+ def md5_find(data, json_type=Const.DUMP_JSON_FILE):
222
+ if json_type == Const.DUMP_JSON_FILE:
223
+ for key_op in data:
224
+ for api_info in data[key_op]:
225
+ if isinstance(data[key_op][api_info], list):
226
+ for data_detail in data[key_op][api_info]:
227
+ if data_detail and Const.MD5 in data_detail:
228
+ return True
229
+ if isinstance(data[key_op][api_info], bool):
230
+ continue
231
+ elif data[key_op][api_info] and Const.MD5 in data[key_op][api_info]:
232
+ return True
233
+ elif json_type == Const.DEBUG_JSON_FILE:
234
+ if isinstance(data, dict):
235
+ if Const.MD5 in data:
245
236
  return True
237
+ else:
238
+ for _, data_info in data.items():
239
+ if md5_find(data_info, Const.DEBUG_JSON_FILE):
240
+ return True
241
+ elif isinstance(data, list):
242
+ for data_info in data:
243
+ if md5_find(data_info, Const.DEBUG_JSON_FILE):
244
+ return True
245
+ else:
246
+ return False
246
247
  return False
247
248
 
248
249
 
249
250
  def detect_framework_by_dump_json(file_path):
251
+ json_data = load_json(file_path)
252
+ framework = json_data.get("framework", None)
253
+ if framework in [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]:
254
+ return framework
250
255
  pattern_ms = r'"type":\s*"mindspore'
251
256
  pattern_pt = r'"type":\s*"torch'
252
257
  with FileOpen(file_path, 'r') as file:
@@ -276,13 +281,26 @@ def get_stack_construct_by_dump_json_path(dump_json_path):
276
281
  def set_dump_path(input_param):
277
282
  npu_path = input_param.get("npu_json_path", None)
278
283
  bench_path = input_param.get("bench_json_path", None)
279
- npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
280
- bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
281
- if not npu_path_valid or not bench_path_valid:
282
- logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}")
284
+ dump_json_path_valid = npu_path is not None and npu_path.endswith("dump.json") and \
285
+ bench_path is not None and bench_path.endswith("dump.json")
286
+ debug_json_path_valid = npu_path is not None and npu_path.endswith("debug.json") and \
287
+ bench_path is not None and bench_path.endswith("debug.json")
288
+ if not dump_json_path_valid and not debug_json_path_valid:
289
+ logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.")
290
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
291
+ input_param[CompareConst.NPU_DUMP_DATA_DIR] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
292
+ input_param[CompareConst.BENCH_DUMP_DATA_DIR] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
293
+
294
+
295
+ def get_file_type(file_path):
296
+ if not isinstance(file_path, str):
297
+ logger.error("get_file_type failed, check the type of file_path.")
298
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
299
+ file_type = file_suffix_to_file_type.get(file_path.split(Const.SCOPE_SEPARATOR)[-1])
300
+ if file_type is None:
301
+ logger.error("get_file_type failed, file_path is neither dump.json nor debug.json.")
283
302
  raise CompareException(CompareException.INVALID_PATH_ERROR)
284
- input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
285
- input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
303
+ return file_type
286
304
 
287
305
 
288
306
  def get_dump_mode(input_param):
@@ -290,6 +308,7 @@ def get_dump_mode(input_param):
290
308
  bench_path = input_param.get("bench_json_path", None)
291
309
  npu_json_data = load_json(npu_path)
292
310
  bench_json_data = load_json(bench_path)
311
+ json_type = get_file_type(file_path=npu_path)
293
312
 
294
313
  npu_task = npu_json_data.get('task', None)
295
314
  bench_task = bench_json_data.get('task', None)
@@ -309,8 +328,8 @@ def get_dump_mode(input_param):
309
328
  return Const.STRUCTURE
310
329
 
311
330
  if npu_task == Const.STATISTICS:
312
- npu_md5_compare = md5_find(npu_json_data['data'])
313
- bench_md5_compare = md5_find(bench_json_data['data'])
331
+ npu_md5_compare = md5_find(npu_json_data['data'], json_type)
332
+ bench_md5_compare = md5_find(bench_json_data['data'], json_type)
314
333
  if npu_md5_compare == bench_md5_compare:
315
334
  return Const.MD5 if npu_md5_compare else Const.SUMMARY
316
335
  else:
@@ -424,6 +443,37 @@ def get_real_step_or_rank(step_or_rank_input, obj):
424
443
  return real_step_or_rank
425
444
 
426
445
 
446
+ def check_init_step(step):
447
+ if not is_int(step):
448
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
449
+ f"{step} must be an integer")
450
+ if not step >= 0:
451
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
452
+ f"{step} must be greater than or equal to 0")
453
+
454
+
455
+ def check_token_range(token_range):
456
+ if token_range is None:
457
+ return
458
+ if not isinstance(token_range, (list, tuple)):
459
+ logger.error("Token_range must be a list or tuple.")
460
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
461
+ if len(token_range) != 2:
462
+ logger.error("Token_range must contains exactly 2 elements.")
463
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
464
+
465
+ start, end = token_range
466
+ if not isinstance(start, int) or not isinstance(end, int):
467
+ logger.error("Start and end in token_range must be integer.")
468
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
469
+ if start > end:
470
+ logger.error("Start in token_range must less than the end.")
471
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
472
+ if start < 0:
473
+ logger.error("Start in token_range must >= 0.")
474
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
475
+
476
+
427
477
  def check_seed_all(seed, mode, rm_dropout):
428
478
  if is_int(seed):
429
479
  if seed < 0 or seed > Const.MAX_SEED_VALUE:
@@ -467,36 +517,6 @@ def safe_get_value(container, index, container_name, key=None):
467
517
  raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
468
518
 
469
519
 
470
- # 记录工具函数递归的深度
471
- recursion_depth = defaultdict(int)
472
-
473
-
474
- # 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
475
- def recursion_depth_decorator(func_info):
476
- def decorator(func):
477
- @wraps(func)
478
- def wrapper(*args, **kwargs):
479
- func_id = id(func)
480
- recursion_depth[func_id] += 1
481
- if recursion_depth[func_id] > Const.MAX_DEPTH:
482
- msg = f"call {func_info} exceeds the recursion limit."
483
- logger.error_log_with_exp(
484
- msg,
485
- MsprobeException(
486
- MsprobeException.RECURSION_LIMIT_ERROR, msg
487
- ),
488
- )
489
- try:
490
- result = func(*args, **kwargs)
491
- finally:
492
- recursion_depth[func_id] -= 1
493
- return result
494
-
495
- return wrapper
496
-
497
- return decorator
498
-
499
-
500
520
  def check_str_param(param):
501
521
  if not re.match(Const.REGEX_PREFIX_PATTERN, param):
502
522
  logger.error('The parameter {} contains special characters.'.format(param))
@@ -509,4 +529,60 @@ class DumpPathAggregation:
509
529
  construct_file_path = None
510
530
  dump_tensor_data_dir = None
511
531
  free_benchmark_file_path = None
512
- debug_file_path = None
532
+ debug_file_path = None
533
+
534
+
535
+ def is_save_variable_valid(variable, valid_special_types, depth=0):
536
+ if depth > Const.DUMP_MAX_DEPTH:
537
+ return False
538
+ if isinstance(variable, valid_special_types):
539
+ return True
540
+ elif isinstance(variable, (list, tuple)):
541
+ return all(is_save_variable_valid(item, valid_special_types, depth + 1) for item in variable)
542
+ elif isinstance(variable, dict):
543
+ return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1)
544
+ for key, value in variable.items())
545
+ else:
546
+ return False
547
+
548
+
549
+ def replace_last_occurrence(text, old, new):
550
+ if text is None:
551
+ return text
552
+ index = text.rfind(old)
553
+ if index != -1:
554
+ return text[:index] + text[index:].replace(old, new, 1)
555
+ return text
556
+
557
+
558
+ def load_stack_json(stack_path):
559
+ stack_dict = load_json(stack_path)
560
+ if not stack_dict.get(Const.NEW_STACK_FLAG):
561
+ return stack_dict
562
+
563
+ new_stack_dict = {}
564
+ for stack_info in stack_dict.values():
565
+ if len(stack_info) != 2:
566
+ continue
567
+ api_list, stack_str = stack_info
568
+ for api_name in api_list:
569
+ new_stack_dict.update({api_name: stack_str})
570
+ return new_stack_dict
571
+
572
+
573
+ def analyze_api_call_stack(name):
574
+ try:
575
+ api_stack = inspect.stack()[2:]
576
+ except Exception as e:
577
+ logger.warning(f"The call stack of {name} failed to retrieve, {e}.")
578
+ api_stack = None
579
+ stack_str = []
580
+ if api_stack:
581
+ for (_, path, line, func, code, _) in api_stack:
582
+ if not code:
583
+ continue
584
+ stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()} \n"
585
+ stack_str.append(stack_line)
586
+ else:
587
+ stack_str.append(Const.WITHOUT_CALL_STACK)
588
+ return "".join(stack_str)
@@ -111,3 +111,10 @@ class BaseConfig:
111
111
  f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.",
112
112
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
113
113
  )
114
+
115
+ def _check_summary_mode(self):
116
+ if self.summary_mode and self.summary_mode not in Const.SUMMARY_MODE:
117
+ logger.error_log_with_exp(
118
+ f"summary_mode is invalid, summary_mode is not in {Const.SUMMARY_MODE}.",
119
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
120
+ )