mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Dict, Any, Optional, Callable, Union, List, Tuple
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_attr(module, attr_name):
|
|
23
|
+
if Const.SEP in attr_name:
|
|
24
|
+
sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1)
|
|
25
|
+
sub_module = getattr(module, sub_module_name, None)
|
|
26
|
+
attr = getattr(sub_module, sub_attr, None)
|
|
27
|
+
else:
|
|
28
|
+
attr = getattr(module, attr_name, None)
|
|
29
|
+
return attr
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ApiWrapper:
|
|
33
|
+
def __init__(
|
|
34
|
+
self, api_types: Dict[str, Dict[str, Any]],
|
|
35
|
+
api_list_paths: Union[str, List[str], Tuple[str]]
|
|
36
|
+
):
|
|
37
|
+
self.api_types = api_types
|
|
38
|
+
if not isinstance(api_list_paths, (list, tuple)):
|
|
39
|
+
api_list_paths = [api_list_paths] * len(self.api_types)
|
|
40
|
+
elif len(api_list_paths) != len(self.api_types):
|
|
41
|
+
raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
|
|
42
|
+
"when api_list_paths is a list or tuple.")
|
|
43
|
+
self.api_list_paths = api_list_paths
|
|
44
|
+
self.api_names = self._get_api_names()
|
|
45
|
+
self.wrapped_api_functions = dict()
|
|
46
|
+
|
|
47
|
+
def wrap_api(
|
|
48
|
+
self, api_templates, hook_build_func: Optional[Callable]
|
|
49
|
+
):
|
|
50
|
+
api_types_num = sum([len(v) for v in self.api_types.values()])
|
|
51
|
+
if not isinstance(api_templates, (list, tuple)):
|
|
52
|
+
api_templates = [api_templates] * api_types_num
|
|
53
|
+
elif len(api_templates) != api_types_num:
|
|
54
|
+
raise RuntimeError("The number of api_templates must be equal to the number of api_types, "
|
|
55
|
+
"when api_templates is a list or tuple.")
|
|
56
|
+
|
|
57
|
+
self.wrapped_api_functions.clear()
|
|
58
|
+
index = 0
|
|
59
|
+
for framework, api_types in self.api_types.items():
|
|
60
|
+
wrapped_functions_in_framework = dict()
|
|
61
|
+
for api_type, api_modules in api_types.items():
|
|
62
|
+
wrapped_functions = dict()
|
|
63
|
+
name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API")
|
|
64
|
+
api_template = api_templates[index]
|
|
65
|
+
index += 1
|
|
66
|
+
for api_name in self.api_names.get(framework, {}).get(api_type, []):
|
|
67
|
+
ori_api = _get_attr(api_modules[0], api_name)
|
|
68
|
+
if callable(ori_api):
|
|
69
|
+
def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
|
|
70
|
+
def api_function(*args, **kwargs):
|
|
71
|
+
return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
|
|
72
|
+
api_function.__name__ = api_name
|
|
73
|
+
return api_function
|
|
74
|
+
wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
|
|
75
|
+
hook_build_func, api_template)
|
|
76
|
+
wrapped_functions_in_framework[api_type] = wrapped_functions
|
|
77
|
+
self.wrapped_api_functions[framework] = wrapped_functions_in_framework
|
|
78
|
+
return self.wrapped_api_functions
|
|
79
|
+
|
|
80
|
+
def _get_api_names(self):
|
|
81
|
+
api_names = dict()
|
|
82
|
+
|
|
83
|
+
for index, framework in enumerate(self.api_types.keys()):
|
|
84
|
+
api_list = load_yaml(self.api_list_paths[index])
|
|
85
|
+
valid_names = dict()
|
|
86
|
+
for api_type, api_modules in self.api_types.get(framework, {}).items():
|
|
87
|
+
api_from_file = api_list.get(Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type), [])
|
|
88
|
+
names = set()
|
|
89
|
+
for api_name in api_from_file:
|
|
90
|
+
target_attr = api_name
|
|
91
|
+
target_module = api_modules[0]
|
|
92
|
+
if Const.SEP in api_name:
|
|
93
|
+
sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
|
|
94
|
+
target_module = getattr(api_modules[0], sub_module_name, None)
|
|
95
|
+
if target_module and target_attr in dir(target_module):
|
|
96
|
+
names.add(api_name)
|
|
97
|
+
valid_names[api_type] = names
|
|
98
|
+
api_names[framework] = valid_names
|
|
99
|
+
|
|
100
|
+
return api_names
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ApiRegistry:
|
|
104
|
+
"""
|
|
105
|
+
Base class for api registry.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates):
|
|
109
|
+
self.ori_api_attr = dict()
|
|
110
|
+
self.wrapped_api_attr = dict()
|
|
111
|
+
self.inner_used_ori_attr = dict()
|
|
112
|
+
self.inner_used_wrapped_attr = dict()
|
|
113
|
+
self.api_types = api_types
|
|
114
|
+
self.inner_used_api = inner_used_api
|
|
115
|
+
self.supported_api_list_path = supported_api_list_path
|
|
116
|
+
self.api_templates = api_templates
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
|
|
120
|
+
for api in api_list:
|
|
121
|
+
api_ori_attr[api] = _get_attr(ori_api_group, api)
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def set_api_attr(api_group, attr_dict):
|
|
125
|
+
for api, api_attr in attr_dict.items():
|
|
126
|
+
if Const.SEP in api:
|
|
127
|
+
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
128
|
+
sub_module = getattr(api_group, sub_module_name, None)
|
|
129
|
+
if sub_module is not None:
|
|
130
|
+
setattr(sub_module, sub_op, api_attr)
|
|
131
|
+
else:
|
|
132
|
+
setattr(api_group, api, api_attr)
|
|
133
|
+
|
|
134
|
+
def register_all_api(self):
|
|
135
|
+
for framework, api_types in self.api_types.items():
|
|
136
|
+
for api_type, api_modules in api_types.items():
|
|
137
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
138
|
+
for module in api_modules[1]:
|
|
139
|
+
self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {}))
|
|
140
|
+
|
|
141
|
+
def register_inner_used_api(self):
|
|
142
|
+
for api_type in self.inner_used_api.keys():
|
|
143
|
+
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
|
|
144
|
+
|
|
145
|
+
def restore_all_api(self):
|
|
146
|
+
for framework, api_types in self.api_types.items():
|
|
147
|
+
for api_type, api_modules in api_types.items():
|
|
148
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
149
|
+
for module in api_modules[1]:
|
|
150
|
+
self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {}))
|
|
151
|
+
|
|
152
|
+
def restore_inner_used_api(self):
|
|
153
|
+
for api_type in self.inner_used_api.keys():
|
|
154
|
+
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
|
|
155
|
+
|
|
156
|
+
def initialize_hook(self, hook_build_func):
|
|
157
|
+
api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path)
|
|
158
|
+
wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
|
|
159
|
+
|
|
160
|
+
for framework, api_types in self.api_types.items():
|
|
161
|
+
for api_type, api_modules in api_types.items():
|
|
162
|
+
ori_attr = dict()
|
|
163
|
+
self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr)
|
|
164
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
165
|
+
self.ori_api_attr[api_type_with_framework] = ori_attr
|
|
166
|
+
self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type)
|
|
167
|
+
|
|
168
|
+
for inner_used_api_type, inner_used_api_list in self.inner_used_api.items():
|
|
169
|
+
ori_attr = dict()
|
|
170
|
+
wrapped_attr = dict()
|
|
171
|
+
for api_name in inner_used_api_list[1:]:
|
|
172
|
+
if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name):
|
|
173
|
+
ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name)
|
|
174
|
+
wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name)
|
|
175
|
+
self.inner_used_ori_attr[inner_used_api_type] = ori_attr
|
|
176
|
+
self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr
|
|
@@ -252,8 +252,8 @@ class BaseDataProcessor:
|
|
|
252
252
|
|
|
253
253
|
@classmethod
|
|
254
254
|
def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
|
|
255
|
-
if depth > Const.
|
|
256
|
-
logger.error(f"The maximum depth of recursive transform, {Const.
|
|
255
|
+
if depth > Const.DUMP_MAX_DEPTH:
|
|
256
|
+
logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.")
|
|
257
257
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
258
258
|
if isinstance(args, cls.get_special_types()):
|
|
259
259
|
arg_transform = transform(args, cls._recursive_key_stack)
|
|
@@ -26,7 +26,7 @@ from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, Tenso
|
|
|
26
26
|
from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
|
|
27
27
|
from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
|
|
28
28
|
from msprobe.mindspore.common.log import logger
|
|
29
|
-
from msprobe.mindspore.dump.hook_cell.
|
|
29
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
30
30
|
|
|
31
31
|
has_adump = True
|
|
32
32
|
try:
|
|
@@ -44,6 +44,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
44
44
|
"dtype": self.analyze_dtype_in_kwargs
|
|
45
45
|
}
|
|
46
46
|
self._async_dump_cache = {}
|
|
47
|
+
self.api_register = get_api_register()
|
|
47
48
|
|
|
48
49
|
@staticmethod
|
|
49
50
|
def get_md5_for_tensor(x):
|
|
@@ -74,46 +75,29 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
74
75
|
else:
|
|
75
76
|
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
76
77
|
data = data.to(ms.float32)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
if hasattr(mint, "norm"):
|
|
82
|
-
get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
|
|
83
|
-
else:
|
|
84
|
-
get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
|
|
85
|
-
tensor_stat.max = get_max_value(data).item()
|
|
86
|
-
tensor_stat.min = get_min_value(data).item()
|
|
87
|
-
tensor_stat.mean = get_mean_value(data).item()
|
|
78
|
+
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
79
|
+
tensor_stat.max = mint.max(data).item()
|
|
80
|
+
tensor_stat.min = mint.min(data).item()
|
|
81
|
+
tensor_stat.mean = mint.mean(data).item()
|
|
88
82
|
tensor_stat.norm = get_norm_value(data).item()
|
|
89
|
-
api_register.norm_inner_op_set_hook_func()
|
|
90
83
|
return tensor_stat
|
|
91
84
|
|
|
92
85
|
@staticmethod
|
|
93
86
|
def get_stat_info_async(data):
|
|
94
87
|
tensor_stat = TensorStatInfo()
|
|
95
|
-
stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
|
|
96
88
|
if data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
97
89
|
logger.warning("Async dump do not support complex data!")
|
|
98
90
|
return tensor_stat
|
|
99
91
|
elif data.dtype == ms.bool_:
|
|
100
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min"],
|
|
92
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min"], ops.stack([data.any(), data.all()]))
|
|
101
93
|
elif not data.shape:
|
|
102
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"],
|
|
94
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack([data, data, data, data]))
|
|
103
95
|
else:
|
|
104
96
|
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
105
97
|
data = data.to(ms.float32)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
|
|
110
|
-
if hasattr(mint, "norm"):
|
|
111
|
-
get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
|
|
112
|
-
else:
|
|
113
|
-
get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
|
|
114
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
|
|
115
|
-
[get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
|
|
116
|
-
api_register.norm_inner_op_set_hook_func()
|
|
98
|
+
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
99
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack(
|
|
100
|
+
[mint.max(data), mint.min(data), mint.mean(data), get_norm_value(data)]))
|
|
117
101
|
return tensor_stat
|
|
118
102
|
|
|
119
103
|
@staticmethod
|
|
@@ -125,14 +109,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
125
109
|
return super().get_special_types() + cls.mindspore_special_type
|
|
126
110
|
|
|
127
111
|
def get_stat_info(self, data):
|
|
112
|
+
self.api_register.restore_inner_used_api()
|
|
128
113
|
tensor_stat = TensorStatInfo()
|
|
129
114
|
if data.numel() == 0:
|
|
130
|
-
|
|
115
|
+
stat_info = tensor_stat
|
|
131
116
|
else:
|
|
132
117
|
if self.config.async_dump:
|
|
133
|
-
|
|
118
|
+
stat_info = MindsporeDataProcessor.get_stat_info_async(data)
|
|
134
119
|
else:
|
|
135
|
-
|
|
120
|
+
stat_info = MindsporeDataProcessor.get_stat_info_sync(data)
|
|
121
|
+
self.api_register.register_inner_used_api()
|
|
122
|
+
return stat_info
|
|
136
123
|
|
|
137
124
|
def analyze_single_element(self, element, suffix_stack):
|
|
138
125
|
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
@@ -191,7 +178,7 @@ class TensorDataProcessor(MindsporeDataProcessor):
|
|
|
191
178
|
else:
|
|
192
179
|
save_tensor_as_npy(tensor, file_path)
|
|
193
180
|
return single_arg
|
|
194
|
-
|
|
181
|
+
|
|
195
182
|
def _analyze_numpy(self, ndarray, suffix):
|
|
196
183
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
197
184
|
save_npy(ndarray, file_path)
|
|
@@ -244,7 +231,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
244
231
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
245
232
|
self.maybe_save_overflow_data()
|
|
246
233
|
return api_info_struct if self.has_overflow else None
|
|
247
|
-
|
|
234
|
+
|
|
248
235
|
def analyze_params(self, name, param_name, grad):
|
|
249
236
|
self.has_overflow = False
|
|
250
237
|
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
@@ -24,14 +24,15 @@ from torch import distributed as dist
|
|
|
24
24
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
25
25
|
|
|
26
26
|
from msprobe.core.common.const import Const
|
|
27
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
27
28
|
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
28
29
|
from msprobe.core.common.log import logger
|
|
29
30
|
from msprobe.core.common.utils import convert_tuple
|
|
31
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
32
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
31
33
|
ModuleForwardInputsOutputs, TensorStatInfo
|
|
32
|
-
from msprobe.pytorch.common.utils import save_pt,
|
|
34
|
+
from msprobe.pytorch.common.utils import Const as PtConst, save_pt, is_hifloat8_tensor, is_float8_tensor
|
|
33
35
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
|
|
34
|
-
from msprobe.core.common.utils import recursion_depth_decorator
|
|
35
36
|
|
|
36
37
|
is_gpu = False
|
|
37
38
|
try:
|
|
@@ -78,14 +79,16 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
78
79
|
def analyze_device_in_kwargs(element):
|
|
79
80
|
single_arg = {}
|
|
80
81
|
single_arg.update({'type': "torch.device"})
|
|
81
|
-
if
|
|
82
|
+
if isinstance(element, (int, str)):
|
|
83
|
+
single_arg.update({"value": element})
|
|
84
|
+
elif isinstance(element, torch.device):
|
|
82
85
|
if hasattr(element, "index"):
|
|
83
86
|
device_value = element.type + ":" + str(element.index)
|
|
84
87
|
else:
|
|
85
88
|
device_value = element.type
|
|
86
89
|
single_arg.update({"value": device_value})
|
|
87
90
|
else:
|
|
88
|
-
|
|
91
|
+
logger.debug(f"Device type {type(element)} is not supported.")
|
|
89
92
|
return single_arg
|
|
90
93
|
|
|
91
94
|
@staticmethod
|
|
@@ -143,7 +146,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
143
146
|
if data.is_meta:
|
|
144
147
|
return tensor_stat
|
|
145
148
|
data_clone = data.detach()
|
|
146
|
-
if data_clone.numel()
|
|
149
|
+
if not data_clone.numel() or not data_clone.data_ptr():
|
|
147
150
|
return tensor_stat
|
|
148
151
|
else:
|
|
149
152
|
if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
|
|
@@ -214,6 +217,18 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
214
217
|
logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
|
|
215
218
|
return {"type": "torch.distributed.ReduceOp", "value": op_type}
|
|
216
219
|
|
|
220
|
+
@staticmethod
|
|
221
|
+
def _cast_to_float_if_fp8(tensor):
|
|
222
|
+
dtype = str(tensor.dtype)
|
|
223
|
+
if is_float8_tensor(tensor):
|
|
224
|
+
dtype = PtConst.HIFLOAT8_TYPE if is_hifloat8_tensor(tensor) else dtype
|
|
225
|
+
logger.debug(
|
|
226
|
+
f"The {dtype} tensor analyzing/saving is unsupported in dump function."
|
|
227
|
+
f"Casting to float for processing."
|
|
228
|
+
)
|
|
229
|
+
tensor = tensor.float()
|
|
230
|
+
return tensor, dtype
|
|
231
|
+
|
|
217
232
|
@classmethod
|
|
218
233
|
def get_special_types(cls):
|
|
219
234
|
return super().get_special_types() + cls.pytorch_special_type
|
|
@@ -228,7 +243,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
228
243
|
if isinstance(element, dist.ProcessGroup):
|
|
229
244
|
return self._analyze_process_group(element)
|
|
230
245
|
if isinstance(element, dist.P2POp):
|
|
231
|
-
return self._analyze_p2pop(element)
|
|
246
|
+
return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
232
247
|
if isinstance(element, dist.ReduceOp):
|
|
233
248
|
return self._analyze_reduce_op(element)
|
|
234
249
|
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
@@ -247,10 +262,10 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
247
262
|
module_input_output.update_output_with_args_and_kwargs()
|
|
248
263
|
return super().analyze_forward_output(name, module, module_input_output)
|
|
249
264
|
|
|
250
|
-
def _analyze_p2pop(self, arg):
|
|
265
|
+
def _analyze_p2pop(self, arg, suffix):
|
|
251
266
|
p2pop_info = {"class_type": "torch.distributed.P2POp"}
|
|
252
267
|
try:
|
|
253
|
-
tensor_info = self._analyze_tensor(arg.tensor,
|
|
268
|
+
tensor_info = self._analyze_tensor(arg.tensor, suffix)
|
|
254
269
|
p2pop_info.update({"tensor": tensor_info})
|
|
255
270
|
p2pop_info.update({"op": arg.op.__name__})
|
|
256
271
|
p2pop_info.update({"peer": arg.peer})
|
|
@@ -263,10 +278,11 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
263
278
|
return p2pop_info
|
|
264
279
|
|
|
265
280
|
def _analyze_tensor(self, tensor, suffix):
|
|
281
|
+
tensor, dtype = self._cast_to_float_if_fp8(tensor)
|
|
266
282
|
tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
|
|
267
283
|
tensor_json = {}
|
|
268
284
|
tensor_json.update({'type': 'torch.Tensor'})
|
|
269
|
-
tensor_json.update({'dtype':
|
|
285
|
+
tensor_json.update({'dtype': dtype})
|
|
270
286
|
tensor_json.update({"shape": tensor.shape})
|
|
271
287
|
if tensor_stat.stack_tensor_stat is None:
|
|
272
288
|
tensor_json.update({"Max": tensor_stat.max})
|
|
@@ -305,13 +321,14 @@ class TensorDataProcessor(PytorchDataProcessor):
|
|
|
305
321
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
306
322
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
307
323
|
single_arg.update({"data_name": dump_data_name})
|
|
324
|
+
tensor, _ = self._cast_to_float_if_fp8(tensor)
|
|
308
325
|
if self.config.async_dump:
|
|
309
326
|
self._async_dump_cache[file_path] = tensor.clone().detach()
|
|
310
327
|
else:
|
|
311
328
|
saved_tensor = tensor.clone().contiguous().detach()
|
|
312
329
|
save_pt(saved_tensor, file_path)
|
|
313
330
|
return single_arg
|
|
314
|
-
|
|
331
|
+
|
|
315
332
|
def _analyze_numpy(self, ndarray, suffix):
|
|
316
333
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
317
334
|
save_pt(torch.tensor(ndarray), file_path)
|
|
@@ -383,7 +400,8 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
383
400
|
self._analyze_maybe_overflow_flag()
|
|
384
401
|
if self.has_overflow:
|
|
385
402
|
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
386
|
-
|
|
403
|
+
tensor, _ = self._cast_to_float_if_fp8(tensor)
|
|
404
|
+
save_pt(tensor.clone().contiguous().detach(), file_path)
|
|
387
405
|
self.real_overflow_nums += 1
|
|
388
406
|
if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
|
|
389
407
|
logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
|
|
@@ -508,11 +526,13 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
508
526
|
return
|
|
509
527
|
|
|
510
528
|
if self.config.is_backward_kernel_dump:
|
|
511
|
-
self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
|
|
512
|
-
self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
|
|
513
529
|
try:
|
|
530
|
+
self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
|
|
531
|
+
self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
|
|
514
532
|
output = module.forward(*self.forward_args, **self.forward_kwargs)
|
|
515
|
-
except Exception:
|
|
533
|
+
except Exception as e:
|
|
534
|
+
if isinstance(e, MsprobeException):
|
|
535
|
+
logger.warning(str(e))
|
|
516
536
|
self._print_unsupported_log(name)
|
|
517
537
|
self.enable_kernel_dump = False
|
|
518
538
|
return
|
|
@@ -554,9 +574,17 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
554
574
|
self.stop_kernel_dump()
|
|
555
575
|
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
556
576
|
|
|
557
|
-
@recursion_depth_decorator(
|
|
577
|
+
@recursion_depth_decorator(
|
|
578
|
+
"KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor",
|
|
579
|
+
max_depth=Const.DUMP_MAX_DEPTH
|
|
580
|
+
)
|
|
558
581
|
def clone_and_detach_tensor(self, input_params):
|
|
559
582
|
if isinstance(input_params, torch.Tensor):
|
|
583
|
+
if is_float8_tensor(input_params):
|
|
584
|
+
raise MsprobeException(
|
|
585
|
+
MsprobeException.UNSUPPORTED_TYPE_ERROR,
|
|
586
|
+
f"L2 backward dump does not support float8 type."
|
|
587
|
+
)
|
|
560
588
|
if input_params.requires_grad:
|
|
561
589
|
return input_params.clone().detach().requires_grad_()
|
|
562
590
|
return input_params.clone()
|
|
@@ -571,6 +599,8 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
571
599
|
|
|
572
600
|
def analyze_single_element(self, element, suffix_stack):
|
|
573
601
|
if isinstance(element, torch.Tensor):
|
|
602
|
+
if is_float8_tensor(element):
|
|
603
|
+
return {}
|
|
574
604
|
if not self.is_found_output_tensor:
|
|
575
605
|
if element.requires_grad:
|
|
576
606
|
self.forward_output_tensor = element
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,12 +16,14 @@
|
|
|
16
16
|
import csv
|
|
17
17
|
import os
|
|
18
18
|
import copy
|
|
19
|
-
import
|
|
19
|
+
import threading
|
|
20
20
|
|
|
21
21
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
22
22
|
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
|
|
23
23
|
from msprobe.core.common.log import logger
|
|
24
|
-
from msprobe.core.common.
|
|
24
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
25
|
+
|
|
26
|
+
lock = threading.Lock()
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class DataWriter:
|
|
@@ -90,28 +92,32 @@ class DataWriter:
|
|
|
90
92
|
self.write_json()
|
|
91
93
|
|
|
92
94
|
def update_data(self, new_data):
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
dump_data
|
|
104
|
-
|
|
105
|
-
|
|
95
|
+
with lock:
|
|
96
|
+
if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
|
|
97
|
+
logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
|
|
98
|
+
return
|
|
99
|
+
dump_data = self.cache_data.get(Const.DATA)
|
|
100
|
+
if not isinstance(dump_data, dict):
|
|
101
|
+
logger.warning(f"The dump data({dump_data}) should be a dict.")
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
key = next(iter(new_data.keys()))
|
|
105
|
+
if key in dump_data:
|
|
106
|
+
dump_data.get(key).update(new_data.get(key))
|
|
107
|
+
else:
|
|
108
|
+
dump_data.update(new_data)
|
|
106
109
|
|
|
107
110
|
def update_stack(self, new_data):
|
|
108
|
-
|
|
111
|
+
with lock:
|
|
112
|
+
self.cache_stack.update(new_data)
|
|
109
113
|
|
|
110
114
|
def update_construct(self, new_data):
|
|
111
|
-
|
|
115
|
+
with lock:
|
|
116
|
+
self.cache_construct.update(new_data)
|
|
112
117
|
|
|
113
118
|
def update_debug(self, new_data):
|
|
114
|
-
|
|
119
|
+
with lock:
|
|
120
|
+
self.cache_debug['data'].update(new_data)
|
|
115
121
|
|
|
116
122
|
def write_data_json(self, file_path):
|
|
117
123
|
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
|
|
@@ -127,22 +133,21 @@ class DataWriter:
|
|
|
127
133
|
save_json(file_path, self.cache_debug, indent=1)
|
|
128
134
|
|
|
129
135
|
def write_json(self):
|
|
130
|
-
|
|
131
|
-
self.
|
|
132
|
-
|
|
133
|
-
self.
|
|
134
|
-
|
|
135
|
-
self.
|
|
136
|
-
|
|
137
|
-
self.
|
|
136
|
+
with lock:
|
|
137
|
+
if self.cache_data:
|
|
138
|
+
self.write_data_json(self.dump_file_path)
|
|
139
|
+
if self.cache_stack:
|
|
140
|
+
self.write_stack_info_json(self.stack_file_path)
|
|
141
|
+
if self.cache_construct:
|
|
142
|
+
self.write_construct_info_json(self.construct_file_path)
|
|
143
|
+
if self.cache_debug:
|
|
144
|
+
self.write_debug_info_json(self.debug_file_path)
|
|
138
145
|
|
|
139
146
|
def fill_stack_tensor_data(self):
|
|
140
147
|
self.process_stat_data_recursive(self.cache_data)
|
|
141
148
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
|
|
145
|
-
raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
|
|
149
|
+
@recursion_depth_decorator("AsyncDump: DataWriter.process_stat_data_recursive", max_depth=Const.DUMP_MAX_DEPTH)
|
|
150
|
+
def process_stat_data_recursive(self, data):
|
|
146
151
|
if isinstance(data, dict):
|
|
147
152
|
if "tensor_stat" in data.keys():
|
|
148
153
|
tensor_stat = data["tensor_stat"]
|
|
@@ -150,14 +155,12 @@ class DataWriter:
|
|
|
150
155
|
logger.warning("Some bad data in async dump")
|
|
151
156
|
else:
|
|
152
157
|
tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
|
|
153
|
-
if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
|
|
154
|
-
tensor_stat_data = tensor_stat_data.cpu()
|
|
155
158
|
for index, stat in zip(tensor_stat_index, tensor_stat_data):
|
|
156
159
|
data.update({index: stat.item()})
|
|
157
160
|
del data["tensor_stat"]
|
|
158
161
|
else:
|
|
159
162
|
for key in data.keys():
|
|
160
|
-
self.process_stat_data_recursive(data[key]
|
|
163
|
+
self.process_stat_data_recursive(data[key])
|
|
161
164
|
elif isinstance(data, (list, tuple)):
|
|
162
165
|
for i in data:
|
|
163
|
-
self.process_stat_data_recursive(i
|
|
166
|
+
self.process_stat_data_recursive(i)
|
|
@@ -112,7 +112,7 @@ class GradComparator:
|
|
|
112
112
|
result.append([key] + value)
|
|
113
113
|
result_csv_path = os.path.join(output_dir, "similarities.csv")
|
|
114
114
|
if os.path.exists(result_csv_path):
|
|
115
|
-
logger.warning(f"{result_csv_path} will be
|
|
115
|
+
logger.warning(f"{result_csv_path} will be deleted")
|
|
116
116
|
remove_path(result_csv_path)
|
|
117
117
|
write_csv(result, result_csv_path)
|
|
118
118
|
|
|
@@ -20,6 +20,7 @@ import numpy as np
|
|
|
20
20
|
from msprobe.core.overflow_check.api_info import APIInfo
|
|
21
21
|
from msprobe.core.overflow_check.level import OverflowLevel
|
|
22
22
|
from msprobe.core.overflow_check.utils import has_nan_inf
|
|
23
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class AnomalyScene:
|
|
@@ -35,6 +36,7 @@ class AnomalyScene:
|
|
|
35
36
|
raise NotImplementedError
|
|
36
37
|
|
|
37
38
|
@staticmethod
|
|
39
|
+
@recursion_depth_decorator("AbnormalScene: AnomalyScene._has_anomaly")
|
|
38
40
|
def _has_anomaly(data: Union[Dict, Any]) -> bool:
|
|
39
41
|
"""检查张量是否包含异常值"""
|
|
40
42
|
if isinstance(data, dict):
|
msprobe/docs/01.installation.md
CHANGED
|
@@ -16,6 +16,7 @@ pip install mindstudio-probe
|
|
|
16
16
|
|
|
17
17
|
|版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码|
|
|
18
18
|
|:--:|:--:|:--:|:--:|:--:|:--:|
|
|
19
|
+
|1.2.2|2025.3.03|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl)|961411bb460d327ea51d6ca4d0c8e8c5565f07c0852d7b8592b781ca35b87212|
|
|
19
20
|
|1.2.1|2025.2.07|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.1-py3-none-any.whl)|b64b342118558e0339b39237f88a49b93fd24551b0cb202c872fbfef4260c86b|
|
|
20
21
|
|1.2.0|2025.1.13|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.0-py3-none-any.whl)|1e3aeea1706112f6ee52fd1165037936bb209138f0b9ec42ea21e2c1c8942cdc|
|
|
21
22
|
|1.1.1|2024.12.09|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.1-py3-none-any.whl)|577b597555dc155b76ba1a62d575c3546004644e140a456c3ba0824d46283735|
|
|
@@ -51,7 +52,7 @@ pip install ./mindstudio_probe*.whl
|
|
|
51
52
|
|
|
52
53
|
|参数|说明|是否必选|
|
|
53
54
|
|--|--|:--:|
|
|
54
|
-
|--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。<br>• adump模块用于MindSpore静态图场景L2级别的dump。<br>• 仅MindSpore 2.5.0及以上版本支持adump模块。<br>• 若使用源码安装,编译环境需支持GCC 7或以上版本,和CMAKE 3.14或以上版本。<br>• 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否|
|
|
55
|
+
|--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。<br>• adump模块用于MindSpore静态图场景L2级别的dump。<br>• 仅MindSpore 2.5.0及以上版本支持adump模块。<br>• 若使用源码安装,编译环境需支持GCC 7.5或以上版本,和CMAKE 3.14或以上版本。<br>• 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否|
|
|
55
56
|
|
|
56
57
|
# 特性变更说明
|
|
57
58
|
|