mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -13,9 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import atexit
|
|
16
17
|
import os
|
|
17
18
|
|
|
18
|
-
from msprobe.core.data_dump.scope import
|
|
19
|
+
from msprobe.core.data_dump.scope import ScopeFactory
|
|
19
20
|
from msprobe.core.data_dump.json_writer import DataWriter
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
21
22
|
from msprobe.core.common.const import Const
|
|
@@ -27,7 +28,6 @@ def build_data_collector(config):
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class DataCollector:
|
|
30
|
-
multi_output_apis = ["_sort_", "npu_flash_attention"]
|
|
31
31
|
tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
|
|
32
32
|
level_without_construct = [Const.LEVEL_L1, Const.LEVEL_L2]
|
|
33
33
|
|
|
@@ -37,13 +37,8 @@ class DataCollector:
|
|
|
37
37
|
self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
|
|
38
38
|
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
|
|
39
39
|
self.module_count = {}
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
else:
|
|
43
|
-
self.scope = build_scope(None, self.config.scope, self.config.list)
|
|
44
|
-
|
|
45
|
-
def __del__(self):
|
|
46
|
-
self.write_json()
|
|
40
|
+
self.scope = ScopeFactory(self.config).build_scope()
|
|
41
|
+
atexit.register(self.write_json)
|
|
47
42
|
|
|
48
43
|
@property
|
|
49
44
|
def dump_data_dir(self):
|
|
@@ -85,6 +80,10 @@ class DataCollector:
|
|
|
85
80
|
self.data_writer.update_data(data_info)
|
|
86
81
|
|
|
87
82
|
def pre_forward_data_collect(self, name, module, pid, module_input_output):
|
|
83
|
+
if self.config.level == Const.LEVEL_L2 and self.check_scope_and_pid(self.scope, name, pid):
|
|
84
|
+
self.data_processor.analyze_pre_forward(name, module, module_input_output)
|
|
85
|
+
return
|
|
86
|
+
|
|
88
87
|
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
89
88
|
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
90
89
|
self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
|
|
@@ -98,13 +97,14 @@ class DataCollector:
|
|
|
98
97
|
self.update_construct(name)
|
|
99
98
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
100
99
|
return
|
|
100
|
+
if self.config.level == Const.LEVEL_L2:
|
|
101
|
+
self.data_processor.analyze_forward(name, module, module_input_output)
|
|
102
|
+
return
|
|
101
103
|
|
|
102
104
|
if not self.is_inplace(module):
|
|
103
105
|
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
104
106
|
else:
|
|
105
107
|
data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
|
|
106
|
-
if self.config.level == "L2":
|
|
107
|
-
return
|
|
108
108
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
109
109
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
110
110
|
|
|
@@ -114,6 +114,8 @@ class DataCollector:
|
|
|
114
114
|
return
|
|
115
115
|
|
|
116
116
|
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
117
|
+
if self.config.level == Const.LEVEL_L2:
|
|
118
|
+
return
|
|
117
119
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
118
120
|
|
|
119
121
|
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -15,10 +15,11 @@
|
|
|
15
15
|
|
|
16
16
|
import inspect
|
|
17
17
|
import os
|
|
18
|
-
from dataclasses import dataclass
|
|
18
|
+
from dataclasses import dataclass, is_dataclass
|
|
19
19
|
from typing import Tuple, Dict, Optional, Any
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
|
+
|
|
22
23
|
from msprobe.core.common.const import Const
|
|
23
24
|
from msprobe.core.common.log import logger
|
|
24
25
|
from msprobe.core.common.utils import convert_tuple, CompareException
|
|
@@ -101,6 +102,8 @@ class BaseDataProcessor:
|
|
|
101
102
|
self.current_iter = 0
|
|
102
103
|
self._return_forward_new_output = False
|
|
103
104
|
self._forward_new_output = None
|
|
105
|
+
if hasattr(config, "data_mode"):
|
|
106
|
+
self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
|
|
104
107
|
|
|
105
108
|
@property
|
|
106
109
|
def data_path(self):
|
|
@@ -182,6 +185,18 @@ class BaseDataProcessor:
|
|
|
182
185
|
def _analyze_numpy(value, numpy_type):
|
|
183
186
|
return {"type": numpy_type, "value": value}
|
|
184
187
|
|
|
188
|
+
@staticmethod
|
|
189
|
+
def _get_allowed_data_mode(data_mode):
|
|
190
|
+
if Const.ALL in data_mode:
|
|
191
|
+
allowed_data_mode = [Const.FORWARD, Const.BACKWARD, Const.INPUT, Const.OUTPUT]
|
|
192
|
+
else:
|
|
193
|
+
allowed_data_mode = list(set(data_mode))
|
|
194
|
+
if Const.FORWARD not in allowed_data_mode and Const.BACKWARD not in allowed_data_mode:
|
|
195
|
+
allowed_data_mode += [Const.FORWARD, Const.BACKWARD]
|
|
196
|
+
if Const.INPUT not in allowed_data_mode and Const.OUTPUT not in allowed_data_mode:
|
|
197
|
+
allowed_data_mode += [Const.INPUT, Const.OUTPUT]
|
|
198
|
+
return allowed_data_mode
|
|
199
|
+
|
|
185
200
|
@classmethod
|
|
186
201
|
def get_special_types(cls):
|
|
187
202
|
return cls.special_type
|
|
@@ -194,25 +209,42 @@ class BaseDataProcessor:
|
|
|
194
209
|
if isinstance(args, cls.get_special_types()):
|
|
195
210
|
arg_transform = transform(args, cls._recursive_key_stack)
|
|
196
211
|
return arg_transform
|
|
212
|
+
elif isinstance(args, tuple) and hasattr(args, '_fields'):
|
|
213
|
+
# namedtuple to dict
|
|
214
|
+
args_dict = {field: getattr(args, field) for field in args._fields}
|
|
215
|
+
return cls.apply_transform_dict(args_dict, transform, depth)
|
|
216
|
+
elif is_dataclass(args):
|
|
217
|
+
# dataclass to dict
|
|
218
|
+
args_dict = {field: getattr(args, field) for field in args.__dataclass_fields__}
|
|
219
|
+
return cls.apply_transform_dict(args_dict, transform, depth)
|
|
197
220
|
elif isinstance(args, (list, tuple)):
|
|
198
|
-
result_list =
|
|
199
|
-
for i, arg in enumerate(args):
|
|
200
|
-
cls._recursive_key_stack.append(str(i))
|
|
201
|
-
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
202
|
-
cls._recursive_key_stack.pop()
|
|
221
|
+
result_list = cls.apply_transform_list(args, transform, depth)
|
|
203
222
|
return type(args)(result_list)
|
|
204
223
|
elif isinstance(args, dict):
|
|
205
|
-
|
|
206
|
-
for k, arg in args.items():
|
|
207
|
-
cls._recursive_key_stack.append(str(k))
|
|
208
|
-
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
209
|
-
cls._recursive_key_stack.pop()
|
|
210
|
-
return result_dict
|
|
224
|
+
return cls.apply_transform_dict(args, transform, depth)
|
|
211
225
|
elif args is not None:
|
|
212
226
|
logger.warning(f"Data type {type(args)} is not supported.")
|
|
213
227
|
return None
|
|
214
228
|
else:
|
|
215
229
|
return None
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def apply_transform_dict(cls, args, transform, depth):
|
|
233
|
+
result_dict = {}
|
|
234
|
+
for k, arg in args.items():
|
|
235
|
+
cls._recursive_key_stack.append(str(k))
|
|
236
|
+
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
237
|
+
cls._recursive_key_stack.pop()
|
|
238
|
+
return result_dict
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def apply_transform_list(cls, args, transform, depth):
|
|
242
|
+
result_list = []
|
|
243
|
+
for i, arg in enumerate(args):
|
|
244
|
+
cls._recursive_key_stack.append(str(i))
|
|
245
|
+
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
246
|
+
cls._recursive_key_stack.pop()
|
|
247
|
+
return result_list
|
|
216
248
|
|
|
217
249
|
def if_return_forward_new_output(self):
|
|
218
250
|
return self._return_forward_new_output
|
|
@@ -239,9 +271,7 @@ class BaseDataProcessor:
|
|
|
239
271
|
Return:
|
|
240
272
|
bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
|
|
241
273
|
"""
|
|
242
|
-
return
|
|
243
|
-
forward_backward in self.config.data_mode or
|
|
244
|
-
input_output in self.config.data_mode)
|
|
274
|
+
return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
|
|
245
275
|
|
|
246
276
|
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
247
277
|
pass
|
|
@@ -41,7 +41,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
41
41
|
@staticmethod
|
|
42
42
|
def get_md5_for_tensor(x):
|
|
43
43
|
x = convert_bf16_to_fp32(x)
|
|
44
|
-
tensor_bytes = x.asnumpy().tobytes()
|
|
44
|
+
tensor_bytes = x.contiguous().asnumpy().tobytes()
|
|
45
45
|
crc32_hash = zlib.crc32(tensor_bytes)
|
|
46
46
|
return f"{crc32_hash:08x}"
|
|
47
47
|
|
|
@@ -58,19 +58,19 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
58
58
|
if data.numel() == 0:
|
|
59
59
|
return tensor_stat
|
|
60
60
|
elif data.dtype == ms.bool_:
|
|
61
|
-
data_np = data.asnumpy()
|
|
61
|
+
data_np = data.contiguous().asnumpy()
|
|
62
62
|
tensor_stat.max = np.max(data_np).item()
|
|
63
63
|
tensor_stat.min = np.min(data_np).item()
|
|
64
64
|
elif not data.shape:
|
|
65
65
|
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
|
|
66
66
|
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
67
|
-
data_abs = np.abs(data.asnumpy())
|
|
67
|
+
data_abs = np.abs(data.contiguous().asnumpy())
|
|
68
68
|
tensor_stat.max = np.max(data_abs).item()
|
|
69
69
|
tensor_stat.min = np.min(data_abs).item()
|
|
70
70
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
71
71
|
tensor_stat.norm = np.linalg.norm(data_abs).item()
|
|
72
72
|
else:
|
|
73
|
-
if not ops.is_floating_point(data):
|
|
73
|
+
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
74
74
|
data = data.to(ms.float32)
|
|
75
75
|
api_register.norm_inner_op_set_ori_func()
|
|
76
76
|
get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
|
|
@@ -13,19 +13,24 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import hashlib
|
|
16
17
|
import zlib
|
|
17
18
|
from dataclasses import asdict
|
|
18
19
|
from typing import List
|
|
19
20
|
|
|
20
21
|
import numpy as np
|
|
21
22
|
import torch
|
|
23
|
+
from torch import distributed as dist
|
|
24
|
+
|
|
22
25
|
from msprobe.core.common.const import Const
|
|
23
26
|
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
24
27
|
from msprobe.core.common.log import logger
|
|
28
|
+
from msprobe.core.common.utils import convert_tuple
|
|
25
29
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
26
30
|
ModuleForwardInputsOutputs, TensorStatInfo
|
|
27
31
|
from msprobe.pytorch.common.utils import save_pt, load_pt
|
|
28
32
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
|
|
33
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
29
34
|
|
|
30
35
|
is_gpu = False
|
|
31
36
|
try:
|
|
@@ -35,7 +40,13 @@ except ImportError:
|
|
|
35
40
|
|
|
36
41
|
|
|
37
42
|
class PytorchDataProcessor(BaseDataProcessor):
|
|
38
|
-
pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
|
|
43
|
+
pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup)
|
|
44
|
+
memory_format = {
|
|
45
|
+
torch.contiguous_format: "contiguous_format",
|
|
46
|
+
torch.channels_last: "channels_last",
|
|
47
|
+
torch.channels_last_3d: "channels_last_3d",
|
|
48
|
+
torch.preserve_format: "preserve_format"
|
|
49
|
+
}
|
|
39
50
|
|
|
40
51
|
def __init__(self, config, data_writer):
|
|
41
52
|
super().__init__(config, data_writer)
|
|
@@ -79,8 +90,8 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
79
90
|
if data_clone.numel() == 0:
|
|
80
91
|
return tensor_stat
|
|
81
92
|
elif data_clone.dtype == torch.bool:
|
|
82
|
-
tensor_stat.max =
|
|
83
|
-
tensor_stat.min =
|
|
93
|
+
tensor_stat.max = torch._C._VariableFunctionsClass.any(data_clone).item()
|
|
94
|
+
tensor_stat.min = torch._C._VariableFunctionsClass.all(data_clone).item()
|
|
84
95
|
elif not data_clone.shape:
|
|
85
96
|
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
|
|
86
97
|
elif torch.is_complex(data_clone):
|
|
@@ -104,20 +115,46 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
104
115
|
data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
|
|
105
116
|
if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
|
|
106
117
|
return float('nan')
|
|
118
|
+
|
|
107
119
|
finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
|
|
108
120
|
if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
|
|
109
|
-
finite_values = data_clone
|
|
121
|
+
finite_values = getattr(torch._C._TensorBase, "__getitem__")(data_clone, finite_mask)
|
|
110
122
|
return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
|
|
111
123
|
torch._C._VariableFunctionsClass.min(finite_values).item()
|
|
112
124
|
else:
|
|
113
|
-
data_no_nan = data_clone
|
|
125
|
+
data_no_nan = getattr(torch._C._TensorBase, "__getitem__")(data_clone, ~data_nan)
|
|
114
126
|
return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
|
|
115
127
|
torch._C._VariableFunctionsClass.min(data_no_nan).item()
|
|
116
128
|
|
|
129
|
+
@staticmethod
|
|
130
|
+
def process_group_hash(arg):
|
|
131
|
+
group_ranks = dist.get_process_group_ranks(arg)
|
|
132
|
+
group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
|
|
133
|
+
return group_ranks_hash
|
|
134
|
+
|
|
117
135
|
@staticmethod
|
|
118
136
|
def _analyze_torch_size(arg):
|
|
119
137
|
return {"type": "torch.Size", "value": list(arg)}
|
|
120
138
|
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _analyze_memory_format(arg):
|
|
141
|
+
# 获取内存格式
|
|
142
|
+
format_type = PytorchDataProcessor.memory_format.get(arg)
|
|
143
|
+
|
|
144
|
+
return {"type": "torch.memory_format", "format": format_type}
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def _analyze_process_group(arg):
|
|
148
|
+
group_info = {"type": "torch.ProcessGroup"}
|
|
149
|
+
try:
|
|
150
|
+
group_ranks = dist.get_process_group_ranks(arg)
|
|
151
|
+
group_info.update({"group_ranks": group_ranks})
|
|
152
|
+
group_id = PytorchDataProcessor.process_group_hash(arg)
|
|
153
|
+
group_info.update({"group_id": group_id})
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.")
|
|
156
|
+
return group_info
|
|
157
|
+
|
|
121
158
|
@classmethod
|
|
122
159
|
def get_special_types(cls):
|
|
123
160
|
return super().get_special_types() + cls.pytorch_special_type
|
|
@@ -127,6 +164,10 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
127
164
|
return self.torch_object_key[suffix_stack[-1]](element)
|
|
128
165
|
if isinstance(element, torch.Size):
|
|
129
166
|
return self._analyze_torch_size(element)
|
|
167
|
+
if isinstance(element, torch.memory_format):
|
|
168
|
+
return self._analyze_memory_format(element)
|
|
169
|
+
if isinstance(element, dist.ProcessGroup):
|
|
170
|
+
return self._analyze_process_group(element)
|
|
130
171
|
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
131
172
|
if converted_numpy is not element:
|
|
132
173
|
return self._analyze_numpy(converted_numpy, numpy_type)
|
|
@@ -320,64 +361,120 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
|
|
|
320
361
|
|
|
321
362
|
|
|
322
363
|
class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
323
|
-
forward_init_status = False
|
|
324
|
-
multi_output_apis = ["_sort_", "npu_flash_attention"]
|
|
325
|
-
|
|
326
364
|
def __init__(self, config, data_writer):
|
|
327
365
|
super().__init__(config, data_writer)
|
|
366
|
+
self.enable_kernel_dump = True
|
|
367
|
+
self.is_found_output_tensor = False
|
|
368
|
+
self.is_found_grad_input_tensor = False
|
|
369
|
+
self.forward_args = None
|
|
370
|
+
self.forward_kwargs = None
|
|
371
|
+
self.forward_output_tensor = None
|
|
372
|
+
self.grad_input_tensor = None
|
|
373
|
+
|
|
374
|
+
@staticmethod
|
|
375
|
+
def start_kernel_dump(config_path):
|
|
376
|
+
torch_npu.npu.synchronize()
|
|
377
|
+
torch_npu.npu.init_dump()
|
|
378
|
+
torch_npu.npu.set_dump(config_path)
|
|
379
|
+
torch_npu.npu.synchronize()
|
|
380
|
+
|
|
381
|
+
@staticmethod
|
|
382
|
+
def stop_kernel_dump():
|
|
383
|
+
torch_npu.npu.synchronize()
|
|
384
|
+
torch_npu.npu.finalize_dump()
|
|
385
|
+
torch_npu.npu.synchronize()
|
|
386
|
+
|
|
387
|
+
@staticmethod
|
|
388
|
+
def _print_unsupported_log(api_name):
|
|
389
|
+
logger.warning(f"The kernel dump does not support the {api_name} API.")
|
|
390
|
+
|
|
391
|
+
def analyze_pre_forward(self, name, module, module_input_output):
|
|
392
|
+
if not self.enable_kernel_dump:
|
|
393
|
+
return
|
|
394
|
+
if is_gpu:
|
|
395
|
+
logger.warning("The current environment is not a complete NPU environment, and kernel dump cannot be used.")
|
|
396
|
+
self.enable_kernel_dump = False
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
if self.config.is_backward_kernel_dump:
|
|
400
|
+
self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
|
|
401
|
+
self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
|
|
402
|
+
try:
|
|
403
|
+
output = module.forward(*self.forward_args, **self.forward_kwargs)
|
|
404
|
+
except Exception:
|
|
405
|
+
self._print_unsupported_log(name)
|
|
406
|
+
self.enable_kernel_dump = False
|
|
407
|
+
return
|
|
408
|
+
|
|
409
|
+
self.analyze_element(convert_tuple(output))
|
|
410
|
+
if not self.is_found_output_tensor:
|
|
411
|
+
self._print_unsupported_log(name)
|
|
412
|
+
self.enable_kernel_dump = False
|
|
413
|
+
return
|
|
414
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
328
415
|
|
|
329
416
|
def analyze_forward(self, name, module, module_input_output):
|
|
330
|
-
if self.
|
|
331
|
-
|
|
417
|
+
if not self.enable_kernel_dump:
|
|
418
|
+
return
|
|
419
|
+
if self.config.is_backward_kernel_dump:
|
|
420
|
+
return
|
|
421
|
+
self.enable_kernel_dump = False
|
|
422
|
+
self.stop_kernel_dump()
|
|
423
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
424
|
+
|
|
425
|
+
def analyze_backward(self, name, module, module_input_output):
|
|
426
|
+
if not self.enable_kernel_dump:
|
|
427
|
+
return
|
|
428
|
+
self.enable_kernel_dump = False
|
|
429
|
+
|
|
430
|
+
self.analyze_element(module_input_output.grad_input)
|
|
431
|
+
if not self.is_found_grad_input_tensor:
|
|
432
|
+
self._print_unsupported_log(name)
|
|
433
|
+
return
|
|
434
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
self.forward_output_tensor.backward(self.grad_input_tensor, retain_graph=True)
|
|
438
|
+
except Exception:
|
|
439
|
+
self._print_unsupported_log(name)
|
|
440
|
+
self.stop_kernel_dump()
|
|
441
|
+
return
|
|
442
|
+
|
|
443
|
+
self.stop_kernel_dump()
|
|
444
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
445
|
+
|
|
446
|
+
@recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor")
|
|
447
|
+
def clone_and_detach_tensor(self, input_params):
|
|
448
|
+
if isinstance(input_params, torch.Tensor):
|
|
449
|
+
if input_params.requires_grad:
|
|
450
|
+
return input_params.clone().detach().requires_grad_()
|
|
451
|
+
return input_params.clone()
|
|
452
|
+
elif isinstance(input_params, tuple):
|
|
453
|
+
return tuple(self.clone_and_detach_tensor(x) for x in input_params)
|
|
454
|
+
elif isinstance(input_params, list):
|
|
455
|
+
return list(self.clone_and_detach_tensor(x) for x in input_params)
|
|
456
|
+
elif isinstance(input_params, dict):
|
|
457
|
+
return {k: self.clone_and_detach_tensor(v) for k, v in input_params.items()}
|
|
332
458
|
else:
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def forward_acl_dump(self, name, module, module_input_output):
|
|
336
|
-
if not KernelDumpDataProcessor.forward_init_status:
|
|
337
|
-
KernelDumpDataProcessor.forward_init_status = True
|
|
338
|
-
torch_npu.npu.synchronize()
|
|
339
|
-
torch_npu.npu.init_dump()
|
|
340
|
-
torch_npu.npu.set_dump(self.config.acl_config)
|
|
341
|
-
torch_npu.npu.synchronize()
|
|
342
|
-
if self.op_need_trigger(name):
|
|
343
|
-
module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
|
|
344
|
-
else:
|
|
345
|
-
module.forward(*module_input_output.args, **module_input_output.kwargs)
|
|
346
|
-
torch_npu.npu.synchronize()
|
|
347
|
-
torch_npu.npu.finalize_dump()
|
|
348
|
-
torch_npu.npu.synchronize()
|
|
349
|
-
KernelDumpDataProcessor.forward_init_status = False
|
|
350
|
-
logger.info("Dump %s op file." % name)
|
|
351
|
-
|
|
352
|
-
def acl_backward_dump_status(self, output, grad, module_name):
|
|
353
|
-
if isinstance(output, torch.Tensor):
|
|
354
|
-
output.backward(grad, retain_graph=True)
|
|
355
|
-
return True
|
|
459
|
+
return input_params
|
|
356
460
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
461
|
+
def analyze_single_element(self, element, suffix_stack):
|
|
462
|
+
if isinstance(element, torch.Tensor):
|
|
463
|
+
if not self.is_found_output_tensor:
|
|
464
|
+
if element.requires_grad:
|
|
465
|
+
self.forward_output_tensor = element
|
|
466
|
+
self.is_found_output_tensor = True
|
|
467
|
+
return {}
|
|
468
|
+
if not self.is_found_grad_input_tensor:
|
|
469
|
+
self.grad_input_tensor = element.clone()
|
|
470
|
+
self.is_found_grad_input_tensor = True
|
|
471
|
+
return {}
|
|
362
472
|
|
|
363
|
-
def
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
torch_npu.npu.set_dump(self.config.acl_config)
|
|
372
|
-
torch_npu.npu.synchronize()
|
|
373
|
-
if not self.acl_backward_dump_status(output, grad, name):
|
|
374
|
-
logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
|
|
375
|
-
"you can manually construct a single API backward case for ACL dump.".format(
|
|
376
|
-
name))
|
|
377
|
-
torch_npu.npu.synchronize()
|
|
378
|
-
torch_npu.npu.finalize_dump()
|
|
379
|
-
KernelDumpDataProcessor.forward_init_status = False
|
|
380
|
-
logger.info("Dump %s op file." % name)
|
|
381
|
-
|
|
382
|
-
def op_need_trigger(self, module_name):
|
|
383
|
-
return 'Tensor.__getitem__.' in module_name
|
|
473
|
+
def reset_status(self):
|
|
474
|
+
self.enable_kernel_dump = True
|
|
475
|
+
self.is_found_output_tensor = False
|
|
476
|
+
self.is_found_grad_input_tensor = False
|
|
477
|
+
self.forward_args = None
|
|
478
|
+
self.forward_kwargs = None
|
|
479
|
+
self.forward_output_tensor = None
|
|
480
|
+
self.grad_input_tensor = None
|