mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,334 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from abc import abstractmethod
16
+
17
+ from mindspore import mint, ops
18
+
19
+ from msprobe.mindspore.common.log import logger
20
+ from msprobe.core.common.const import MonitorConst
21
+
22
+
23
+ class OptimizerMon(object):
24
+ def __init__(self, optim) -> None:
25
+ self.fp16_to_fp32_param = {}
26
+ self.optim = optim
27
+ self.state = {}
28
+
29
+ def narrow_from_flatten(self, param, flatten_state):
30
+ return flatten_state
31
+
32
+ def get_state(self, optim):
33
+ if hasattr(optim, 'chained_optimizers'):
34
+ for opt in optim.chained_optimizers:
35
+ self._get_single_state(opt)
36
+ else:
37
+ self._get_single_state(optim)
38
+
39
+ def fetch_grad(self, monitor, params2name):
40
+ if not self.fp16_to_fp32_param:
41
+ self.map_fp16_to_fp32_param(self.optim)
42
+
43
+ grad_dict = {}
44
+ first_param = True
45
+ for param, name in params2name.items():
46
+ if monitor.duplicate_param.get(name, False):
47
+ continue
48
+ if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
49
+ continue
50
+ grad = param.main_grad if monitor.params_have_main_grad else param.grad
51
+ element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
52
+ if param.numel() != element_in_cur_partition:
53
+ if first_param:
54
+ grad = grad.flatten()[-element_in_cur_partition:]
55
+ else: # supposed to be the last one
56
+ grad = grad.flatten()[:element_in_cur_partition]
57
+ first_param = False
58
+ if grad is None:
59
+ continue
60
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
61
+ monitor.register_param_call_id("hook_optimizer", tag)
62
+ grad_dict[tag] = grad
63
+ return grad_dict
64
+
65
+ def map_fp16_to_fp32_param(self, optim):
66
+ pass
67
+
68
+ def fetch_mv(self, monitor, params2name):
69
+ if not self.fp16_to_fp32_param:
70
+ self.map_fp16_to_fp32_param(self.optim)
71
+ if not self.state:
72
+ self.get_state(self.optim)
73
+
74
+ exp_avg_dict = {}
75
+ exp_avg_sq_dict = {}
76
+ update_dict = {}
77
+ ratio_dict = {}
78
+
79
+ if not self.state:
80
+ logger.warning('optimizer state can not accessed')
81
+ return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict
82
+
83
+ for lp_param, name in params2name.items():
84
+ if lp_param in self.fp16_to_fp32_param:
85
+ hp_param = self.fp16_to_fp32_param[lp_param]
86
+ else:
87
+ hp_param = lp_param
88
+
89
+ if hp_param in self.state:
90
+ state_param = self.state.get(hp_param, {})
91
+ exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None))
92
+ exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None))
93
+ if monitor.mv_distribution:
94
+ exp_avg_dict[name] = exp_avg
95
+ exp_avg_sq_dict[name] = exp_avg_sq
96
+ if monitor.mg_direction:
97
+ exp_avg_dict[name] = exp_avg
98
+ if monitor.ur_distribution:
99
+ if len(self.optim.param_groups) > 1:
100
+ logger.info(f"the length of optim.param_groups is {len(self.optim.param_groups)}.")
101
+ if 'step' in state_param:
102
+ step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
103
+ elif 'step' in self.optim.param_groups[0]:
104
+ step = self.optim.param_groups[0]['step'] # AdamW from mindspeed
105
+ else:
106
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
107
+ continue
108
+ if exp_avg is None or exp_avg_sq is None:
109
+ logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.")
110
+ continue
111
+ exp_avg_hat = exp_avg / (1 - self.optim.defaults['betas'][0] ** step)
112
+ exp_avg_sq_hat = exp_avg_sq / (1 - self.optim.defaults['betas'][1] ** step)
113
+ update_dict[name] = exp_avg_hat / (mint.sqrt(exp_avg_sq_hat) + self.optim.defaults['eps'])
114
+ ratio_dict[name] = exp_avg_hat / mint.sqrt(exp_avg_sq_hat)
115
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
116
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
117
+ return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict
118
+
119
+ def _get_single_state(self, optim):
120
+ state = {}
121
+ if hasattr(optim, 'param_to_cpu_states_map'):
122
+ state = optim.param_to_cpu_states_map
123
+ elif hasattr(optim, 'state'):
124
+ state = optim.state
125
+ elif hasattr(optim, 'optimizer') and hasattr(optim.optimizer, 'state'):
126
+ state = optim.optimizer.state
127
+ self.state.update(state)
128
+
129
+
130
+ class MixPrecisionOptimizerMon(OptimizerMon):
131
+ """
132
+ 混合精度优化器监控类。在混合精度训练中监控和管理优化器。
133
+ 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
134
+ """
135
+ def map_fp16_to_fp32_param(self, optim):
136
+ for fp16_group, fp32_group in zip(optim.float16_groups, optim.fp32_from_float16_groups):
137
+ for fp16_param, fp32_param in zip(fp16_group, fp32_group):
138
+ self.fp16_to_fp32_param[fp16_param] = fp32_param
139
+
140
+
141
+ class MegatronDistributedOptimizerMon(OptimizerMon):
142
+ def map_fp16_to_fp32_param(self, optim):
143
+ if not (hasattr(optim, "model_float16_groups") and
144
+ hasattr(optim, "shard_fp32_from_float16_groups")):
145
+ raise Exception(
146
+ "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
147
+ "if not, please check megatron-lm version")
148
+ for fp16_group, shard_fp32_group in zip(optim.model_float16_groups,
149
+ optim.shard_fp32_from_float16_groups):
150
+ for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
151
+ self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
152
+
153
+
154
+ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
155
+ def map_fp16_to_fp32_param(self, optim):
156
+ for opt in optim.chained_optimizers:
157
+ super().map_fp16_to_fp32_param(opt)
158
+
159
+
160
+ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
161
+ def map_fp16_to_fp32_param(self, optim):
162
+ for opt in optim.chained_optimizers:
163
+ super().map_fp16_to_fp32_param(opt)
164
+
165
+
166
+ class DeepSpeedZeroOptimizerMon(OptimizerMon):
167
+ """
168
+ Base monitor class for DeepSpeed ZeRO optimizer.
169
+ ZeRO stage 0 no partition
170
+ ZeRO stage 1 partitions optimizer states across data parallel processes.
171
+ ZeRO stage 2 additionally partitions gradients.
172
+ ZeRO stage 3 additionally partitions parameters.
173
+
174
+ This class provides monitoring capabilities for ZeRO optimizers by:
175
+ - Handling gradient collection for different ZeRO stages
176
+ - Managing optimizer state access for monitoring
177
+ """
178
+ def __init__(self, optim):
179
+ super().__init__(optim)
180
+ self.stage = ''
181
+ self.bit16_groups = []
182
+ self.fp32_flat_groups = []
183
+ self.param2group = ()
184
+ self.param2index = []
185
+ self.group_offset = {}
186
+
187
+ @abstractmethod
188
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
189
+ raise NotImplementedError
190
+
191
+ def param_not_in_partition(self, lp_param, group_idx):
192
+ param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx]
193
+ hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param))
194
+ return hp_address is None
195
+
196
+ def get_position(self, lp_param, group_idx):
197
+ param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx]
198
+ hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param))
199
+ return hp_address.start, hp_address.numel
200
+
201
+ def get_group_index(self):
202
+ param2group = {}
203
+ for group_idx, bit16_group in enumerate(self.bit16_groups):
204
+ for param in bit16_group:
205
+ param2group[param] = group_idx
206
+ return param2group
207
+
208
+ def get_param_index(self, lp_param, group_idx):
209
+ if not self.param2index:
210
+ for group in self.bit16_groups:
211
+ param2index = {}
212
+ for index, param in enumerate(group):
213
+ param2index[param] = index
214
+ self.param2index.append(param2index)
215
+
216
+ return self.param2index[group_idx][lp_param]
217
+
218
+ def narrow_from_flatten(self, param, flatten_state):
219
+ if flatten_state is None:
220
+ return flatten_state
221
+ group_idx = self.param2group[param]
222
+ if self.param_not_in_partition(param, group_idx):
223
+ return None
224
+ start, numel = self.get_position(param, group_idx)
225
+ return flatten_state.narrow(0, start, numel)
226
+
227
+ def map_fp16_to_fp32_param(self, optim):
228
+ for group_idx, group in enumerate(self.bit16_groups):
229
+ for param in group:
230
+ self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx]
231
+
232
+ def fetch_grad(self, monitor, params2name):
233
+ grad_dict = {}
234
+ for lp_param, name in params2name.items():
235
+ group_idx = self.param2group[lp_param]
236
+ param_id = self.get_param_index(lp_param, group_idx)
237
+ if self.param_not_in_partition(lp_param, group_idx):
238
+ continue
239
+ if self.stage == '1or2':
240
+ param_id = param_id - self.group_offset[group_idx] - 1
241
+ grad = self.get_grad_for_param(lp_param, group_idx, param_id)
242
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
243
+ monitor.register_param_call_id("hook_optimizer", tag)
244
+ grad_dict[tag] = grad
245
+
246
+ return grad_dict
247
+
248
+
249
+ class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
250
+ def __init__(self, optim):
251
+ super().__init__(optim)
252
+ self.stage = '0'
253
+ self.bit16_groups = optim.bf16_groups
254
+ self.fp32_flat_groups = optim.fp32_groups_flat_partition
255
+ self.param2group = self.get_group_index()
256
+
257
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
258
+ return self.optim.fp32_groups_gradient_dict[group_idx][param_id]
259
+
260
+
261
+ class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
262
+ def __init__(self, optim):
263
+ super().__init__(optim)
264
+ self.stage = '1or2'
265
+ self.bit16_groups = optim.bit16_groups
266
+ self.fp32_flat_groups = optim.single_partition_of_fp32_groups
267
+ self.param2group = self.get_group_index()
268
+ self.group_offset = {}
269
+ self.get_group_offset()
270
+
271
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
272
+ if getattr(self.optim, "cpu_offload", False):
273
+ grads = self.optim.single_partition_of_fp32_groups[group_idx].grad
274
+ start, numel = self.get_position(lp_param, group_idx)
275
+ grad = grads.narrow(0, start, numel)
276
+ else:
277
+ grad = self.optim.averaged_gradients[group_idx][param_id]
278
+ return grad
279
+
280
+ def get_group_offset(self):
281
+ for group_idx, group in enumerate(self.bit16_groups):
282
+ self.group_offset[group_idx] = -1
283
+ for lp_param in group:
284
+ if self.param_not_in_partition(lp_param, group_idx):
285
+ self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx)
286
+ else:
287
+ break
288
+
289
+
290
+ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
291
+ def __init__(self, optim):
292
+ super().__init__(optim)
293
+ self.stage = '3'
294
+ self.bit16_groups = optim.fp16_groups
295
+ self.fp32_flat_groups = optim.fp32_partitioned_groups_flat
296
+ self.param2group = self.get_group_index()
297
+
298
+ def param_not_in_partition(self, lp_param, group_idx):
299
+ """Each param partioned across all zero ranks"""
300
+ return False
301
+
302
+ def get_position(self, lp_param, group_idx):
303
+ param_id = self.optim.get_param_id(lp_param)
304
+ return self.optim.grad_position[param_id][1:]
305
+
306
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
307
+ return self.optim.averaged_gradients[group_idx][param_id]
308
+
309
+
310
+ class OptimizerMonFactory:
311
+ _optimizer_mon_map = {
312
+ "FP32Optimizer": OptimizerMon,
313
+ "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
314
+ "DistributedOptimizer": MegatronDistributedOptimizerMon,
315
+ "SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
316
+ "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
317
+ "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
318
+ "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
319
+ "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
320
+ "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
321
+ "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
322
+ "Adam": OptimizerMon
323
+ }
324
+
325
+ @staticmethod
326
+ def create_optimizer_mon(optimizer):
327
+ # auto replace opt_ty
328
+ optimizer_class = optimizer.__class__.__name__
329
+ if optimizer_class == "ChainedOptimizer":
330
+ optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
331
+ logger.info(f'The optimizer type is {optimizer_class}')
332
+
333
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon)
334
+ return optimizer_mon_class(optimizer)
@@ -24,18 +24,24 @@ from msprobe.core.common.log import logger
24
24
  from msprobe.core.common.file_utils import check_file_or_directory_path
25
25
 
26
26
 
27
- def get_single_metrics(op_list, tag, tensor, output=None):
27
+ def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None):
28
28
  if output is None:
29
29
  output = {}
30
30
  if tag not in output:
31
31
  output[tag] = {}
32
32
  for op in op_list:
33
33
  func = FUNC_MAP.get(op)
34
- statistic = func(tensor)
34
+ if op == "zeros":
35
+ statistic = func(tensor, eps)
36
+ else:
37
+ statistic = func(tensor)
35
38
  if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
36
39
  statistic = float(statistic)
37
40
  statistic = Tensor(statistic)
38
- output[tag][op] = statistic.astype(mstype.float32)
41
+ if isinstance(statistic, Tensor):
42
+ output[tag][op] = statistic.astype(mstype.float32)
43
+ else:
44
+ output[tag][op] = statistic
39
45
 
40
46
 
41
47
  def get_metrics(op_list, tag2tensor, eps, output=None):
@@ -44,7 +50,7 @@ def get_metrics(op_list, tag2tensor, eps, output=None):
44
50
  for tag, tensor in tag2tensor.items():
45
51
  if tag not in output:
46
52
  output[tag] = {}
47
- get_single_metrics(op_list, tag, tensor, output)
53
+ get_single_metrics(op_list, tag, tensor, eps, output)
48
54
  return output
49
55
 
50
56
 
@@ -91,6 +97,11 @@ def validate_ops(ops):
91
97
  default_op = MonitorConst.OP_LIST[0]
92
98
  valid_ops.append(default_op)
93
99
  logger.info(f"There is no valid ops, default op {default_op} is used")
100
+ # 增加默认shape和dtype参数
101
+ if "shape" not in valid_ops:
102
+ valid_ops.append("shape")
103
+ if "dtype" not in valid_ops:
104
+ valid_ops.append("dtype")
94
105
  return valid_ops
95
106
 
96
107
 
@@ -171,7 +182,7 @@ def validate_alert(alert):
171
182
  args = rule.get("args")
172
183
  if args and isinstance(args, dict):
173
184
  threshold = args.get("threshold")
174
- if not isinstance(threshold, float) or threshold < 0:
185
+ if not isinstance(threshold, (float, int)) or threshold < 0:
175
186
  raise TypeError('threshold must be float and not less than 0')
176
187
  dump = alert.get('dump')
177
188
  if dump and not isinstance(dump, bool):
@@ -217,6 +228,13 @@ def validate_dynamic_on(dynamic_on):
217
228
  raise TypeError('dynamic_on should be a bool')
218
229
 
219
230
 
231
+ def validate_monitor_mbs_grad(monitor_mbs_grad):
232
+ if not isinstance(monitor_mbs_grad, bool):
233
+ logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
234
+ return False
235
+ return monitor_mbs_grad
236
+
237
+
220
238
  def validate_config(config):
221
239
  config['ops'] = validate_ops(config.get('ops', []))
222
240
 
@@ -266,6 +284,8 @@ def validate_config(config):
266
284
  collect_times = config.get('collect_times', int(1e8))
267
285
  validate_collect_times(collect_times)
268
286
 
287
+ config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
288
+
269
289
  dynamic_on = config.get('dynamic_on', False)
270
290
  validate_dynamic_on(dynamic_on)
271
291
 
@@ -29,6 +29,7 @@ class TensorConfig(BaseConfig):
29
29
  self.check_mode = None
30
30
  self.file_format = json_config.get("file_format")
31
31
  self.check_config()
32
+ self._check_summary_mode()
32
33
  self._check_config()
33
34
 
34
35
  def _check_config(self):
@@ -42,12 +43,23 @@ class StatisticsConfig(BaseConfig):
42
43
  self.file_format = None
43
44
  self.check_mode = None
44
45
  self.check_config()
45
- self._check_config()
46
+ self._check_summary_mode()
46
47
 
47
- def _check_config(self):
48
- single_opt = ["statistics", "md5"]
48
+ self.tensor_list = json_config.get("tensor_list", [])
49
+ self._check_str_list_config(self.tensor_list, "tensor_list")
50
+ self.stat_cal_mode = json_config.get("device", "host")
51
+ self.device_stat_precision_mode = json_config.get("precision", "high")
52
+ self._check_stat_params()
53
+
54
+ def _check_stat_params(self):
55
+ if self.stat_cal_mode not in ["device", "host"]:
56
+ raise Exception("Config param [device] is invalid, expected from [\"device\", \"host\"]")
57
+ if self.device_stat_precision_mode not in ["high", "low"]:
58
+ raise Exception("Config param [precision] is invalid, expected from [\"high\", \"low\"]")
59
+
60
+ def _check_summary_mode(self):
49
61
  muti_opt = ["md5", "max", "min", "mean", "l2norm"]
50
- if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt:
62
+ if isinstance(self.summary_mode, str) and self.summary_mode not in Const.SUMMARY_MODE:
51
63
  raise Exception("summary_mode is invalid")
52
64
  if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode):
53
65
  raise Exception("summary_mode is invalid")
@@ -132,14 +144,3 @@ def parse_task_config(task, json_config):
132
144
  if task not in TaskDict:
133
145
  raise Exception("task is invalid.")
134
146
  return TaskDict.get(task)(task_map)
135
-
136
-
137
- def parse_json_config(json_file_path):
138
- if not json_file_path:
139
- raise Exception("json file path is None")
140
- json_config = load_json(json_file_path)
141
- common_config = parse_common_config(json_config)
142
- if not common_config.task:
143
- common_config.task = Const.STATISTICS
144
- task_config = parse_task_config(common_config.task, json_config)
145
- return common_config, task_config
@@ -29,11 +29,14 @@ class TaskHandlerFactory:
29
29
  }
30
30
 
31
31
  @staticmethod
32
- def create(config: DebuggerConfig):
32
+ def create(config: DebuggerConfig, model=None):
33
33
  task = TaskHandlerFactory.tasks.get(config.task)
34
34
  if not task:
35
35
  raise Exception("Valid task is needed.")
36
- handler = task.create(config)
36
+ if task == DumpToolFactory:
37
+ handler = task.create(config, model)
38
+ else:
39
+ handler = task.create(config)
37
40
  if not handler:
38
41
  raise Exception("Can not find task handler")
39
42
  return handler
msprobe/msprobe.py CHANGED
@@ -22,6 +22,8 @@ from msprobe.core.common.log import logger
22
22
  from msprobe.core.compare.utils import _compare_parser
23
23
  from msprobe.core.compare.compare_cli import compare_cli
24
24
  from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
25
+ from msprobe.core.config_check.config_check_cli import _config_checking_parser, \
26
+ _run_config_checking_command
25
27
 
26
28
 
27
29
  def is_module_available(module_name):
@@ -51,6 +53,9 @@ def main():
51
53
  graph_service_cmd_parser = subparsers.add_parser('graph')
52
54
  op_generate_cmd_parser = subparsers.add_parser('op_generate')
53
55
  merge_result_parser = subparsers.add_parser('merge_result')
56
+ config_checking_parser = subparsers.add_parser('config_check')
57
+ nan_analyze_parser = subparsers.add_parser('nan_analyze')
58
+ _config_checking_parser(config_checking_parser)
54
59
  _compare_parser(compare_cmd_parser)
55
60
  _merge_result_parser(merge_result_parser)
56
61
 
@@ -71,6 +76,7 @@ def main():
71
76
  from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
72
77
  from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
73
78
  _run_operator_generate_commond
79
+ from msprobe.nan_analyze.analyzer import _nan_analyze_parser, _run_nan_analyze
74
80
 
75
81
  _run_ut_parser(run_ut_cmd_parser)
76
82
  _run_ut_parser(multi_run_ut_cmd_parser)
@@ -80,6 +86,7 @@ def main():
80
86
  _run_overflow_check_parser(run_overflow_check_cmd_parser)
81
87
  _pt_graph_service_parser(graph_service_cmd_parser)
82
88
  _op_generator_parser(op_generate_cmd_parser)
89
+ _nan_analyze_parser(nan_analyze_parser)
83
90
  elif framework_args.framework == Const.MS_FRAMEWORK:
84
91
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
85
92
  from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
@@ -91,6 +98,10 @@ def main():
91
98
 
92
99
  _ms_graph_service_parser(graph_service_cmd_parser)
93
100
 
101
+ from msprobe.mindspore.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
102
+ _run_operator_generate_commond
103
+ _op_generator_parser(op_generate_cmd_parser)
104
+
94
105
  args = parser.parse_args(sys.argv[1:])
95
106
  if sys.argv[2] == Const.PT_FRAMEWORK:
96
107
  if not is_torch_available:
@@ -118,6 +129,10 @@ def main():
118
129
  compare_cli(args)
119
130
  elif sys.argv[3] == "merge_result":
120
131
  merge_result_cli(args)
132
+ elif sys.argv[3] == "config_check":
133
+ _run_config_checking_command(args)
134
+ elif sys.argv[3] == "nan_analyze":
135
+ _run_nan_analyze(args)
121
136
  else:
122
137
  if not is_module_available(Const.MS_FRAMEWORK):
123
138
  logger.error("MindSpore does not exist, please install MindSpore library")
@@ -134,9 +149,13 @@ def main():
134
149
  mul_api_checker_main(args)
135
150
  elif sys.argv[3] == "graph":
136
151
  _ms_graph_service_command(args)
152
+ elif sys.argv[3] == 'op_generate':
153
+ _run_operator_generate_commond(args)
137
154
  elif sys.argv[3] == "code_mapping":
138
155
  from msprobe.mindspore.code_mapping.main import code_mapping_main
139
156
  code_mapping_main(args)
157
+ elif sys.argv[3] == "config_check":
158
+ _run_config_checking_command(args)
140
159
 
141
160
 
142
161
  if __name__ == "__main__":
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.