mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -0,0 +1,94 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import abc
18
+ from mindspore import Tensor
19
+
20
+ from msprobe.core.common.log import logger
21
+
22
+
23
+ # 用于存储所有validator实现类的注册表
24
+ config_validator_registry = {}
25
+
26
+
27
+ def register_config_validator(cls):
28
+ """装饰器 用于注册ConfigValidator的实现类"""
29
+ config_validator_registry[cls.__name__] = cls
30
+ return cls
31
+
32
+
33
+ class ConfigValidator(metaclass=abc.ABCMeta):
34
+ @abc.abstractmethod
35
+ def check_pattern_match(self, config_spec: str):
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
40
+ pass
41
+
42
+
43
+ @register_config_validator
44
+ class TensorValidator(ConfigValidator):
45
+ def check_pattern_match(self, config_spec: str):
46
+ pattern = re.compile(r"tensor")
47
+ return pattern.match(config_spec)
48
+
49
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
50
+ if not isinstance(actual_data, Tensor):
51
+ raise ValueError(
52
+ f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
53
+
54
+
55
+ @register_config_validator
56
+ class TupleValidator(ConfigValidator):
57
+ def check_pattern_match(self, config_spec: str):
58
+ pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
59
+ return pattern.match(config_spec)
60
+
61
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
62
+ length, index = pattern_match.groups()
63
+ if index is None:
64
+ index = 0
65
+ length, index = int(length), int(index)
66
+
67
+ if not (0 <= index < length):
68
+ raise ValueError(
69
+ f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
70
+ f"y must be greater than or equal to 0 and less than x.")
71
+ if not isinstance(actual_data, tuple):
72
+ raise ValueError(
73
+ f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
74
+ if len(actual_data) != length:
75
+ raise ValueError(
76
+ f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
77
+ f"actual is {len(actual_data)} please check.")
78
+ return index
79
+
80
+
81
+ def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
82
+ focused_col = None
83
+ for _, validator_cls in config_validator_registry.items():
84
+ config_validator = validator_cls()
85
+ pattern_match = config_validator.check_pattern_match(config_spec)
86
+ if pattern_match:
87
+ try:
88
+ focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
89
+ except ValueError as e:
90
+ logger.warning(f"config spec validate failed: {str(e)}")
91
+ return focused_col
92
+ logger.warning(f"config spec in {module_name} {data_type} not supported, "
93
+ f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
94
+ return focused_col
@@ -0,0 +1,267 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from mindspore import dtype as mstype, Tensor
17
+
18
+ from msprobe.mindspore.monitor.features import FUNC_MAP
19
+ from msprobe.core.common.const import MonitorConst
20
+ from msprobe.core.common.utils import is_int
21
+ from msprobe.core.common.log import logger
22
+
23
+
24
+ def get_single_metrics(op_list, tag, tensor, output=None):
25
+ if output is None:
26
+ output = {}
27
+ if tag not in output:
28
+ output[tag] = {}
29
+ for op in op_list:
30
+ func = FUNC_MAP.get(op)
31
+ statistic = func(tensor)
32
+ if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
33
+ statistic = float(statistic)
34
+ statistic = Tensor(statistic)
35
+ output[tag][op] = statistic.astype(mstype.float32)
36
+
37
+
38
+ def get_metrics(op_list, tag2tensor, eps, output=None):
39
+ if output is None:
40
+ output = {}
41
+ for tag, tensor in tag2tensor.items():
42
+ if tag not in output:
43
+ output[tag] = {}
44
+ get_single_metrics(op_list, tag, tensor, output)
45
+ return output
46
+
47
+
48
+ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
49
+ if rank is None:
50
+ return f"{module_or_param_name}/{tag}"
51
+ else:
52
+ return f"{module_or_param_name}/rank{rank}/{tag}"
53
+
54
+
55
+ def step_accumulates_one(context, micro_batch_number):
56
+ """
57
+ :param context: ModuleHookContext
58
+ :param micro_batch_number: mbs of training model.
59
+ :return:
60
+ """
61
+ context.micro_step += 1
62
+ if context.micro_step == micro_batch_number:
63
+ context.micro_step = 0
64
+ context.step += 1
65
+
66
+
67
+ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_times=1e8):
68
+ """
69
+ If current step less than start_step or not reach step_interval, skip current step.
70
+ :param step: current training step, int
71
+ :param start_step: int
72
+ :param step_interval: int
73
+ :return: whether skip or not, bool
74
+ """
75
+ return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
76
+
77
+
78
+ def validate_ops(ops):
79
+ if not isinstance(ops, list):
80
+ raise TypeError("ops should be a list")
81
+ valid_ops = []
82
+ for op in ops:
83
+ if op not in MonitorConst.OP_LIST:
84
+ logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
85
+ continue
86
+ valid_ops.append(op)
87
+ if not valid_ops:
88
+ default_op = MonitorConst.OP_LIST[0]
89
+ valid_ops.append(default_op)
90
+ logger.info(f"There is no valid ops, default op {default_op} is used")
91
+ return valid_ops
92
+
93
+
94
+ def validate_ranks(ranks):
95
+ if not isinstance(ranks, list):
96
+ raise TypeError("module_ranks should be a list")
97
+ for rank in ranks:
98
+ if not isinstance(rank, str):
99
+ raise TypeError(f"element in module_ranks should be a str, get {type(rank)}")
100
+
101
+
102
+ def validate_targets(targets):
103
+ if not isinstance(targets, dict):
104
+ raise TypeError('targets in config.json should be a dict')
105
+ for module_name, field in targets.items():
106
+ if not isinstance(module_name, str):
107
+ raise TypeError('key of targets should be module_name[str] in config.json')
108
+ if not isinstance(field, dict):
109
+ raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
110
+
111
+
112
+ def validate_print_struct(print_struct):
113
+ if not isinstance(print_struct, bool):
114
+ raise TypeError("print_struct should be a bool")
115
+
116
+
117
+ def validate_ur_distribution(ur_distribution):
118
+ if not isinstance(ur_distribution, bool):
119
+ raise TypeError('ur_distribution should be a bool')
120
+
121
+
122
+ def validate_xy_distribution(xy_distribution):
123
+ if not isinstance(xy_distribution, bool):
124
+ raise TypeError('xy_distribution should be a bool')
125
+
126
+
127
+ def validate_wg_distribution(wg_distribution):
128
+ if not isinstance(wg_distribution, bool):
129
+ raise TypeError('wg_distribution should be a bool')
130
+
131
+
132
+ def validate_mg_distribution(mg_distribution):
133
+ if not isinstance(mg_distribution, bool):
134
+ raise TypeError('mg_distribution should be a bool')
135
+
136
+
137
+ def validate_param_distribution(param_distribution):
138
+ if not isinstance(param_distribution, bool):
139
+ raise TypeError('param_distribution should be a bool')
140
+
141
+
142
+ def validate_cc_distribution(cc_distribution):
143
+ if not isinstance(cc_distribution, dict):
144
+ raise TypeError('cc_distribution should be a dictionary')
145
+ expected_keys = {
146
+ 'enable': bool,
147
+ 'cc_codeline': list,
148
+ 'cc_pre_hook': bool,
149
+ 'cc_log_only': bool
150
+ }
151
+ for key, value in cc_distribution.items():
152
+ if key in expected_keys:
153
+ if not isinstance(value, expected_keys[key]):
154
+ raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}')
155
+ else:
156
+ raise TypeError(f'{key} of cc_distribution is not supported.')
157
+
158
+
159
+ def validate_alert(alert):
160
+ if not isinstance(alert, dict):
161
+ raise TypeError('alert should be a dictionary')
162
+ rules = alert.get('rules')
163
+ if rules and isinstance(rules, list):
164
+ for rule in rules:
165
+ rule_name = rule.get("rule_name")
166
+ if rule_name and rule_name not in MonitorConst.RULE_NAME:
167
+ raise TypeError(f"{rule_name} is not supported")
168
+ args = rule.get("args")
169
+ if args and isinstance(args, dict):
170
+ threshold = args.get("threshold")
171
+ if not isinstance(threshold, float) or threshold < 0:
172
+ raise TypeError('threshold must be float and not less than 0')
173
+ dump = alert.get('dump')
174
+ if dump and not isinstance(dump, bool):
175
+ raise TypeError('dump must be bool.')
176
+
177
+
178
+ def validate_step_count_per_record(step_count_per_record):
179
+ if not is_int(step_count_per_record):
180
+ raise TypeError('step_count_per_record must be int.')
181
+ if step_count_per_record < 1:
182
+ raise ValueError("step_count_per_record must greater than 0")
183
+ if step_count_per_record > 1e6:
184
+ raise ValueError("step_count_per_record must smaller than 1e6")
185
+
186
+
187
+ def validate_start_step(start_step):
188
+ if not is_int(start_step):
189
+ raise TypeError('start_step must be int.')
190
+ if start_step < 0:
191
+ raise ValueError("start_step must greater than 0")
192
+ if start_step > 1e8:
193
+ raise ValueError("start_step must smaller than 1e8")
194
+
195
+
196
+ def validate_step_interval(step_interval):
197
+ if not is_int(step_interval):
198
+ raise TypeError('step_interval must be int.')
199
+ if step_interval < 1:
200
+ raise ValueError("step_interval must greater than 1")
201
+ if step_interval > 1e8:
202
+ raise ValueError("step_interval must smaller than 1e8")
203
+
204
+
205
+ def validate_collect_times(collect_times):
206
+ if not is_int(collect_times):
207
+ raise TypeError('collect_times must be int.')
208
+ if collect_times < 1:
209
+ raise ValueError("collect_times must greater than 1")
210
+
211
+
212
+ def validate_config(config):
213
+ config['ops'] = validate_ops(config.get('ops', []))
214
+
215
+ eps = config.get('eps', 1e-8)
216
+ if not isinstance(eps, float):
217
+ raise TypeError("eps should be a float")
218
+
219
+ ranks = config.get("module_ranks", [])
220
+ validate_ranks(ranks)
221
+
222
+ targets = config.get("targets", {})
223
+ validate_targets(targets)
224
+
225
+ print_struct = config.get('print_struct', False)
226
+ validate_print_struct(print_struct)
227
+
228
+ ur_distribution = config.get('ur_distribution', False)
229
+ validate_ur_distribution(ur_distribution)
230
+
231
+ xy_distribution = config.get('xy_distribution', False)
232
+ validate_xy_distribution(xy_distribution)
233
+
234
+ wg_distribution = config.get('wg_distribution', False)
235
+ validate_wg_distribution(wg_distribution)
236
+
237
+ mg_distribution = config.get('mg_distribution', False)
238
+ validate_mg_distribution(mg_distribution)
239
+
240
+ param_distribution = config.get('param_distribution', False)
241
+ validate_param_distribution(param_distribution)
242
+
243
+ cc_distribution = config.get('cc_distribution', {})
244
+ validate_cc_distribution(cc_distribution)
245
+
246
+ alert = config.get('alert', {})
247
+ validate_alert(alert)
248
+
249
+ step_count_per_record = config.get('step_count_per_record', 1)
250
+ validate_step_count_per_record(step_count_per_record)
251
+
252
+ start_step = config.get('start_step', 0)
253
+ validate_start_step(start_step)
254
+
255
+ step_interval = config.get('step_interval', 1)
256
+ validate_step_interval(step_interval)
257
+
258
+ collect_times = config.get('collect_times', 1e8)
259
+ validate_collect_times(collect_times)
260
+
261
+ if not targets:
262
+ if xy_distribution:
263
+ config["all_xy"] = True
264
+ config["targets"] = {"": {}}
265
+ config["is_select"] = False
266
+ else:
267
+ config["is_select"] = True
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -45,7 +45,11 @@ class StatisticsConfig(BaseConfig):
45
45
  self._check_config()
46
46
 
47
47
  def _check_config(self):
48
- if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
48
+ single_opt = ["statistics", "md5"]
49
+ muti_opt = ["md5", "max", "min", "mean", "l2norm"]
50
+ if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt:
51
+ raise Exception("summary_mode is invalid")
52
+ if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode):
49
53
  raise Exception("summary_mode is invalid")
50
54
 
51
55
 
@@ -102,12 +106,18 @@ class GradProbeConfig(BaseConfig):
102
106
  check_numeral_list_ascend(self.bounds)
103
107
 
104
108
 
109
+ class StructureConfig(BaseConfig):
110
+ def __init__(self, json_config):
111
+ super().__init__(json_config)
112
+
113
+
105
114
  TaskDict = {
106
115
  Const.TENSOR: TensorConfig,
107
116
  Const.STATISTICS: StatisticsConfig,
108
117
  Const.OVERFLOW_CHECK: OverflowCheckConfig,
109
118
  Const.FREE_BENCHMARK: FreeBenchmarkConfig,
110
- Const.GRAD_PROBE: GradProbeConfig
119
+ Const.GRAD_PROBE: GradProbeConfig,
120
+ Const.STRUCTURE: StructureConfig
111
121
  }
112
122
 
113
123
 
@@ -46,6 +46,13 @@ class KernelGraphOverflowCheck:
46
46
  self.dump_json["common_dump_settings"]["op_debug_mode"] = 2
47
47
 
48
48
  def handle(self):
49
+ try:
50
+ from msprobe.lib import _msprobe_c
51
+ return
52
+ except ImportError:
53
+ # 如果没有_msprobe_ce_c走MindSpore老流程
54
+ logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
55
+
49
56
  if os.getenv("GRAPH_OP_RUN") == "1":
50
57
  raise Exception("Must run in graph mode, not kbk mode")
51
58
  json_path = self.dump_json["common_dump_settings"]["path"]