mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
import abc
|
|
18
|
+
from mindspore import Tensor
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.log import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# 用于存储所有validator实现类的注册表
|
|
24
|
+
config_validator_registry = {}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def register_config_validator(cls):
|
|
28
|
+
"""装饰器 用于注册ConfigValidator的实现类"""
|
|
29
|
+
config_validator_registry[cls.__name__] = cls
|
|
30
|
+
return cls
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ConfigValidator(metaclass=abc.ABCMeta):
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def check_pattern_match(self, config_spec: str):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_config_validator
|
|
44
|
+
class TensorValidator(ConfigValidator):
|
|
45
|
+
def check_pattern_match(self, config_spec: str):
|
|
46
|
+
pattern = re.compile(r"tensor")
|
|
47
|
+
return pattern.match(config_spec)
|
|
48
|
+
|
|
49
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
50
|
+
if not isinstance(actual_data, Tensor):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@register_config_validator
|
|
56
|
+
class TupleValidator(ConfigValidator):
|
|
57
|
+
def check_pattern_match(self, config_spec: str):
|
|
58
|
+
pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
|
|
59
|
+
return pattern.match(config_spec)
|
|
60
|
+
|
|
61
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
62
|
+
length, index = pattern_match.groups()
|
|
63
|
+
if index is None:
|
|
64
|
+
index = 0
|
|
65
|
+
length, index = int(length), int(index)
|
|
66
|
+
|
|
67
|
+
if not (0 <= index < length):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
|
|
70
|
+
f"y must be greater than or equal to 0 and less than x.")
|
|
71
|
+
if not isinstance(actual_data, tuple):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
|
|
74
|
+
if len(actual_data) != length:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
|
|
77
|
+
f"actual is {len(actual_data)} please check.")
|
|
78
|
+
return index
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
|
|
82
|
+
focused_col = None
|
|
83
|
+
for _, validator_cls in config_validator_registry.items():
|
|
84
|
+
config_validator = validator_cls()
|
|
85
|
+
pattern_match = config_validator.check_pattern_match(config_spec)
|
|
86
|
+
if pattern_match:
|
|
87
|
+
try:
|
|
88
|
+
focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
|
|
89
|
+
except ValueError as e:
|
|
90
|
+
logger.warning(f"config spec validate failed: {str(e)}")
|
|
91
|
+
return focused_col
|
|
92
|
+
logger.warning(f"config spec in {module_name} {data_type} not supported, "
|
|
93
|
+
f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
|
|
94
|
+
return focused_col
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from mindspore import dtype as mstype, Tensor
|
|
17
|
+
|
|
18
|
+
from msprobe.mindspore.monitor.features import FUNC_MAP
|
|
19
|
+
from msprobe.core.common.const import MonitorConst
|
|
20
|
+
from msprobe.core.common.utils import is_int
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_single_metrics(op_list, tag, tensor, output=None):
|
|
25
|
+
if output is None:
|
|
26
|
+
output = {}
|
|
27
|
+
if tag not in output:
|
|
28
|
+
output[tag] = {}
|
|
29
|
+
for op in op_list:
|
|
30
|
+
func = FUNC_MAP.get(op)
|
|
31
|
+
statistic = func(tensor)
|
|
32
|
+
if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
|
|
33
|
+
statistic = float(statistic)
|
|
34
|
+
statistic = Tensor(statistic)
|
|
35
|
+
output[tag][op] = statistic.astype(mstype.float32)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_metrics(op_list, tag2tensor, eps, output=None):
|
|
39
|
+
if output is None:
|
|
40
|
+
output = {}
|
|
41
|
+
for tag, tensor in tag2tensor.items():
|
|
42
|
+
if tag not in output:
|
|
43
|
+
output[tag] = {}
|
|
44
|
+
get_single_metrics(op_list, tag, tensor, output)
|
|
45
|
+
return output
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
49
|
+
if rank is None:
|
|
50
|
+
return f"{module_or_param_name}/{tag}"
|
|
51
|
+
else:
|
|
52
|
+
return f"{module_or_param_name}/rank{rank}/{tag}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def step_accumulates_one(context, micro_batch_number):
|
|
56
|
+
"""
|
|
57
|
+
:param context: ModuleHookContext
|
|
58
|
+
:param micro_batch_number: mbs of training model.
|
|
59
|
+
:return:
|
|
60
|
+
"""
|
|
61
|
+
context.micro_step += 1
|
|
62
|
+
if context.micro_step == micro_batch_number:
|
|
63
|
+
context.micro_step = 0
|
|
64
|
+
context.step += 1
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_times=1e8):
|
|
68
|
+
"""
|
|
69
|
+
If current step less than start_step or not reach step_interval, skip current step.
|
|
70
|
+
:param step: current training step, int
|
|
71
|
+
:param start_step: int
|
|
72
|
+
:param step_interval: int
|
|
73
|
+
:return: whether skip or not, bool
|
|
74
|
+
"""
|
|
75
|
+
return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def validate_ops(ops):
|
|
79
|
+
if not isinstance(ops, list):
|
|
80
|
+
raise TypeError("ops should be a list")
|
|
81
|
+
valid_ops = []
|
|
82
|
+
for op in ops:
|
|
83
|
+
if op not in MonitorConst.OP_LIST:
|
|
84
|
+
logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
|
|
85
|
+
continue
|
|
86
|
+
valid_ops.append(op)
|
|
87
|
+
if not valid_ops:
|
|
88
|
+
default_op = MonitorConst.OP_LIST[0]
|
|
89
|
+
valid_ops.append(default_op)
|
|
90
|
+
logger.info(f"There is no valid ops, default op {default_op} is used")
|
|
91
|
+
return valid_ops
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def validate_ranks(ranks):
|
|
95
|
+
if not isinstance(ranks, list):
|
|
96
|
+
raise TypeError("module_ranks should be a list")
|
|
97
|
+
for rank in ranks:
|
|
98
|
+
if not isinstance(rank, str):
|
|
99
|
+
raise TypeError(f"element in module_ranks should be a str, get {type(rank)}")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def validate_targets(targets):
|
|
103
|
+
if not isinstance(targets, dict):
|
|
104
|
+
raise TypeError('targets in config.json should be a dict')
|
|
105
|
+
for module_name, field in targets.items():
|
|
106
|
+
if not isinstance(module_name, str):
|
|
107
|
+
raise TypeError('key of targets should be module_name[str] in config.json')
|
|
108
|
+
if not isinstance(field, dict):
|
|
109
|
+
raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def validate_print_struct(print_struct):
|
|
113
|
+
if not isinstance(print_struct, bool):
|
|
114
|
+
raise TypeError("print_struct should be a bool")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def validate_ur_distribution(ur_distribution):
|
|
118
|
+
if not isinstance(ur_distribution, bool):
|
|
119
|
+
raise TypeError('ur_distribution should be a bool')
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def validate_xy_distribution(xy_distribution):
|
|
123
|
+
if not isinstance(xy_distribution, bool):
|
|
124
|
+
raise TypeError('xy_distribution should be a bool')
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def validate_wg_distribution(wg_distribution):
|
|
128
|
+
if not isinstance(wg_distribution, bool):
|
|
129
|
+
raise TypeError('wg_distribution should be a bool')
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def validate_mg_distribution(mg_distribution):
|
|
133
|
+
if not isinstance(mg_distribution, bool):
|
|
134
|
+
raise TypeError('mg_distribution should be a bool')
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def validate_param_distribution(param_distribution):
|
|
138
|
+
if not isinstance(param_distribution, bool):
|
|
139
|
+
raise TypeError('param_distribution should be a bool')
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def validate_cc_distribution(cc_distribution):
|
|
143
|
+
if not isinstance(cc_distribution, dict):
|
|
144
|
+
raise TypeError('cc_distribution should be a dictionary')
|
|
145
|
+
expected_keys = {
|
|
146
|
+
'enable': bool,
|
|
147
|
+
'cc_codeline': list,
|
|
148
|
+
'cc_pre_hook': bool,
|
|
149
|
+
'cc_log_only': bool
|
|
150
|
+
}
|
|
151
|
+
for key, value in cc_distribution.items():
|
|
152
|
+
if key in expected_keys:
|
|
153
|
+
if not isinstance(value, expected_keys[key]):
|
|
154
|
+
raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}')
|
|
155
|
+
else:
|
|
156
|
+
raise TypeError(f'{key} of cc_distribution is not supported.')
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def validate_alert(alert):
|
|
160
|
+
if not isinstance(alert, dict):
|
|
161
|
+
raise TypeError('alert should be a dictionary')
|
|
162
|
+
rules = alert.get('rules')
|
|
163
|
+
if rules and isinstance(rules, list):
|
|
164
|
+
for rule in rules:
|
|
165
|
+
rule_name = rule.get("rule_name")
|
|
166
|
+
if rule_name and rule_name not in MonitorConst.RULE_NAME:
|
|
167
|
+
raise TypeError(f"{rule_name} is not supported")
|
|
168
|
+
args = rule.get("args")
|
|
169
|
+
if args and isinstance(args, dict):
|
|
170
|
+
threshold = args.get("threshold")
|
|
171
|
+
if not isinstance(threshold, float) or threshold < 0:
|
|
172
|
+
raise TypeError('threshold must be float and not less than 0')
|
|
173
|
+
dump = alert.get('dump')
|
|
174
|
+
if dump and not isinstance(dump, bool):
|
|
175
|
+
raise TypeError('dump must be bool.')
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def validate_step_count_per_record(step_count_per_record):
|
|
179
|
+
if not is_int(step_count_per_record):
|
|
180
|
+
raise TypeError('step_count_per_record must be int.')
|
|
181
|
+
if step_count_per_record < 1:
|
|
182
|
+
raise ValueError("step_count_per_record must greater than 0")
|
|
183
|
+
if step_count_per_record > 1e6:
|
|
184
|
+
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def validate_start_step(start_step):
|
|
188
|
+
if not is_int(start_step):
|
|
189
|
+
raise TypeError('start_step must be int.')
|
|
190
|
+
if start_step < 0:
|
|
191
|
+
raise ValueError("start_step must greater than 0")
|
|
192
|
+
if start_step > 1e8:
|
|
193
|
+
raise ValueError("start_step must smaller than 1e8")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def validate_step_interval(step_interval):
|
|
197
|
+
if not is_int(step_interval):
|
|
198
|
+
raise TypeError('step_interval must be int.')
|
|
199
|
+
if step_interval < 1:
|
|
200
|
+
raise ValueError("step_interval must greater than 1")
|
|
201
|
+
if step_interval > 1e8:
|
|
202
|
+
raise ValueError("step_interval must smaller than 1e8")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def validate_collect_times(collect_times):
|
|
206
|
+
if not is_int(collect_times):
|
|
207
|
+
raise TypeError('collect_times must be int.')
|
|
208
|
+
if collect_times < 1:
|
|
209
|
+
raise ValueError("collect_times must greater than 1")
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def validate_config(config):
|
|
213
|
+
config['ops'] = validate_ops(config.get('ops', []))
|
|
214
|
+
|
|
215
|
+
eps = config.get('eps', 1e-8)
|
|
216
|
+
if not isinstance(eps, float):
|
|
217
|
+
raise TypeError("eps should be a float")
|
|
218
|
+
|
|
219
|
+
ranks = config.get("module_ranks", [])
|
|
220
|
+
validate_ranks(ranks)
|
|
221
|
+
|
|
222
|
+
targets = config.get("targets", {})
|
|
223
|
+
validate_targets(targets)
|
|
224
|
+
|
|
225
|
+
print_struct = config.get('print_struct', False)
|
|
226
|
+
validate_print_struct(print_struct)
|
|
227
|
+
|
|
228
|
+
ur_distribution = config.get('ur_distribution', False)
|
|
229
|
+
validate_ur_distribution(ur_distribution)
|
|
230
|
+
|
|
231
|
+
xy_distribution = config.get('xy_distribution', False)
|
|
232
|
+
validate_xy_distribution(xy_distribution)
|
|
233
|
+
|
|
234
|
+
wg_distribution = config.get('wg_distribution', False)
|
|
235
|
+
validate_wg_distribution(wg_distribution)
|
|
236
|
+
|
|
237
|
+
mg_distribution = config.get('mg_distribution', False)
|
|
238
|
+
validate_mg_distribution(mg_distribution)
|
|
239
|
+
|
|
240
|
+
param_distribution = config.get('param_distribution', False)
|
|
241
|
+
validate_param_distribution(param_distribution)
|
|
242
|
+
|
|
243
|
+
cc_distribution = config.get('cc_distribution', {})
|
|
244
|
+
validate_cc_distribution(cc_distribution)
|
|
245
|
+
|
|
246
|
+
alert = config.get('alert', {})
|
|
247
|
+
validate_alert(alert)
|
|
248
|
+
|
|
249
|
+
step_count_per_record = config.get('step_count_per_record', 1)
|
|
250
|
+
validate_step_count_per_record(step_count_per_record)
|
|
251
|
+
|
|
252
|
+
start_step = config.get('start_step', 0)
|
|
253
|
+
validate_start_step(start_step)
|
|
254
|
+
|
|
255
|
+
step_interval = config.get('step_interval', 1)
|
|
256
|
+
validate_step_interval(step_interval)
|
|
257
|
+
|
|
258
|
+
collect_times = config.get('collect_times', 1e8)
|
|
259
|
+
validate_collect_times(collect_times)
|
|
260
|
+
|
|
261
|
+
if not targets:
|
|
262
|
+
if xy_distribution:
|
|
263
|
+
config["all_xy"] = True
|
|
264
|
+
config["targets"] = {"": {}}
|
|
265
|
+
config["is_select"] = False
|
|
266
|
+
else:
|
|
267
|
+
config["is_select"] = True
|
msprobe/mindspore/ms_config.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -45,7 +45,11 @@ class StatisticsConfig(BaseConfig):
|
|
|
45
45
|
self._check_config()
|
|
46
46
|
|
|
47
47
|
def _check_config(self):
|
|
48
|
-
|
|
48
|
+
single_opt = ["statistics", "md5"]
|
|
49
|
+
muti_opt = ["md5", "max", "min", "mean", "l2norm"]
|
|
50
|
+
if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt:
|
|
51
|
+
raise Exception("summary_mode is invalid")
|
|
52
|
+
if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode):
|
|
49
53
|
raise Exception("summary_mode is invalid")
|
|
50
54
|
|
|
51
55
|
|
|
@@ -102,12 +106,18 @@ class GradProbeConfig(BaseConfig):
|
|
|
102
106
|
check_numeral_list_ascend(self.bounds)
|
|
103
107
|
|
|
104
108
|
|
|
109
|
+
class StructureConfig(BaseConfig):
|
|
110
|
+
def __init__(self, json_config):
|
|
111
|
+
super().__init__(json_config)
|
|
112
|
+
|
|
113
|
+
|
|
105
114
|
TaskDict = {
|
|
106
115
|
Const.TENSOR: TensorConfig,
|
|
107
116
|
Const.STATISTICS: StatisticsConfig,
|
|
108
117
|
Const.OVERFLOW_CHECK: OverflowCheckConfig,
|
|
109
118
|
Const.FREE_BENCHMARK: FreeBenchmarkConfig,
|
|
110
|
-
Const.GRAD_PROBE: GradProbeConfig
|
|
119
|
+
Const.GRAD_PROBE: GradProbeConfig,
|
|
120
|
+
Const.STRUCTURE: StructureConfig
|
|
111
121
|
}
|
|
112
122
|
|
|
113
123
|
|
|
@@ -46,6 +46,13 @@ class KernelGraphOverflowCheck:
|
|
|
46
46
|
self.dump_json["common_dump_settings"]["op_debug_mode"] = 2
|
|
47
47
|
|
|
48
48
|
def handle(self):
|
|
49
|
+
try:
|
|
50
|
+
from msprobe.lib import _msprobe_c
|
|
51
|
+
return
|
|
52
|
+
except ImportError:
|
|
53
|
+
# 如果没有_msprobe_ce_c走MindSpore老流程
|
|
54
|
+
logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
|
|
55
|
+
|
|
49
56
|
if os.getenv("GRAPH_OP_RUN") == "1":
|
|
50
57
|
raise Exception("Must run in graph mode, not kbk mode")
|
|
51
58
|
json_path = self.dump_json["common_dump_settings"]["path"]
|