mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,331 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from abc import abstractmethod
16
+
17
+ from mindspore import mint, ops
18
+
19
+ from msprobe.mindspore.common.log import logger
20
+ from msprobe.core.common.const import MonitorConst
21
+
22
+
23
+ class OptimizerMon(object):
24
+ def __init__(self, optim) -> None:
25
+ self.fp16_to_fp32_param = {}
26
+ self.optim = optim
27
+ self.state = {}
28
+
29
+ def narrow_from_flatten(self, param, flatten_state):
30
+ return flatten_state
31
+
32
+ def get_state(self, optim):
33
+ if hasattr(optim, 'chained_optimizers'):
34
+ for opt in optim.chained_optimizers:
35
+ self._get_single_state(opt)
36
+ else:
37
+ self._get_single_state(optim)
38
+
39
+ def fetch_grad(self, monitor, params2name):
40
+ if not self.fp16_to_fp32_param:
41
+ self.map_fp16_to_fp32_param(self.optim)
42
+
43
+ grad_dict = {}
44
+ first_param = True
45
+ for param, name in params2name.items():
46
+ if monitor.duplicate_param.get(name, False):
47
+ continue
48
+ if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
49
+ continue
50
+ grad = param.main_grad if monitor.params_have_main_grad else param.grad
51
+ element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
52
+ if param.numel() != element_in_cur_partition:
53
+ if first_param:
54
+ grad = grad.flatten()[-element_in_cur_partition:]
55
+ else: # supposed to be the last one
56
+ grad = grad.flatten()[:element_in_cur_partition]
57
+ first_param = False
58
+ if grad is None:
59
+ continue
60
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
61
+ monitor.register_param_call_id("hook_optimizer", tag)
62
+ grad_dict[tag] = grad
63
+ return grad_dict
64
+
65
+ def map_fp16_to_fp32_param(self, optim):
66
+ pass
67
+
68
+ def fetch_mv(self, monitor, params2name):
69
+ if not self.fp16_to_fp32_param:
70
+ self.map_fp16_to_fp32_param(self.optim)
71
+ if not self.state:
72
+ self.get_state(self.optim)
73
+
74
+ exp_avg_dict = {}
75
+ exp_avg_sq_dict = {}
76
+ update_dict = {}
77
+ ratio_dict = {}
78
+
79
+ if not self.state:
80
+ logger.warning('optimizer state can not accessed')
81
+ return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict
82
+
83
+ for lp_param, name in params2name.items():
84
+ if lp_param in self.fp16_to_fp32_param:
85
+ hp_param = self.fp16_to_fp32_param[lp_param]
86
+ else:
87
+ hp_param = lp_param
88
+
89
+ if hp_param in self.state:
90
+ state_param = self.state.get(hp_param, {})
91
+ exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None))
92
+ exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None))
93
+ if monitor.mv_distribution:
94
+ exp_avg_dict[name] = exp_avg
95
+ exp_avg_sq_dict[name] = exp_avg_sq
96
+ if monitor.mg_direction:
97
+ exp_avg_dict[name] = exp_avg
98
+ if monitor.ur_distribution:
99
+ if len(self.optim.param_groups) > 1:
100
+ logger.info(f"the length of optim.param_groups is {len(self.optim.param_groups)}.")
101
+ if 'step' in state_param:
102
+ step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
103
+ elif 'step' in self.optim.param_groups[0]:
104
+ step = self.optim.param_groups[0]['step'] # AdamW from mindspeed
105
+ else:
106
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
107
+ continue
108
+ exp_avg_hat = exp_avg / (1 - self.optim.defaults['betas'][0] ** step)
109
+ exp_avg_sq_hat = exp_avg_sq / (1 - self.optim.defaults['betas'][1] ** step)
110
+ update_dict[name] = exp_avg_hat / (mint.sqrt(exp_avg_sq_hat) + self.optim.defaults['eps'])
111
+ ratio_dict[name] = exp_avg_hat / mint.sqrt(exp_avg_sq_hat)
112
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
113
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
114
+ return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict
115
+
116
+ def _get_single_state(self, optim):
117
+ state = {}
118
+ if hasattr(optim, 'param_to_cpu_states_map'):
119
+ state = optim.param_to_cpu_states_map
120
+ elif hasattr(optim, 'state'):
121
+ state = optim.state
122
+ elif hasattr(optim, 'optimizer') and hasattr(optim.optimizer, 'state'):
123
+ state = optim.optimizer.state
124
+ self.state.update(state)
125
+
126
+
127
+ class MixPrecisionOptimizerMon(OptimizerMon):
128
+ """
129
+ 混合精度优化器监控类。在混合精度训练中监控和管理优化器。
130
+ 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
131
+ """
132
+ def map_fp16_to_fp32_param(self, optim):
133
+ for fp16_group, fp32_group in zip(optim.float16_groups, optim.fp32_from_float16_groups):
134
+ for fp16_param, fp32_param in zip(fp16_group, fp32_group):
135
+ self.fp16_to_fp32_param[fp16_param] = fp32_param
136
+
137
+
138
+ class MegatronDistributedOptimizerMon(OptimizerMon):
139
+ def map_fp16_to_fp32_param(self, optim):
140
+ if not (hasattr(optim, "model_float16_groups") and
141
+ hasattr(optim, "shard_fp32_from_float16_groups")):
142
+ raise Exception(
143
+ "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
144
+ "if not, please check megatron-lm version")
145
+ for fp16_group, shard_fp32_group in zip(optim.model_float16_groups,
146
+ optim.shard_fp32_from_float16_groups):
147
+ for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
148
+ self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
149
+
150
+
151
+ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
152
+ def map_fp16_to_fp32_param(self, optim):
153
+ for opt in optim.chained_optimizers:
154
+ super().map_fp16_to_fp32_param(opt)
155
+
156
+
157
+ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
158
+ def map_fp16_to_fp32_param(self, optim):
159
+ for opt in optim.chained_optimizers:
160
+ super().map_fp16_to_fp32_param(opt)
161
+
162
+
163
+ class DeepSpeedZeroOptimizerMon(OptimizerMon):
164
+ """
165
+ Base monitor class for DeepSpeed ZeRO optimizer.
166
+ ZeRO stage 0 no partition
167
+ ZeRO stage 1 partitions optimizer states across data parallel processes.
168
+ ZeRO stage 2 additionally partitions gradients.
169
+ ZeRO stage 3 additionally partitions parameters.
170
+
171
+ This class provides monitoring capabilities for ZeRO optimizers by:
172
+ - Handling gradient collection for different ZeRO stages
173
+ - Managing optimizer state access for monitoring
174
+ """
175
+ def __init__(self, optim):
176
+ super().__init__(optim)
177
+ self.stage = ''
178
+ self.bit16_groups = []
179
+ self.fp32_flat_groups = []
180
+ self.param2group = ()
181
+ self.param2index = []
182
+ self.group_offset = {}
183
+
184
+ @abstractmethod
185
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
186
+ raise NotImplementedError
187
+
188
+ def param_not_in_partition(self, lp_param, group_idx):
189
+ param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx]
190
+ hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param))
191
+ return hp_address is None
192
+
193
+ def get_position(self, lp_param, group_idx):
194
+ param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx]
195
+ hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param))
196
+ return hp_address.start, hp_address.numel
197
+
198
+ def get_group_index(self):
199
+ param2group = {}
200
+ for group_idx, bit16_group in enumerate(self.bit16_groups):
201
+ for param in bit16_group:
202
+ param2group[param] = group_idx
203
+ return param2group
204
+
205
+ def get_param_index(self, lp_param, group_idx):
206
+ if not self.param2index:
207
+ for group in self.bit16_groups:
208
+ param2index = {}
209
+ for index, param in enumerate(group):
210
+ param2index[param] = index
211
+ self.param2index.append(param2index)
212
+
213
+ return self.param2index[group_idx][lp_param]
214
+
215
+ def narrow_from_flatten(self, param, flatten_state):
216
+ if flatten_state is None:
217
+ return flatten_state
218
+ group_idx = self.param2group[param]
219
+ if self.param_not_in_partition(param, group_idx):
220
+ return None
221
+ start, numel = self.get_position(param, group_idx)
222
+ return flatten_state.narrow(0, start, numel)
223
+
224
+ def map_fp16_to_fp32_param(self, optim):
225
+ for group_idx, group in enumerate(self.bit16_groups):
226
+ for param in group:
227
+ self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx]
228
+
229
+ def fetch_grad(self, monitor, params2name):
230
+ grad_dict = {}
231
+ for lp_param, name in params2name.items():
232
+ group_idx = self.param2group[lp_param]
233
+ param_id = self.get_param_index(lp_param, group_idx)
234
+ if self.param_not_in_partition(lp_param, group_idx):
235
+ continue
236
+ if self.stage == '1or2':
237
+ param_id = param_id - self.group_offset[group_idx] - 1
238
+ grad = self.get_grad_for_param(lp_param, group_idx, param_id)
239
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
240
+ monitor.register_param_call_id("hook_optimizer", tag)
241
+ grad_dict[tag] = grad
242
+
243
+ return grad_dict
244
+
245
+
246
+ class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
247
+ def __init__(self, optim):
248
+ super().__init__(optim)
249
+ self.stage = '0'
250
+ self.bit16_groups = optim.bf16_groups
251
+ self.fp32_flat_groups = optim.fp32_groups_flat_partition
252
+ self.param2group = self.get_group_index()
253
+
254
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
255
+ return self.optim.fp32_groups_gradient_dict[group_idx][param_id]
256
+
257
+
258
+ class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
259
+ def __init__(self, optim):
260
+ super().__init__(optim)
261
+ self.stage = '1or2'
262
+ self.bit16_groups = optim.bit16_groups
263
+ self.fp32_flat_groups = optim.single_partition_of_fp32_groups
264
+ self.param2group = self.get_group_index()
265
+ self.group_offset = {}
266
+ self.get_group_offset()
267
+
268
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
269
+ if getattr(self.optim, "cpu_offload", False):
270
+ grads = self.optim.single_partition_of_fp32_groups[group_idx].grad
271
+ start, numel = self.get_position(lp_param, group_idx)
272
+ grad = grads.narrow(0, start, numel)
273
+ else:
274
+ grad = self.optim.averaged_gradients[group_idx][param_id]
275
+ return grad
276
+
277
+ def get_group_offset(self):
278
+ for group_idx, group in enumerate(self.bit16_groups):
279
+ self.group_offset[group_idx] = -1
280
+ for lp_param in group:
281
+ if self.param_not_in_partition(lp_param, group_idx):
282
+ self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx)
283
+ else:
284
+ break
285
+
286
+
287
+ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
288
+ def __init__(self, optim):
289
+ super().__init__(optim)
290
+ self.stage = '3'
291
+ self.bit16_groups = optim.fp16_groups
292
+ self.fp32_flat_groups = optim.fp32_partitioned_groups_flat
293
+ self.param2group = self.get_group_index()
294
+
295
+ def param_not_in_partition(self, param, group_index):
296
+ """Each param partioned across all zero ranks"""
297
+ return False
298
+
299
+ def get_position(self, lp_param, group_idx):
300
+ param_id = self.optim.get_param_id(lp_param)
301
+ return self.optim.grad_position[param_id][1:]
302
+
303
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
304
+ return self.optim.averaged_gradients[group_idx][param_id]
305
+
306
+
307
+ class OptimizerMonFactory:
308
+ _optimizer_mon_map = {
309
+ "FP32Optimizer": OptimizerMon,
310
+ "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
311
+ "DistributedOptimizer": MegatronDistributedOptimizerMon,
312
+ "SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
313
+ "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
314
+ "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
315
+ "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
316
+ "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
317
+ "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
318
+ "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
319
+ "Adam": OptimizerMon
320
+ }
321
+
322
+ @staticmethod
323
+ def create_optimizer_mon(optimizer):
324
+ # auto replace opt_ty
325
+ optimizer_class = optimizer.__class__.__name__
326
+ if optimizer_class == "ChainedOptimizer":
327
+ optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
328
+ logger.info(f'The optimizer type is {optimizer_class}')
329
+
330
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon)
331
+ return optimizer_mon_class(optimizer)
@@ -12,27 +12,36 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
15
+ import os
16
+ import re
17
+ from datetime import datetime
16
18
  from mindspore import dtype as mstype, Tensor
17
19
 
18
20
  from msprobe.mindspore.monitor.features import FUNC_MAP
19
21
  from msprobe.core.common.const import MonitorConst
20
22
  from msprobe.core.common.utils import is_int
21
23
  from msprobe.core.common.log import logger
24
+ from msprobe.core.common.file_utils import check_file_or_directory_path
22
25
 
23
26
 
24
- def get_single_metrics(op_list, tag, tensor, output=None):
27
+ def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None):
25
28
  if output is None:
26
29
  output = {}
27
30
  if tag not in output:
28
31
  output[tag] = {}
29
32
  for op in op_list:
30
33
  func = FUNC_MAP.get(op)
31
- statistic = func(tensor)
34
+ if op == "zeros":
35
+ statistic = func(tensor, eps)
36
+ else:
37
+ statistic = func(tensor)
32
38
  if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
33
39
  statistic = float(statistic)
34
40
  statistic = Tensor(statistic)
35
- output[tag][op] = statistic.astype(mstype.float32)
41
+ if isinstance(statistic, Tensor):
42
+ output[tag][op] = statistic.astype(mstype.float32)
43
+ else:
44
+ output[tag][op] = statistic
36
45
 
37
46
 
38
47
  def get_metrics(op_list, tag2tensor, eps, output=None):
@@ -41,7 +50,7 @@ def get_metrics(op_list, tag2tensor, eps, output=None):
41
50
  for tag, tensor in tag2tensor.items():
42
51
  if tag not in output:
43
52
  output[tag] = {}
44
- get_single_metrics(op_list, tag, tensor, output)
53
+ get_single_metrics(op_list, tag, tensor, eps, output)
45
54
  return output
46
55
 
47
56
 
@@ -88,6 +97,11 @@ def validate_ops(ops):
88
97
  default_op = MonitorConst.OP_LIST[0]
89
98
  valid_ops.append(default_op)
90
99
  logger.info(f"There is no valid ops, default op {default_op} is used")
100
+ # 增加默认shape和dtype参数
101
+ if "shape" not in valid_ops:
102
+ valid_ops.append("shape")
103
+ if "dtype" not in valid_ops:
104
+ valid_ops.append("dtype")
91
105
  return valid_ops
92
106
 
93
107
 
@@ -95,8 +109,8 @@ def validate_ranks(ranks):
95
109
  if not isinstance(ranks, list):
96
110
  raise TypeError("module_ranks should be a list")
97
111
  for rank in ranks:
98
- if not isinstance(rank, str):
99
- raise TypeError(f"element in module_ranks should be a str, get {type(rank)}")
112
+ if not isinstance(rank, int):
113
+ raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
100
114
 
101
115
 
102
116
  def validate_targets(targets):
@@ -168,7 +182,7 @@ def validate_alert(alert):
168
182
  args = rule.get("args")
169
183
  if args and isinstance(args, dict):
170
184
  threshold = args.get("threshold")
171
- if not isinstance(threshold, float) or threshold < 0:
185
+ if not isinstance(threshold, (float, int)) or threshold < 0:
172
186
  raise TypeError('threshold must be float and not less than 0')
173
187
  dump = alert.get('dump')
174
188
  if dump and not isinstance(dump, bool):
@@ -209,6 +223,18 @@ def validate_collect_times(collect_times):
209
223
  raise ValueError("collect_times must greater than 1")
210
224
 
211
225
 
226
+ def validate_dynamic_on(dynamic_on):
227
+ if not isinstance(dynamic_on, bool):
228
+ raise TypeError('dynamic_on should be a bool')
229
+
230
+
231
+ def validate_monitor_mbs_grad(monitor_mbs_grad):
232
+ if not isinstance(monitor_mbs_grad, bool):
233
+ logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
234
+ return False
235
+ return monitor_mbs_grad
236
+
237
+
212
238
  def validate_config(config):
213
239
  config['ops'] = validate_ops(config.get('ops', []))
214
240
 
@@ -255,9 +281,14 @@ def validate_config(config):
255
281
  step_interval = config.get('step_interval', 1)
256
282
  validate_step_interval(step_interval)
257
283
 
258
- collect_times = config.get('collect_times', 1e8)
284
+ collect_times = config.get('collect_times', int(1e8))
259
285
  validate_collect_times(collect_times)
260
286
 
287
+ config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
288
+
289
+ dynamic_on = config.get('dynamic_on', False)
290
+ validate_dynamic_on(dynamic_on)
291
+
261
292
  if not targets:
262
293
  if xy_distribution:
263
294
  config["all_xy"] = True
@@ -265,3 +296,34 @@ def validate_config(config):
265
296
  config["is_select"] = False
266
297
  else:
267
298
  config["is_select"] = True
299
+
300
+
301
+ def time_str2time_digit(time_str):
302
+ time_format = '%b%d_%H-%M-%S'
303
+ try:
304
+ time_digit = datetime.strptime(time_str, time_format)
305
+ except Exception as e:
306
+ raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
307
+ of existing output dirpath, like 'Dec03_21-34-40'.") from e
308
+ return time_digit
309
+
310
+
311
+ def get_target_output_dir(monitor_path, time_start, time_end):
312
+ check_file_or_directory_path(monitor_path, isdir=True)
313
+ time_start = time_str2time_digit(time_start) if time_start is not None else time_start
314
+ time_end = time_str2time_digit(time_end) if time_end is not None else time_end
315
+ if time_start and time_end and time_start > time_end:
316
+ raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
317
+ result = {}
318
+ for dirname in os.listdir(monitor_path):
319
+ match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
320
+ if not match:
321
+ continue
322
+ time_tag = match.group(1)
323
+ rank = match.group(2)
324
+ target_time = time_str2time_digit(time_tag)
325
+ start_ok = time_start is None or target_time >= time_start
326
+ end_ok = time_end is None or target_time <= time_end
327
+ if start_ok and end_ok:
328
+ result[rank] = os.path.join(monitor_path, dirname)
329
+ return result
@@ -29,6 +29,7 @@ class TensorConfig(BaseConfig):
29
29
  self.check_mode = None
30
30
  self.file_format = json_config.get("file_format")
31
31
  self.check_config()
32
+ self._check_summary_mode()
32
33
  self._check_config()
33
34
 
34
35
  def _check_config(self):
@@ -42,12 +43,23 @@ class StatisticsConfig(BaseConfig):
42
43
  self.file_format = None
43
44
  self.check_mode = None
44
45
  self.check_config()
45
- self._check_config()
46
+ self._check_summary_mode()
46
47
 
47
- def _check_config(self):
48
- single_opt = ["statistics", "md5"]
48
+ self.tensor_list = json_config.get("tensor_list", [])
49
+ self._check_str_list_config(self.tensor_list, "tensor_list")
50
+ self.stat_cal_mode = json_config.get("device", "host")
51
+ self.device_stat_precision_mode = json_config.get("precision", "high")
52
+ self._check_stat_params()
53
+
54
+ def _check_stat_params(self):
55
+ if self.stat_cal_mode not in ["device", "host"]:
56
+ raise Exception("Config param [device] is invalid, expected from [\"device\", \"host\"]")
57
+ if self.device_stat_precision_mode not in ["high", "low"]:
58
+ raise Exception("Config param [precision] is invalid, expected from [\"high\", \"low\"]")
59
+
60
+ def _check_summary_mode(self):
49
61
  muti_opt = ["md5", "max", "min", "mean", "l2norm"]
50
- if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt:
62
+ if isinstance(self.summary_mode, str) and self.summary_mode not in Const.SUMMARY_MODE:
51
63
  raise Exception("summary_mode is invalid")
52
64
  if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode):
53
65
  raise Exception("summary_mode is invalid")
@@ -132,14 +144,3 @@ def parse_task_config(task, json_config):
132
144
  if task not in TaskDict:
133
145
  raise Exception("task is invalid.")
134
146
  return TaskDict.get(task)(task_map)
135
-
136
-
137
- def parse_json_config(json_file_path):
138
- if not json_file_path:
139
- raise Exception("json file path is None")
140
- json_config = load_json(json_file_path)
141
- common_config = parse_common_config(json_config)
142
- if not common_config.task:
143
- common_config.task = Const.STATISTICS
144
- task_config = parse_task_config(common_config.task, json_config)
145
- return common_config, task_config
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from msprobe.core.common.log import logger
16
17
  from msprobe.mindspore.common.const import Const
17
18
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
19
  from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
@@ -44,6 +45,7 @@ class OverflowCheckToolFactory:
44
45
  raise Exception("Valid level is needed.")
45
46
  tool = tool.get(config.execution_mode)
46
47
  if not tool:
47
- raise Exception(f"Overflow check is not supported in {config.execution_mode} mode "
48
- f"when level is {config.level}.")
48
+ logger.error(f"Overflow check is not supported in {config.execution_mode} mode "
49
+ f"when level is {config.level}.")
50
+ raise ValueError
49
51
  return tool(config)
@@ -29,11 +29,14 @@ class TaskHandlerFactory:
29
29
  }
30
30
 
31
31
  @staticmethod
32
- def create(config: DebuggerConfig):
32
+ def create(config: DebuggerConfig, model=None):
33
33
  task = TaskHandlerFactory.tasks.get(config.task)
34
34
  if not task:
35
35
  raise Exception("Valid task is needed.")
36
- handler = task.create(config)
36
+ if task == DumpToolFactory:
37
+ handler = task.create(config, model)
38
+ else:
39
+ handler = task.create(config)
37
40
  if not handler:
38
41
  raise Exception("Can not find task handler")
39
42
  return handler
msprobe/msprobe.py CHANGED
@@ -22,6 +22,8 @@ from msprobe.core.common.log import logger
22
22
  from msprobe.core.compare.utils import _compare_parser
23
23
  from msprobe.core.compare.compare_cli import compare_cli
24
24
  from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
25
+ from msprobe.core.config_check.config_check_cli import _config_checking_parser, \
26
+ _run_config_checking_command
25
27
 
26
28
 
27
29
  def is_module_available(module_name):
@@ -51,6 +53,9 @@ def main():
51
53
  graph_service_cmd_parser = subparsers.add_parser('graph')
52
54
  op_generate_cmd_parser = subparsers.add_parser('op_generate')
53
55
  merge_result_parser = subparsers.add_parser('merge_result')
56
+ config_checking_parser = subparsers.add_parser('config_check')
57
+ nan_analyze_parser = subparsers.add_parser('nan_analyze')
58
+ _config_checking_parser(config_checking_parser)
54
59
  _compare_parser(compare_cmd_parser)
55
60
  _merge_result_parser(merge_result_parser)
56
61
 
@@ -71,6 +76,7 @@ def main():
71
76
  from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
72
77
  from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
73
78
  _run_operator_generate_commond
79
+ from msprobe.nan_analyze.analyzer import _nan_analyze_parser, _run_nan_analyze
74
80
 
75
81
  _run_ut_parser(run_ut_cmd_parser)
76
82
  _run_ut_parser(multi_run_ut_cmd_parser)
@@ -80,6 +86,7 @@ def main():
80
86
  _run_overflow_check_parser(run_overflow_check_cmd_parser)
81
87
  _pt_graph_service_parser(graph_service_cmd_parser)
82
88
  _op_generator_parser(op_generate_cmd_parser)
89
+ _nan_analyze_parser(nan_analyze_parser)
83
90
  elif framework_args.framework == Const.MS_FRAMEWORK:
84
91
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
85
92
  from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
@@ -91,6 +98,10 @@ def main():
91
98
 
92
99
  _ms_graph_service_parser(graph_service_cmd_parser)
93
100
 
101
+ from msprobe.mindspore.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
102
+ _run_operator_generate_commond
103
+ _op_generator_parser(op_generate_cmd_parser)
104
+
94
105
  args = parser.parse_args(sys.argv[1:])
95
106
  if sys.argv[2] == Const.PT_FRAMEWORK:
96
107
  if not is_torch_available:
@@ -118,6 +129,10 @@ def main():
118
129
  compare_cli(args)
119
130
  elif sys.argv[3] == "merge_result":
120
131
  merge_result_cli(args)
132
+ elif sys.argv[3] == "config_check":
133
+ _run_config_checking_command(args)
134
+ elif sys.argv[3] == "nan_analyze":
135
+ _run_nan_analyze(args)
121
136
  else:
122
137
  if not is_module_available(Const.MS_FRAMEWORK):
123
138
  logger.error("MindSpore does not exist, please install MindSpore library")
@@ -134,9 +149,13 @@ def main():
134
149
  mul_api_checker_main(args)
135
150
  elif sys.argv[3] == "graph":
136
151
  _ms_graph_service_command(args)
152
+ elif sys.argv[3] == 'op_generate':
153
+ _run_operator_generate_commond(args)
137
154
  elif sys.argv[3] == "code_mapping":
138
155
  from msprobe.mindspore.code_mapping.main import code_mapping_main
139
156
  code_mapping_main(args)
157
+ elif sys.argv[3] == "config_check":
158
+ _run_config_checking_command(args)
140
159
 
141
160
 
142
161
  if __name__ == "__main__":
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.