mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import abc
18
+ from mindspore import Tensor
19
+
20
+ from msprobe.core.common.log import logger
21
+
22
+
23
+ # 用于存储所有validator实现类的注册表
24
+ config_validator_registry = {}
25
+
26
+
27
+ def register_config_validator(cls):
28
+ """装饰器 用于注册ConfigValidator的实现类"""
29
+ config_validator_registry[cls.__name__] = cls
30
+ return cls
31
+
32
+
33
+ class ConfigValidator(metaclass=abc.ABCMeta):
34
+ @abc.abstractmethod
35
+ def check_pattern_match(self, config_spec: str):
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
40
+ pass
41
+
42
+
43
+ @register_config_validator
44
+ class TensorValidator(ConfigValidator):
45
+ def check_pattern_match(self, config_spec: str):
46
+ pattern = re.compile(r"tensor")
47
+ return pattern.match(config_spec)
48
+
49
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
50
+ if not isinstance(actual_data, Tensor):
51
+ raise ValueError(
52
+ f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
53
+
54
+
55
+ @register_config_validator
56
+ class TupleValidator(ConfigValidator):
57
+ def check_pattern_match(self, config_spec: str):
58
+ pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
59
+ return pattern.match(config_spec)
60
+
61
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
62
+ length, index = pattern_match.groups()
63
+ if index is None:
64
+ index = 0
65
+ length, index = int(length), int(index)
66
+
67
+ if not (0 <= index < length):
68
+ raise ValueError(
69
+ f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
70
+ f"y must be greater than or equal to 0 and less than x.")
71
+ if not isinstance(actual_data, tuple):
72
+ raise ValueError(
73
+ f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
74
+ if len(actual_data) != length:
75
+ raise ValueError(
76
+ f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
77
+ f"actual is {len(actual_data)} please check.")
78
+ return index
79
+
80
+
81
+ def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
82
+ focused_col = None
83
+ for _, validator_cls in config_validator_registry.items():
84
+ config_validator = validator_cls()
85
+ pattern_match = config_validator.check_pattern_match(config_spec)
86
+ if pattern_match:
87
+ try:
88
+ focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
89
+ except ValueError as e:
90
+ logger.warning(f"config spec validate failed: {str(e)}")
91
+ return focused_col
92
+ logger.warning(f"config spec in {module_name} {data_type} not supported, "
93
+ f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
94
+ return focused_col
@@ -0,0 +1,309 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from mindspore import dtype as mstype, Tensor
19
+
20
+ from msprobe.mindspore.monitor.features import FUNC_MAP
21
+ from msprobe.core.common.const import MonitorConst
22
+ from msprobe.core.common.utils import is_int
23
+ from msprobe.core.common.log import logger
24
+ from msprobe.core.common.file_utils import check_file_or_directory_path
25
+
26
+
27
+ def get_single_metrics(op_list, tag, tensor, output=None):
28
+ if output is None:
29
+ output = {}
30
+ if tag not in output:
31
+ output[tag] = {}
32
+ for op in op_list:
33
+ func = FUNC_MAP.get(op)
34
+ statistic = func(tensor)
35
+ if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
36
+ statistic = float(statistic)
37
+ statistic = Tensor(statistic)
38
+ output[tag][op] = statistic.astype(mstype.float32)
39
+
40
+
41
+ def get_metrics(op_list, tag2tensor, eps, output=None):
42
+ if output is None:
43
+ output = {}
44
+ for tag, tensor in tag2tensor.items():
45
+ if tag not in output:
46
+ output[tag] = {}
47
+ get_single_metrics(op_list, tag, tensor, output)
48
+ return output
49
+
50
+
51
+ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
52
+ if rank is None:
53
+ return f"{module_or_param_name}/{tag}"
54
+ else:
55
+ return f"{module_or_param_name}/rank{rank}/{tag}"
56
+
57
+
58
+ def step_accumulates_one(context, micro_batch_number):
59
+ """
60
+ :param context: ModuleHookContext
61
+ :param micro_batch_number: mbs of training model.
62
+ :return:
63
+ """
64
+ context.micro_step += 1
65
+ if context.micro_step == micro_batch_number:
66
+ context.micro_step = 0
67
+ context.step += 1
68
+
69
+
70
+ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_times=1e8):
71
+ """
72
+ If current step less than start_step or not reach step_interval, skip current step.
73
+ :param step: current training step, int
74
+ :param start_step: int
75
+ :param step_interval: int
76
+ :return: whether skip or not, bool
77
+ """
78
+ return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
79
+
80
+
81
+ def validate_ops(ops):
82
+ if not isinstance(ops, list):
83
+ raise TypeError("ops should be a list")
84
+ valid_ops = []
85
+ for op in ops:
86
+ if op not in MonitorConst.OP_LIST:
87
+ logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
88
+ continue
89
+ valid_ops.append(op)
90
+ if not valid_ops:
91
+ default_op = MonitorConst.OP_LIST[0]
92
+ valid_ops.append(default_op)
93
+ logger.info(f"There is no valid ops, default op {default_op} is used")
94
+ return valid_ops
95
+
96
+
97
+ def validate_ranks(ranks):
98
+ if not isinstance(ranks, list):
99
+ raise TypeError("module_ranks should be a list")
100
+ for rank in ranks:
101
+ if not isinstance(rank, int):
102
+ raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
103
+
104
+
105
+ def validate_targets(targets):
106
+ if not isinstance(targets, dict):
107
+ raise TypeError('targets in config.json should be a dict')
108
+ for module_name, field in targets.items():
109
+ if not isinstance(module_name, str):
110
+ raise TypeError('key of targets should be module_name[str] in config.json')
111
+ if not isinstance(field, dict):
112
+ raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
113
+
114
+
115
+ def validate_print_struct(print_struct):
116
+ if not isinstance(print_struct, bool):
117
+ raise TypeError("print_struct should be a bool")
118
+
119
+
120
+ def validate_ur_distribution(ur_distribution):
121
+ if not isinstance(ur_distribution, bool):
122
+ raise TypeError('ur_distribution should be a bool')
123
+
124
+
125
+ def validate_xy_distribution(xy_distribution):
126
+ if not isinstance(xy_distribution, bool):
127
+ raise TypeError('xy_distribution should be a bool')
128
+
129
+
130
+ def validate_wg_distribution(wg_distribution):
131
+ if not isinstance(wg_distribution, bool):
132
+ raise TypeError('wg_distribution should be a bool')
133
+
134
+
135
+ def validate_mg_distribution(mg_distribution):
136
+ if not isinstance(mg_distribution, bool):
137
+ raise TypeError('mg_distribution should be a bool')
138
+
139
+
140
+ def validate_param_distribution(param_distribution):
141
+ if not isinstance(param_distribution, bool):
142
+ raise TypeError('param_distribution should be a bool')
143
+
144
+
145
+ def validate_cc_distribution(cc_distribution):
146
+ if not isinstance(cc_distribution, dict):
147
+ raise TypeError('cc_distribution should be a dictionary')
148
+ expected_keys = {
149
+ 'enable': bool,
150
+ 'cc_codeline': list,
151
+ 'cc_pre_hook': bool,
152
+ 'cc_log_only': bool
153
+ }
154
+ for key, value in cc_distribution.items():
155
+ if key in expected_keys:
156
+ if not isinstance(value, expected_keys[key]):
157
+ raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}')
158
+ else:
159
+ raise TypeError(f'{key} of cc_distribution is not supported.')
160
+
161
+
162
+ def validate_alert(alert):
163
+ if not isinstance(alert, dict):
164
+ raise TypeError('alert should be a dictionary')
165
+ rules = alert.get('rules')
166
+ if rules and isinstance(rules, list):
167
+ for rule in rules:
168
+ rule_name = rule.get("rule_name")
169
+ if rule_name and rule_name not in MonitorConst.RULE_NAME:
170
+ raise TypeError(f"{rule_name} is not supported")
171
+ args = rule.get("args")
172
+ if args and isinstance(args, dict):
173
+ threshold = args.get("threshold")
174
+ if not isinstance(threshold, float) or threshold < 0:
175
+ raise TypeError('threshold must be float and not less than 0')
176
+ dump = alert.get('dump')
177
+ if dump and not isinstance(dump, bool):
178
+ raise TypeError('dump must be bool.')
179
+
180
+
181
+ def validate_step_count_per_record(step_count_per_record):
182
+ if not is_int(step_count_per_record):
183
+ raise TypeError('step_count_per_record must be int.')
184
+ if step_count_per_record < 1:
185
+ raise ValueError("step_count_per_record must greater than 0")
186
+ if step_count_per_record > 1e6:
187
+ raise ValueError("step_count_per_record must smaller than 1e6")
188
+
189
+
190
+ def validate_start_step(start_step):
191
+ if not is_int(start_step):
192
+ raise TypeError('start_step must be int.')
193
+ if start_step < 0:
194
+ raise ValueError("start_step must greater than 0")
195
+ if start_step > 1e8:
196
+ raise ValueError("start_step must smaller than 1e8")
197
+
198
+
199
+ def validate_step_interval(step_interval):
200
+ if not is_int(step_interval):
201
+ raise TypeError('step_interval must be int.')
202
+ if step_interval < 1:
203
+ raise ValueError("step_interval must greater than 1")
204
+ if step_interval > 1e8:
205
+ raise ValueError("step_interval must smaller than 1e8")
206
+
207
+
208
+ def validate_collect_times(collect_times):
209
+ if not is_int(collect_times):
210
+ raise TypeError('collect_times must be int.')
211
+ if collect_times < 1:
212
+ raise ValueError("collect_times must greater than 1")
213
+
214
+
215
+ def validate_dynamic_on(dynamic_on):
216
+ if not isinstance(dynamic_on, bool):
217
+ raise TypeError('dynamic_on should be a bool')
218
+
219
+
220
+ def validate_config(config):
221
+ config['ops'] = validate_ops(config.get('ops', []))
222
+
223
+ eps = config.get('eps', 1e-8)
224
+ if not isinstance(eps, float):
225
+ raise TypeError("eps should be a float")
226
+
227
+ ranks = config.get("module_ranks", [])
228
+ validate_ranks(ranks)
229
+
230
+ targets = config.get("targets", {})
231
+ validate_targets(targets)
232
+
233
+ print_struct = config.get('print_struct', False)
234
+ validate_print_struct(print_struct)
235
+
236
+ ur_distribution = config.get('ur_distribution', False)
237
+ validate_ur_distribution(ur_distribution)
238
+
239
+ xy_distribution = config.get('xy_distribution', False)
240
+ validate_xy_distribution(xy_distribution)
241
+
242
+ wg_distribution = config.get('wg_distribution', False)
243
+ validate_wg_distribution(wg_distribution)
244
+
245
+ mg_distribution = config.get('mg_distribution', False)
246
+ validate_mg_distribution(mg_distribution)
247
+
248
+ param_distribution = config.get('param_distribution', False)
249
+ validate_param_distribution(param_distribution)
250
+
251
+ cc_distribution = config.get('cc_distribution', {})
252
+ validate_cc_distribution(cc_distribution)
253
+
254
+ alert = config.get('alert', {})
255
+ validate_alert(alert)
256
+
257
+ step_count_per_record = config.get('step_count_per_record', 1)
258
+ validate_step_count_per_record(step_count_per_record)
259
+
260
+ start_step = config.get('start_step', 0)
261
+ validate_start_step(start_step)
262
+
263
+ step_interval = config.get('step_interval', 1)
264
+ validate_step_interval(step_interval)
265
+
266
+ collect_times = config.get('collect_times', int(1e8))
267
+ validate_collect_times(collect_times)
268
+
269
+ dynamic_on = config.get('dynamic_on', False)
270
+ validate_dynamic_on(dynamic_on)
271
+
272
+ if not targets:
273
+ if xy_distribution:
274
+ config["all_xy"] = True
275
+ config["targets"] = {"": {}}
276
+ config["is_select"] = False
277
+ else:
278
+ config["is_select"] = True
279
+
280
+
281
+ def time_str2time_digit(time_str):
282
+ time_format = '%b%d_%H-%M-%S'
283
+ try:
284
+ time_digit = datetime.strptime(time_str, time_format)
285
+ except Exception as e:
286
+ raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
287
+ of existing output dirpath, like 'Dec03_21-34-40'.") from e
288
+ return time_digit
289
+
290
+
291
+ def get_target_output_dir(monitor_path, time_start, time_end):
292
+ check_file_or_directory_path(monitor_path, isdir=True)
293
+ time_start = time_str2time_digit(time_start) if time_start is not None else time_start
294
+ time_end = time_str2time_digit(time_end) if time_end is not None else time_end
295
+ if time_start and time_end and time_start > time_end:
296
+ raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
297
+ result = {}
298
+ for dirname in os.listdir(monitor_path):
299
+ match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
300
+ if not match:
301
+ continue
302
+ time_tag = match.group(1)
303
+ rank = match.group(2)
304
+ target_time = time_str2time_digit(time_tag)
305
+ start_ok = time_start is None or target_time >= time_start
306
+ end_ok = time_end is None or target_time <= time_end
307
+ if start_ok and end_ok:
308
+ result[rank] = os.path.join(monitor_path, dirname)
309
+ return result
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -106,12 +106,18 @@ class GradProbeConfig(BaseConfig):
106
106
  check_numeral_list_ascend(self.bounds)
107
107
 
108
108
 
109
+ class StructureConfig(BaseConfig):
110
+ def __init__(self, json_config):
111
+ super().__init__(json_config)
112
+
113
+
109
114
  TaskDict = {
110
115
  Const.TENSOR: TensorConfig,
111
116
  Const.STATISTICS: StatisticsConfig,
112
117
  Const.OVERFLOW_CHECK: OverflowCheckConfig,
113
118
  Const.FREE_BENCHMARK: FreeBenchmarkConfig,
114
- Const.GRAD_PROBE: GradProbeConfig
119
+ Const.GRAD_PROBE: GradProbeConfig,
120
+ Const.STRUCTURE: StructureConfig
115
121
  }
116
122
 
117
123
 
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from msprobe.core.common.log import logger
16
17
  from msprobe.mindspore.common.const import Const
17
18
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
19
  from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
@@ -44,6 +45,7 @@ class OverflowCheckToolFactory:
44
45
  raise Exception("Valid level is needed.")
45
46
  tool = tool.get(config.execution_mode)
46
47
  if not tool:
47
- raise Exception(f"Overflow check is not supported in {config.execution_mode} mode "
48
- f"when level is {config.level}.")
48
+ logger.error(f"Overflow check is not supported in {config.execution_mode} mode "
49
+ f"when level is {config.level}.")
50
+ raise ValueError
49
51
  return tool(config)