mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,356 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from abc import ABC, abstractmethod
18
+ import copy
19
+ from collections import defaultdict
20
+ import functools
21
+ import os
22
+
23
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
24
+ from msprobe.core.common.file_utils import create_directory
25
+ from msprobe.core.common.runtime import Runtime
26
+ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
27
+ from msprobe.core.data_dump.api_registry import ApiRegistry
28
+ from msprobe.core.data_dump.data_collector import build_data_collector
29
+ from msprobe.core.hook_manager import BaseHookManager
30
+ from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json
31
+
32
+
33
+ class BaseService(ABC):
34
+ def __init__(self, config):
35
+ self.config = copy.deepcopy(config)
36
+ self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置
37
+ self.model = None
38
+ self.data_collector = build_data_collector(self.config)
39
+ self.attl_manager = None
40
+ self.current_iter = 0
41
+ self.loop = 0
42
+ self.init_step = 0
43
+ self.cur_token_id = 0
44
+ self.first_start = True
45
+ self.primitive_switch = False
46
+ self.current_rank = None
47
+ self.dump_iter_dir = None
48
+ self.should_stop_service = False
49
+ self.ori_customer_func = {}
50
+ self.debug_variable_counter = None
51
+ self.currrent_step_first_debug_save = True
52
+ self.logger = None # 子类中注入
53
+ self.api_register = None # 子类中注入
54
+ self.api_template = None # 子类中注入
55
+ self.hook_manager = None # 子类中注入
56
+ self._init_specific_components()
57
+ self._register_api_hook()
58
+
59
+ @property
60
+ def _is_debug_level(self):
61
+ return self.config.level == Const.LEVEL_DEBUG
62
+
63
+ @property
64
+ def _is_l2_level(self):
65
+ return self.config.level == Const.LEVEL_L2
66
+
67
+ @property
68
+ def _is_mix_level(self):
69
+ return self.config.level == Const.LEVEL_MIX
70
+
71
+ @property
72
+ def _is_need_module_hook(self):
73
+ return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]
74
+
75
+ @property
76
+ def _is_need_api_hook(self):
77
+ return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]
78
+
79
+ @property
80
+ def _is_no_dump_step(self):
81
+ return (self.config.step and self.current_iter not in self.config.step)
82
+
83
+ @property
84
+ def _is_no_dump_rank(self):
85
+ return (self.config.rank and self.current_rank not in self.config.rank)
86
+
87
+ @property
88
+ def _need_tensor_data(self):
89
+ """判断是否需要采集tensor数据"""
90
+ return bool(
91
+ self.config.task in self.data_collector.tasks_need_tensor_data or
92
+ (self.config.task == Const.STATISTICS and self.config.tensor_list)
93
+ )
94
+
95
+ @property
96
+ def _is_online_run_ut(self):
97
+ return getattr(self.config, "online_run_ut", False)
98
+
99
+ @property
100
+ @abstractmethod
101
+ def _get_framework_type(self):
102
+ """获取框架类型"""
103
+ pass
104
+
105
+ @staticmethod
106
+ @abstractmethod
107
+ def _get_current_rank():
108
+ """获取当前rank_id"""
109
+ pass
110
+
111
+ @staticmethod
112
+ def _change_jit_switch(status):
113
+ """修改JitDump开关,mindspore子类重写"""
114
+ pass
115
+
116
+ def start(self, model=None, token_range=None):
117
+ """通用start模板"""
118
+ self._process_iteration()
119
+ if self._is_debug_level:
120
+ return
121
+ if self._need_stop_service():
122
+ return
123
+ self.model = model
124
+ self.cur_token_id = 0
125
+ if self.first_start:
126
+ try:
127
+ self.current_rank = self._get_current_rank()
128
+ except DistributedNotInitializedError:
129
+ self.current_rank = None
130
+ Runtime.current_rank = self.current_rank
131
+ if self._is_no_dump_rank:
132
+ return
133
+ self._register_hook()
134
+ if self._is_need_module_hook:
135
+ self._register_module_hook()
136
+ self.first_start = False
137
+
138
+ if token_range:
139
+ self._register_infer_count_hook(self.model, token_range)
140
+ self.logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
141
+ if token_range is None:
142
+ Runtime.is_running = True
143
+ self.primitive_switch = True
144
+ self._change_jit_switch(True)
145
+ self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
146
+ if self._is_online_run_ut:
147
+ self._run_ut_dispatch(True)
148
+ else:
149
+ self.create_dirs()
150
+ self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
151
+
152
+ def stop(self):
153
+ """通用stop模板"""
154
+ if self._is_debug_level or self.should_stop_service:
155
+ return
156
+ if self._is_no_dump_step or self._is_no_dump_rank:
157
+ return
158
+ self.logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
159
+ "Please set debugger.start() to turn on the dump switch again. ")
160
+ Runtime.is_running = False
161
+ self.primitive_switch = False
162
+ self._change_jit_switch(False)
163
+ if self._is_l2_level:
164
+ return
165
+ if self._is_online_run_ut:
166
+ self._run_ut_dispatch(False)
167
+ self._process_async_dump()
168
+ self.data_collector.write_json()
169
+
170
+ def step(self):
171
+ """通用step处理"""
172
+ if self.should_stop_service:
173
+ return
174
+ self._process_async_dump()
175
+ self.data_collector.write_json()
176
+ self.currrent_step_first_debug_save = True
177
+ self.loop += 1
178
+ self._reset_status()
179
+
180
+ def save(self, variable, name, save_backward):
181
+ '''
182
+ Args:
183
+ variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
184
+ name: str
185
+ save_backward: boolean
186
+ Return:
187
+ void
188
+ '''
189
+ if not self._is_debug_level:
190
+ return
191
+ self.current_iter = self.loop + self.init_step
192
+ if self._is_no_dump_step:
193
+ return
194
+
195
+ if self.currrent_step_first_debug_save:
196
+ try:
197
+ self.current_rank = self._get_current_rank()
198
+ except DistributedNotInitializedError:
199
+ self.current_rank = None
200
+
201
+ self.create_dirs()
202
+ self.debug_variable_counter = defaultdict(int)
203
+ self.currrent_step_first_debug_save = False
204
+
205
+ count = self.debug_variable_counter[name]
206
+ self.debug_variable_counter[name] += 1
207
+
208
+ name_with_count = f"{name}.{count}"
209
+ grad_name_with_count = f"{name}_grad.{count}"
210
+
211
+ # forward save
212
+ self.data_collector.debug_data_collect_forward(variable, name_with_count)
213
+
214
+ # backward save
215
+ if save_backward:
216
+ self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
217
+
218
+ def register_custom_api(self, module, api_name, api_prefix):
219
+ self.ori_customer_func[str(module) + Const.SEP + api_name] = getattr(module, api_name)
220
+ ApiRegistry.register_custom_api(module, api_name, api_prefix,
221
+ functools.partial(self.build_hook, Const.API), self.api_template)
222
+
223
+ def restore_custom_api(self, module, api):
224
+ ori_func = self.ori_customer_func.get(str(module) + Const.SEP + api)
225
+ if ori_func:
226
+ setattr(module, api, ori_func)
227
+
228
+
229
+ def build_hook(self, hook_type, name):
230
+ return self.hook_manager.build_hook(hook_type, name)
231
+
232
+ def create_dirs(self):
233
+ """统一目录创建逻辑"""
234
+ create_directory(self.config.dump_path)
235
+ if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE:
236
+ self.dump_iter_dir = os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, f"step{self.current_iter}")
237
+ else:
238
+ self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
239
+
240
+ cur_rank = self.current_rank if self.current_rank is not None else ''
241
+ if self._is_l2_level:
242
+ self._create_l2_dirs(cur_rank)
243
+ else:
244
+ self._create_default_dirs(cur_rank)
245
+
246
+ @abstractmethod
247
+ def _init_specific_components(self):
248
+ """初始化框架特定组件"""
249
+ pass
250
+
251
+ @abstractmethod
252
+ def _register_hook(self):
253
+ """注册hook函数"""
254
+ pass
255
+
256
+ @abstractmethod
257
+ def _register_module_hook(self):
258
+ """注册模块级别的hook函数"""
259
+
260
+ def _need_stop_service(self):
261
+ if self.should_stop_service:
262
+ return True
263
+ end_service = self.config.step and self.current_iter > max(self.config.step) or \
264
+ self.data_collector and self.data_collector.data_processor.is_terminated
265
+ if end_service:
266
+ if self._is_online_run_ut and self.attl_manager:
267
+ self.attl_manager.attl_stop()
268
+ self.primitive_switch = False
269
+ self._change_jit_switch(False)
270
+ Runtime.is_running = False
271
+ self.should_stop_service = True
272
+ print_tools_ends_info()
273
+ return True
274
+ if self._is_no_dump_step:
275
+ return True
276
+ return False
277
+
278
+ def _register_api_hook(self):
279
+ if self._is_need_api_hook:
280
+ self.api_register.initialize_hook(functools.partial(self.build_hook, Const.API))
281
+ self.api_register.register_all_api()
282
+ self.logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
283
+
284
+ def _register_infer_count_hook(self, root_model, token_range):
285
+ """
286
+ 通过root_model执行的轮次来判断当前在第几个token
287
+ param root_model: 需要采集的推理模型
288
+ param token_range: [start, end], 采集infer的token循环范围,左右皆包含在内
289
+ return: None
290
+ """
291
+ def infer_hook(model, args):
292
+ if self.cur_token_id == token_range[0]:
293
+ Runtime.is_running = True
294
+ self.primitive_switch = True
295
+ self._change_jit_switch(True)
296
+ self.logger.info(f"Current token id: {self.cur_token_id}, start dump infer token.")
297
+ elif token_range[0] < self.cur_token_id <= token_range[1]:
298
+ self.logger.debug(f"Current token id: {self.cur_token_id}.")
299
+ elif self.cur_token_id == token_range[1] + 1:
300
+ Runtime.is_running = False
301
+ self.primitive_switch = False
302
+ self._change_jit_switch(False)
303
+ self.logger.info(
304
+ f"Current token id: {self.cur_token_id}, exceed token_range, early stop dump infer token.")
305
+ self.cur_token_id += 1
306
+ if isinstance(root_model, list):
307
+ root_model = root_model[0]
308
+ self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
309
+ if self._is_online_run_ut:
310
+ return
311
+ root_model.register_forward_pre_hook(infer_hook)
312
+
313
+ def _create_l2_dirs(self, cur_rank):
314
+ create_directory(self.dump_iter_dir)
315
+ kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
316
+ self.config.kernel_config_path = kernel_config_path
317
+
318
+ def _create_default_dirs(self, cur_rank):
319
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
320
+ create_directory(dump_dir)
321
+
322
+ dump_data_dir = None
323
+ if self._need_tensor_data:
324
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
325
+ create_directory(dump_data_dir)
326
+
327
+ self._configure_dump_paths(dump_dir, dump_data_dir)
328
+
329
+ def _configure_dump_paths(self, dump_dir, dump_data_dir):
330
+ dump_path_aggregation = DumpPathAggregation()
331
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
332
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
333
+ dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
334
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
335
+ dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
336
+ dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
337
+ self.data_collector.update_dump_paths(dump_path_aggregation)
338
+ self.data_collector.initialize_json_file(self._get_framework_type)
339
+
340
+ def _process_iteration(self):
341
+ """处理迭代计数"""
342
+ self.current_iter = self.loop + self.init_step
343
+ self.data_collector.update_iter(self.current_iter)
344
+ Runtime.current_iter = self.current_iter
345
+
346
+ def _process_async_dump(self):
347
+ """处理异步dump逻辑"""
348
+ if self.config.async_dump and self.config.task in [Const.STATISTICS, Const.TENSOR]:
349
+ self.data_collector.data_processor.dump_async_data()
350
+
351
+ def _reset_status(self):
352
+ """通用状态重置"""
353
+ self.data_collector.reset_status()
354
+ BaseHookManager.params_grad_info.clear()
355
+ if self._is_l2_level:
356
+ self.data_collector.data_processor.reset_status()
File without changes
@@ -0,0 +1,243 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import multiprocessing
18
+ from dataclasses import dataclass
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ from tqdm import tqdm
23
+
24
+ from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory, save_excel
25
+ from msprobe.core.common.log import logger
26
+
27
+
28
+ @dataclass
29
+ class CompareResult:
30
+ max_abs_error: float
31
+ max_relative_error: float
32
+ same_percentage: float
33
+ first_mismatch_index: int
34
+ percentage_within_thousandth: float
35
+ percentage_within_hundredth: float
36
+
37
+
38
+ class SingleComparator:
39
+ result_header = [
40
+ 'step',
41
+ 'rank',
42
+ 'micro_step',
43
+ 'id',
44
+ 'shape1',
45
+ 'shape2',
46
+ '相同元素百分比(%)',
47
+ '首个不匹配元素索引',
48
+ '最大绝对误差',
49
+ '最大相对误差',
50
+ '误差在千分之一内元素占比(%)',
51
+ '误差在百分之一内元素占比(%)'
52
+ ]
53
+
54
+ @classmethod
55
+ def compare(cls, dir1, dir2, output_path="./msprobe_compare_output", num_processes=8):
56
+ data_dir1 = os.path.join(dir1, "data")
57
+ data_dir2 = os.path.join(dir2, "data")
58
+ check_file_or_directory_path(data_dir1, isdir=True)
59
+ check_file_or_directory_path(data_dir2, isdir=True)
60
+ # 确保输出目录存在,如果不存在则创建
61
+ if not os.path.exists(output_path):
62
+ create_directory(output_path)
63
+ cls.compare_data(data_dir1, data_dir2, output_path, num_processes)
64
+
65
+ @classmethod
66
+ def compare_arrays(cls, array1, array2) -> CompareResult:
67
+ """
68
+ 比较两个NumPy数组,计算最大绝对误差、最大相对误差和相同元素的百分比
69
+ """
70
+ # 计算每个维度上的最小尺寸
71
+ min_shape = [min(s1, s2) for s1, s2 in zip(array1.shape, array2.shape)]
72
+ # 截取数组到相同的形状
73
+ sliced_array1 = array1[tuple(slice(0, s) for s in min_shape)]
74
+ sliced_array2 = array2[tuple(slice(0, s) for s in min_shape)]
75
+
76
+ abs_error = np.abs(sliced_array1 - sliced_array2)
77
+ max_abs_error = np.max(abs_error)
78
+
79
+ # 计算相对误差,处理分母为零的情况
80
+ with np.errstate(divide='ignore', invalid='ignore'):
81
+ relative_error = np.abs(sliced_array1 - sliced_array2) / \
82
+ np.maximum(np.abs(sliced_array1), np.abs(sliced_array2))
83
+ relative_error = np.nan_to_num(relative_error)
84
+ max_relative_error = np.max(relative_error)
85
+
86
+ same_elements = np.sum(sliced_array1 == sliced_array2)
87
+ total_elements = sliced_array1.size
88
+ same_percentage = (same_elements / total_elements) * 100
89
+
90
+ # 展平数组
91
+ flat_array1 = sliced_array1.flatten()
92
+ flat_array2 = sliced_array2.flatten()
93
+
94
+ # 计算从第几个元素开始对不上
95
+ mismatch_indices = np.nonzero(flat_array1 != flat_array2)[0]
96
+ first_mismatch_index = mismatch_indices[0] if mismatch_indices.size > 0 else None
97
+
98
+ # 计算误差在千分之一内的元素占比
99
+ threshold = 0.001 * np.maximum(np.abs(sliced_array1), np.abs(sliced_array2))
100
+ error_within_thousandth = np.sum(abs_error <= threshold)
101
+ percentage_within_thousandth = (error_within_thousandth / total_elements) * 100
102
+
103
+ # 计算误差在百分之一内的元素占比
104
+ threshold = 0.01 * np.maximum(np.abs(sliced_array1), np.abs(sliced_array2))
105
+ error_within_hundredth = np.sum(abs_error <= threshold)
106
+ percentage_within_hundredth = (error_within_hundredth / total_elements) * 100
107
+
108
+ return CompareResult(
109
+ max_abs_error,
110
+ max_relative_error,
111
+ same_percentage,
112
+ first_mismatch_index,
113
+ percentage_within_thousandth,
114
+ percentage_within_hundredth
115
+ )
116
+
117
+ @classmethod
118
+ def get_steps(cls, tag_path):
119
+ for step_folder in os.listdir(tag_path):
120
+ if step_folder.startswith('step'):
121
+ try:
122
+ step = int(step_folder[4:])
123
+ except Exception as e:
124
+ raise RuntimeError(f"parse step number error") from e
125
+ yield step, os.path.join(tag_path, step_folder)
126
+
127
+ @classmethod
128
+ def get_ranks(cls, step_path):
129
+ for rank_folder in os.listdir(step_path):
130
+ if rank_folder.startswith('rank'):
131
+ try:
132
+ rank = int(rank_folder[4:])
133
+ except Exception as e:
134
+ raise RuntimeError(f"parse rank number error") from e
135
+ yield rank, os.path.join(step_path, rank_folder)
136
+
137
+ @classmethod
138
+ def get_micro_steps(cls, rank_path):
139
+ for micro_step_folder in os.listdir(rank_path):
140
+ if micro_step_folder.startswith('micro_step'):
141
+ try:
142
+ micro_step = int(micro_step_folder[10:])
143
+ except Exception as e:
144
+ raise RuntimeError(f"parse nicro_step number error") from e
145
+ yield micro_step, os.path.join(rank_path, micro_step_folder)
146
+ else:
147
+ yield 0, rank_path
148
+
149
+ @classmethod
150
+ def get_arrays(cls, micro_step_path):
151
+ for file in os.listdir(micro_step_path):
152
+ if file.endswith('.npy'):
153
+ try:
154
+ parts = file.rsplit('.', 2)
155
+ if len(parts) > 1 and parts[-2].isdigit():
156
+ array_id = int(parts[-2])
157
+ else:
158
+ array_id = 0
159
+ except ValueError:
160
+ array_id = 0
161
+ yield array_id, os.path.join(micro_step_path, file)
162
+
163
+ @classmethod
164
+ def get_array_paths(cls, dir_path):
165
+ """
166
+ 获取目录中所有符合结构的NumPy数组文件路径
167
+ """
168
+ array_paths = {}
169
+ if not os.path.exists(dir_path):
170
+ return array_paths
171
+ for tag in os.listdir(dir_path):
172
+ tag_path = os.path.join(dir_path, tag)
173
+ if not os.path.isdir(tag_path):
174
+ continue
175
+ for step, step_path in cls.get_steps(tag_path):
176
+ for rank, rank_path in cls.get_ranks(step_path):
177
+ for micro_step, micro_step_path in cls.get_micro_steps(rank_path):
178
+ for array_id, array_path in cls.get_arrays(micro_step_path):
179
+ array_paths.setdefault(tag, []).append((step, rank, micro_step, array_id, array_path))
180
+ return array_paths
181
+
182
+ @classmethod
183
+ def compare_single_tag(cls, tag, array_paths1, array_paths2, output_dir):
184
+ try:
185
+ data = []
186
+ paths1 = array_paths1.get(tag, [])
187
+ paths2 = array_paths2.get(tag, [])
188
+ path_dict1 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths1}
189
+ path_dict2 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths2}
190
+ common_keys = set(path_dict1.keys()) & set(path_dict2.keys())
191
+ for key in common_keys:
192
+ try:
193
+ array1 = np.load(path_dict1[key])
194
+ array2 = np.load(path_dict2[key])
195
+ result = cls.compare_arrays(array1, array2)
196
+ step, rank, micro_step, array_id = key
197
+ data.append([
198
+ step, rank, micro_step, array_id,
199
+ list(array1.shape), list(array2.shape),
200
+ result.same_percentage,
201
+ result.first_mismatch_index,
202
+ result.max_abs_error,
203
+ result.max_relative_error,
204
+ result.percentage_within_thousandth,
205
+ result.percentage_within_hundredth
206
+ ])
207
+ except Exception as e:
208
+ logger.error(f"Error comparing {path_dict1[key]} and {path_dict2[key]}: {e}")
209
+
210
+ df = pd.DataFrame(data, columns=SingleComparator.result_header)
211
+ df = df.sort_values(by=['step', 'rank', 'micro_step', 'id'])
212
+ # 构建输出文件的完整路径
213
+ output_file_path = os.path.join(output_dir, f'{tag}.xlsx')
214
+ save_excel(output_file_path, df)
215
+ except Exception as e:
216
+ logger.error(f"Error processing tag {tag}: {e}")
217
+
218
+ @classmethod
219
+ def compare_data(cls, dir1, dir2, output_dir, num_processes=8):
220
+ """
221
+ 比较两个目录中的NumPy数组文件,并将结果保存到指定目录的Excel文件中
222
+ """
223
+
224
+ array_paths1 = cls.get_array_paths(dir1)
225
+ array_paths2 = cls.get_array_paths(dir2)
226
+
227
+ all_tags = set(array_paths1.keys()) | set(array_paths2.keys())
228
+
229
+ with multiprocessing.Pool(processes=num_processes) as pool:
230
+ args = [(tag, array_paths1, array_paths2, output_dir) for tag in all_tags]
231
+ try:
232
+ results = pool.starmap_async(cls.compare_single_tag, args)
233
+ with tqdm(total=len(all_tags), desc="Processing data") as pbar:
234
+ while not results.ready():
235
+ pbar.n = len(all_tags) - results._number_left
236
+ pbar.refresh()
237
+ results.wait()
238
+ results.get()
239
+ except Exception as e:
240
+ logger.error(f"Multiprocessing error: {e}")
241
+ finally:
242
+ pool.close()
243
+ pool.join()