mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -56,7 +56,7 @@ class DataProcessorFactory:
|
|
|
56
56
|
FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
|
|
57
57
|
KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
|
|
58
58
|
)
|
|
59
|
-
from msprobe.pytorch.module_processer import ModuleProcesser
|
|
59
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
60
60
|
cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
|
|
61
61
|
cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
|
|
62
62
|
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
|
|
@@ -67,10 +67,12 @@ class DataProcessorFactory:
|
|
|
67
67
|
from msprobe.core.data_dump.data_processor.mindspore_processor import (
|
|
68
68
|
StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
|
|
69
69
|
TensorDataProcessor as MindsporeTensorDataProcessor,
|
|
70
|
-
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
|
|
70
|
+
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
|
|
71
|
+
KernelDumpDataProcessor as MindsporeKernelDumpDataProcessor
|
|
71
72
|
)
|
|
72
73
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
73
74
|
cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
|
|
74
75
|
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
|
|
75
76
|
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
|
|
77
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
|
|
76
78
|
cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2024-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import zlib
|
|
17
17
|
|
|
18
18
|
import mindspore as ms
|
|
19
|
-
from mindspore import mint, ops
|
|
19
|
+
from mindspore import mint, ops, hal
|
|
20
20
|
from mindspore._c_expression.typing import Number
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
@@ -28,6 +28,12 @@ from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_
|
|
|
28
28
|
from msprobe.mindspore.common.log import logger
|
|
29
29
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
30
30
|
|
|
31
|
+
has_adump = True
|
|
32
|
+
try:
|
|
33
|
+
from msprobe.lib import _msprobe_c
|
|
34
|
+
except ImportError:
|
|
35
|
+
has_adump = False
|
|
36
|
+
|
|
31
37
|
|
|
32
38
|
class MindsporeDataProcessor(BaseDataProcessor):
|
|
33
39
|
mindspore_special_type = tuple([ms.Tensor, Number])
|
|
@@ -37,11 +43,12 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
37
43
|
self.mindspore_object_key = {
|
|
38
44
|
"dtype": self.analyze_dtype_in_kwargs
|
|
39
45
|
}
|
|
46
|
+
self._async_dump_cache = {}
|
|
40
47
|
|
|
41
48
|
@staticmethod
|
|
42
49
|
def get_md5_for_tensor(x):
|
|
43
50
|
x = convert_bf16_to_fp32(x)
|
|
44
|
-
tensor_bytes = x.
|
|
51
|
+
tensor_bytes = x.asnumpy().tobytes()
|
|
45
52
|
crc32_hash = zlib.crc32(tensor_bytes)
|
|
46
53
|
return f"{crc32_hash:08x}"
|
|
47
54
|
|
|
@@ -49,22 +56,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
49
56
|
def analyze_dtype_in_kwargs(element):
|
|
50
57
|
return {"type": "mindspore.dtype", "value": str(element)}
|
|
51
58
|
|
|
52
|
-
@
|
|
53
|
-
def
|
|
54
|
-
return super().get_special_types() + cls.mindspore_special_type
|
|
55
|
-
|
|
56
|
-
def get_stat_info(self, data):
|
|
59
|
+
@staticmethod
|
|
60
|
+
def get_stat_info_sync(data):
|
|
57
61
|
tensor_stat = TensorStatInfo()
|
|
58
|
-
if data.
|
|
59
|
-
|
|
60
|
-
elif data.dtype == ms.bool_:
|
|
61
|
-
data_np = data.contiguous().asnumpy()
|
|
62
|
+
if data.dtype == ms.bool_:
|
|
63
|
+
data_np = data.asnumpy()
|
|
62
64
|
tensor_stat.max = np.max(data_np).item()
|
|
63
65
|
tensor_stat.min = np.min(data_np).item()
|
|
64
66
|
elif not data.shape:
|
|
65
67
|
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
|
|
66
68
|
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
67
|
-
data_abs = np.abs(data.
|
|
69
|
+
data_abs = np.abs(data.asnumpy())
|
|
68
70
|
tensor_stat.max = np.max(data_abs).item()
|
|
69
71
|
tensor_stat.min = np.min(data_abs).item()
|
|
70
72
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
@@ -87,6 +89,47 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
87
89
|
api_register.norm_inner_op_set_hook_func()
|
|
88
90
|
return tensor_stat
|
|
89
91
|
|
|
92
|
+
@staticmethod
|
|
93
|
+
def get_stat_info_async(data):
|
|
94
|
+
tensor_stat = TensorStatInfo()
|
|
95
|
+
stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
|
|
96
|
+
if data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
97
|
+
logger.warning("Async dump do not support complex data!")
|
|
98
|
+
return tensor_stat
|
|
99
|
+
elif data.dtype == ms.bool_:
|
|
100
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
|
|
101
|
+
elif not data.shape:
|
|
102
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
|
|
103
|
+
else:
|
|
104
|
+
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
105
|
+
data = data.to(ms.float32)
|
|
106
|
+
api_register.norm_inner_op_set_ori_func()
|
|
107
|
+
get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
|
|
108
|
+
get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
|
|
109
|
+
get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
|
|
110
|
+
if hasattr(mint, "norm"):
|
|
111
|
+
get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
|
|
112
|
+
else:
|
|
113
|
+
get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
|
|
114
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
|
|
115
|
+
[get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
|
|
116
|
+
api_register.norm_inner_op_set_hook_func()
|
|
117
|
+
return tensor_stat
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def get_special_types(cls):
|
|
121
|
+
return super().get_special_types() + cls.mindspore_special_type
|
|
122
|
+
|
|
123
|
+
def get_stat_info(self, data):
|
|
124
|
+
tensor_stat = TensorStatInfo()
|
|
125
|
+
if data.numel() == 0:
|
|
126
|
+
return tensor_stat
|
|
127
|
+
else:
|
|
128
|
+
if self.config.async_dump:
|
|
129
|
+
return MindsporeDataProcessor.get_stat_info_async(data)
|
|
130
|
+
else:
|
|
131
|
+
return MindsporeDataProcessor.get_stat_info_sync(data)
|
|
132
|
+
|
|
90
133
|
def analyze_single_element(self, element, suffix_stack):
|
|
91
134
|
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
92
135
|
return self.mindspore_object_key[suffix_stack[-1]](element)
|
|
@@ -107,13 +150,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
107
150
|
tensor_json = {
|
|
108
151
|
'type': 'mindspore.Tensor',
|
|
109
152
|
'dtype': str(tensor.dtype),
|
|
110
|
-
'shape': tensor.shape
|
|
111
|
-
'Max': self.transfer_type(tensor_stat.max),
|
|
112
|
-
'Min': self.transfer_type(tensor_stat.min),
|
|
113
|
-
'Mean': self.transfer_type(tensor_stat.mean),
|
|
114
|
-
'Norm': self.transfer_type(tensor_stat.norm),
|
|
153
|
+
'shape': tensor.shape
|
|
115
154
|
}
|
|
116
|
-
|
|
155
|
+
|
|
156
|
+
if tensor_stat.stack_tensor_stat is None:
|
|
157
|
+
tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
|
|
158
|
+
tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
|
|
159
|
+
tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
|
|
160
|
+
tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
|
|
161
|
+
else:
|
|
162
|
+
tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
|
|
163
|
+
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
117
164
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
118
165
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
119
166
|
return tensor_json
|
|
@@ -124,11 +171,19 @@ class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
|
124
171
|
|
|
125
172
|
|
|
126
173
|
class TensorDataProcessor(MindsporeDataProcessor):
|
|
174
|
+
def dump_async_data(self):
|
|
175
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
176
|
+
save_tensor_as_npy(tensor, file_path)
|
|
177
|
+
self._async_dump_cache.clear()
|
|
178
|
+
|
|
127
179
|
def _analyze_tensor(self, tensor, suffix):
|
|
128
180
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
129
181
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
130
182
|
single_arg.update({"data_name": dump_data_name})
|
|
131
|
-
|
|
183
|
+
if self.config.async_dump:
|
|
184
|
+
self._async_dump_cache[file_path] = tensor.copy()
|
|
185
|
+
else:
|
|
186
|
+
save_tensor_as_npy(tensor, file_path)
|
|
132
187
|
return single_arg
|
|
133
188
|
|
|
134
189
|
|
|
@@ -138,6 +193,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
138
193
|
def __init__(self, config, data_writer):
|
|
139
194
|
super().__init__(config, data_writer)
|
|
140
195
|
self.has_overflow = False
|
|
196
|
+
self.cached_api_info = {}
|
|
141
197
|
self.cached_tensors_and_file_paths = {}
|
|
142
198
|
self.real_overflow_nums = 0
|
|
143
199
|
self.overflow_nums = config.overflow_nums
|
|
@@ -150,6 +206,20 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
150
206
|
return True
|
|
151
207
|
return False
|
|
152
208
|
|
|
209
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
210
|
+
self.has_overflow = False
|
|
211
|
+
self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
215
|
+
api_info_struct = super().analyze_forward_output(name, module, module_input_output)
|
|
216
|
+
if name in self.cached_api_info and name in api_info_struct:
|
|
217
|
+
self.cached_api_info[name].update(api_info_struct[name])
|
|
218
|
+
elif name in api_info_struct:
|
|
219
|
+
self.cached_api_info = api_info_struct
|
|
220
|
+
self.maybe_save_overflow_data()
|
|
221
|
+
return self.cached_api_info if self.has_overflow else None
|
|
222
|
+
|
|
153
223
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
154
224
|
self.has_overflow = False
|
|
155
225
|
api_info_struct = super().analyze_forward(name, module, module_input_output)
|
|
@@ -161,6 +231,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
161
231
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
162
232
|
self.maybe_save_overflow_data()
|
|
163
233
|
return api_info_struct if self.has_overflow else None
|
|
234
|
+
|
|
235
|
+
def analyze_params(self, name, param_name, grad):
|
|
236
|
+
self.has_overflow = False
|
|
237
|
+
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
238
|
+
self.maybe_save_overflow_data()
|
|
239
|
+
return api_info_struct if self.has_overflow else None
|
|
164
240
|
|
|
165
241
|
def maybe_save_overflow_data(self):
|
|
166
242
|
if self.has_overflow:
|
|
@@ -190,3 +266,61 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
190
266
|
self._analyze_maybe_overflow_tensor(single_arg)
|
|
191
267
|
single_arg.update({"data_name": dump_data_name})
|
|
192
268
|
return single_arg
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class KernelDumpDataProcessor(MindsporeDataProcessor):
|
|
272
|
+
def __init__(self, config, data_writer):
|
|
273
|
+
super().__init__(config, data_writer)
|
|
274
|
+
self.enable_kernel_dump = True
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def start_kernel_dump(config_path):
|
|
278
|
+
hal.synchronize()
|
|
279
|
+
_msprobe_c.init_dump()
|
|
280
|
+
_msprobe_c.set_dump(config_path)
|
|
281
|
+
hal.synchronize()
|
|
282
|
+
|
|
283
|
+
@staticmethod
|
|
284
|
+
def stop_kernel_dump():
|
|
285
|
+
hal.synchronize()
|
|
286
|
+
_msprobe_c.finalize_dump()
|
|
287
|
+
hal.synchronize()
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def _print_unsupported_log(api_name):
|
|
291
|
+
logger.warning(f"The kernel dump does not support the {api_name} API.")
|
|
292
|
+
|
|
293
|
+
def analyze_forward_input(self, name, module, module_input_output):
|
|
294
|
+
if not self.enable_kernel_dump:
|
|
295
|
+
return
|
|
296
|
+
if not has_adump:
|
|
297
|
+
logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
|
|
298
|
+
self.enable_kernel_dump = False
|
|
299
|
+
return
|
|
300
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
301
|
+
|
|
302
|
+
def analyze_forward_output(self, name, module, module_input_output):
|
|
303
|
+
if not self.enable_kernel_dump:
|
|
304
|
+
return
|
|
305
|
+
self.enable_kernel_dump = False
|
|
306
|
+
self.stop_kernel_dump()
|
|
307
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
308
|
+
|
|
309
|
+
def analyze_backward_input(self, name, module, module_input_output):
|
|
310
|
+
if not self.enable_kernel_dump:
|
|
311
|
+
return
|
|
312
|
+
if not has_adump:
|
|
313
|
+
logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
|
|
314
|
+
self.enable_kernel_dump = False
|
|
315
|
+
return
|
|
316
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
317
|
+
|
|
318
|
+
def analyze_backward(self, name, module, module_input_output):
|
|
319
|
+
if not self.enable_kernel_dump:
|
|
320
|
+
return
|
|
321
|
+
self.enable_kernel_dump = False
|
|
322
|
+
self.stop_kernel_dump()
|
|
323
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
324
|
+
|
|
325
|
+
def reset_status(self):
|
|
326
|
+
self.enable_kernel_dump = True
|
|
@@ -54,6 +54,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
54
54
|
"device": self.analyze_device_in_kwargs,
|
|
55
55
|
"dtype": self.analyze_dtype_in_kwargs
|
|
56
56
|
}
|
|
57
|
+
self._async_dump_cache = {}
|
|
57
58
|
|
|
58
59
|
@staticmethod
|
|
59
60
|
def get_md5_for_tensor(x):
|
|
@@ -82,49 +83,80 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
82
83
|
return {"type": "torch.dtype", "value": str(element)}
|
|
83
84
|
|
|
84
85
|
@staticmethod
|
|
85
|
-
def
|
|
86
|
+
def get_stat_info_async(data):
|
|
86
87
|
tensor_stat = TensorStatInfo()
|
|
87
|
-
if data
|
|
88
|
-
|
|
89
|
-
data_clone = data.detach()
|
|
90
|
-
if data_clone.numel() == 0:
|
|
88
|
+
if torch.is_complex(data):
|
|
89
|
+
logger.warning("Async dump do not support complex data!")
|
|
91
90
|
return tensor_stat
|
|
92
|
-
elif
|
|
93
|
-
tensor_stat.
|
|
94
|
-
|
|
95
|
-
elif not
|
|
96
|
-
tensor_stat.
|
|
97
|
-
|
|
98
|
-
|
|
91
|
+
elif data.dtype == torch.bool:
|
|
92
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
|
|
93
|
+
[torch.any(data), torch.all(data)]))
|
|
94
|
+
elif not data.shape:
|
|
95
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
|
|
96
|
+
else:
|
|
97
|
+
if not data.is_floating_point() or data.dtype == torch.float64:
|
|
98
|
+
data = data.float()
|
|
99
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
|
|
100
|
+
torch.max(data),
|
|
101
|
+
torch.min(data),
|
|
102
|
+
torch.mean(data),
|
|
103
|
+
torch.norm(data)
|
|
104
|
+
]))
|
|
105
|
+
return tensor_stat
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def get_stat_info_sync(data):
|
|
109
|
+
tensor_stat = TensorStatInfo()
|
|
110
|
+
if torch.is_complex(data):
|
|
111
|
+
data_np = data.cpu().numpy()
|
|
99
112
|
data_abs = np.abs(data_np)
|
|
100
113
|
tensor_stat.max = np.max(data_abs).item()
|
|
101
114
|
tensor_stat.min = np.min(data_abs).item()
|
|
102
115
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
116
|
+
elif data.dtype == torch.bool:
|
|
117
|
+
tensor_stat.max = torch.any(data).item()
|
|
118
|
+
tensor_stat.min = torch.all(data).item()
|
|
119
|
+
elif not data.shape:
|
|
120
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
|
|
103
121
|
else:
|
|
104
|
-
if not
|
|
105
|
-
|
|
106
|
-
tensor_stat.max = torch.
|
|
107
|
-
tensor_stat.min = torch.
|
|
108
|
-
tensor_stat.mean = torch.
|
|
109
|
-
tensor_stat.norm = torch.
|
|
122
|
+
if not data.is_floating_point() or data.dtype == torch.float64:
|
|
123
|
+
data = data.float()
|
|
124
|
+
tensor_stat.max = torch.max(data).item()
|
|
125
|
+
tensor_stat.min = torch.min(data).item()
|
|
126
|
+
tensor_stat.mean = torch.mean(data).item()
|
|
127
|
+
tensor_stat.norm = torch.norm(data).item()
|
|
110
128
|
return tensor_stat
|
|
111
129
|
|
|
130
|
+
@staticmethod
|
|
131
|
+
def get_stat_info(data, async_dump=False):
|
|
132
|
+
tensor_stat = TensorStatInfo()
|
|
133
|
+
if data.is_meta:
|
|
134
|
+
return tensor_stat
|
|
135
|
+
data_clone = data.detach()
|
|
136
|
+
if data_clone.numel() == 0:
|
|
137
|
+
return tensor_stat
|
|
138
|
+
else:
|
|
139
|
+
if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
|
|
140
|
+
return PytorchDataProcessor.get_stat_info_sync(data_clone)
|
|
141
|
+
else:
|
|
142
|
+
return PytorchDataProcessor.get_stat_info_async(data_clone)
|
|
143
|
+
|
|
112
144
|
@staticmethod
|
|
113
145
|
def handle_tensor_extremum_nan_inf(tensor, operator):
|
|
114
146
|
data_clone = tensor.detach()
|
|
115
|
-
data_nan = torch.
|
|
116
|
-
if int(torch.
|
|
147
|
+
data_nan = torch.isnan(data_clone)
|
|
148
|
+
if int(torch.sum(data_nan)) == data_clone.numel():
|
|
117
149
|
return float('nan')
|
|
118
150
|
|
|
119
|
-
finite_mask = torch.
|
|
120
|
-
if int(torch.
|
|
121
|
-
finite_values =
|
|
122
|
-
return torch.
|
|
123
|
-
torch.
|
|
151
|
+
finite_mask = torch.isfinite(data_clone)
|
|
152
|
+
if int(torch.sum(finite_mask)) > 0:
|
|
153
|
+
finite_values = data_clone[finite_mask]
|
|
154
|
+
return torch.max(finite_values).item() if operator == 'max' else \
|
|
155
|
+
torch.min(finite_values).item()
|
|
124
156
|
else:
|
|
125
|
-
data_no_nan =
|
|
126
|
-
return torch.
|
|
127
|
-
torch.
|
|
157
|
+
data_no_nan = data_clone[~data_nan]
|
|
158
|
+
return torch.max(data_no_nan).item() if operator == 'max' else \
|
|
159
|
+
torch.min(data_no_nan).item()
|
|
128
160
|
|
|
129
161
|
@staticmethod
|
|
130
162
|
def process_group_hash(arg):
|
|
@@ -132,6 +164,10 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
132
164
|
group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
|
|
133
165
|
return group_ranks_hash
|
|
134
166
|
|
|
167
|
+
@staticmethod
|
|
168
|
+
def is_distributed_op(module):
|
|
169
|
+
return getattr(module, "op_is_distributed", False)
|
|
170
|
+
|
|
135
171
|
@staticmethod
|
|
136
172
|
def _analyze_torch_size(arg):
|
|
137
173
|
return {"type": "torch.Size", "value": list(arg)}
|
|
@@ -177,26 +213,35 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
177
213
|
return self._analyze_builtin(element)
|
|
178
214
|
return {}
|
|
179
215
|
|
|
216
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
217
|
+
if self.is_distributed_op(module):
|
|
218
|
+
module_input_output.update_output_with_args_and_kwargs()
|
|
219
|
+
return super().analyze_forward_output(name, module, module_input_output)
|
|
220
|
+
|
|
180
221
|
def _analyze_tensor(self, tensor, suffix):
|
|
181
|
-
tensor_stat = self.get_stat_info(tensor)
|
|
222
|
+
tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
|
|
182
223
|
tensor_json = {}
|
|
183
224
|
tensor_json.update({'type': 'torch.Tensor'})
|
|
184
225
|
tensor_json.update({'dtype': str(tensor.dtype)})
|
|
185
226
|
tensor_json.update({"shape": tensor.shape})
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
227
|
+
if tensor_stat.stack_tensor_stat is None:
|
|
228
|
+
tensor_json.update({"Max": tensor_stat.max})
|
|
229
|
+
tensor_json.update({"Min": tensor_stat.min})
|
|
230
|
+
tensor_json.update({"Mean": tensor_stat.mean})
|
|
231
|
+
tensor_json.update({"Norm": tensor_stat.norm})
|
|
232
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
233
|
+
if tensor_stat.max is not None:
|
|
234
|
+
if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
|
|
235
|
+
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
|
|
236
|
+
if tensor_stat.min is not None:
|
|
237
|
+
if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
|
|
238
|
+
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
|
|
239
|
+
|
|
240
|
+
else:
|
|
241
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
242
|
+
tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
|
|
243
|
+
|
|
244
|
+
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
200
245
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
201
246
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
202
247
|
return tensor_json
|
|
@@ -207,12 +252,20 @@ class StatisticsDataProcessor(PytorchDataProcessor):
|
|
|
207
252
|
|
|
208
253
|
|
|
209
254
|
class TensorDataProcessor(PytorchDataProcessor):
|
|
255
|
+
def dump_async_data(self):
|
|
256
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
257
|
+
save_pt(tensor.contiguous(), file_path)
|
|
258
|
+
self._async_dump_cache.clear()
|
|
259
|
+
|
|
210
260
|
def _analyze_tensor(self, tensor, suffix):
|
|
211
261
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
212
|
-
saved_tensor = tensor.clone().contiguous().detach()
|
|
213
|
-
save_pt(saved_tensor, file_path)
|
|
214
262
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
215
263
|
single_arg.update({"data_name": dump_data_name})
|
|
264
|
+
if self.config.async_dump:
|
|
265
|
+
self._async_dump_cache[file_path] = tensor.clone().detach()
|
|
266
|
+
else:
|
|
267
|
+
saved_tensor = tensor.clone().contiguous().detach()
|
|
268
|
+
save_pt(saved_tensor, file_path)
|
|
216
269
|
return single_arg
|
|
217
270
|
|
|
218
271
|
|
|
@@ -223,7 +276,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
223
276
|
super().__init__(config, data_writer)
|
|
224
277
|
self.has_overflow = False
|
|
225
278
|
self.support_inf_nan = None
|
|
226
|
-
self.
|
|
279
|
+
self.cached_api_info = {}
|
|
227
280
|
self.cached_tensors_and_file_paths = {}
|
|
228
281
|
self.bits_for_overflow = 8
|
|
229
282
|
self.real_overflow_nums = 0
|
|
@@ -237,21 +290,21 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
237
290
|
return True
|
|
238
291
|
return False
|
|
239
292
|
|
|
240
|
-
def
|
|
293
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
241
294
|
self.has_overflow = False
|
|
242
295
|
self._is_support_inf_nan()
|
|
243
|
-
self.
|
|
296
|
+
self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
|
|
244
297
|
return None
|
|
245
298
|
|
|
246
|
-
def
|
|
299
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
247
300
|
self._is_support_inf_nan()
|
|
248
|
-
api_info_struct = super().
|
|
249
|
-
if name in self.
|
|
250
|
-
self.
|
|
301
|
+
api_info_struct = super().analyze_forward_output(name, module, module_input_output)
|
|
302
|
+
if name in self.cached_api_info and name in api_info_struct:
|
|
303
|
+
self.cached_api_info[name].update(api_info_struct[name])
|
|
251
304
|
elif name in api_info_struct:
|
|
252
|
-
self.
|
|
305
|
+
self.cached_api_info = api_info_struct
|
|
253
306
|
self.handle_overflow()
|
|
254
|
-
return self.
|
|
307
|
+
return self.cached_api_info if self.has_overflow else None
|
|
255
308
|
|
|
256
309
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
257
310
|
self.has_overflow = False
|
|
@@ -266,6 +319,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
266
319
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
267
320
|
self.handle_overflow()
|
|
268
321
|
return api_info_struct if self.has_overflow else None
|
|
322
|
+
|
|
323
|
+
def analyze_params(self, name, param_name, grad):
|
|
324
|
+
self.has_overflow = False
|
|
325
|
+
self._is_support_inf_nan()
|
|
326
|
+
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
327
|
+
self.handle_overflow()
|
|
328
|
+
return api_info_struct if self.has_overflow else None
|
|
269
329
|
|
|
270
330
|
def handle_overflow(self):
|
|
271
331
|
if not self.support_inf_nan:
|
|
@@ -340,10 +400,10 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
|
|
|
340
400
|
)
|
|
341
401
|
return
|
|
342
402
|
|
|
343
|
-
def
|
|
403
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
344
404
|
self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
|
|
345
405
|
|
|
346
|
-
def
|
|
406
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
347
407
|
new_output, unequal_rows = self.checker.forward(
|
|
348
408
|
name,
|
|
349
409
|
module,
|
|
@@ -388,7 +448,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
388
448
|
def _print_unsupported_log(api_name):
|
|
389
449
|
logger.warning(f"The kernel dump does not support the {api_name} API.")
|
|
390
450
|
|
|
391
|
-
def
|
|
451
|
+
def analyze_forward_input(self, name, module, module_input_output):
|
|
392
452
|
if not self.enable_kernel_dump:
|
|
393
453
|
return
|
|
394
454
|
if is_gpu:
|
|
@@ -413,7 +473,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
413
473
|
return
|
|
414
474
|
self.start_kernel_dump(self.config.kernel_config_path)
|
|
415
475
|
|
|
416
|
-
def
|
|
476
|
+
def analyze_forward_output(self, name, module, module_input_output):
|
|
417
477
|
if not self.enable_kernel_dump:
|
|
418
478
|
return
|
|
419
479
|
if self.config.is_backward_kernel_dump:
|
|
@@ -15,10 +15,12 @@
|
|
|
15
15
|
|
|
16
16
|
import csv
|
|
17
17
|
import os
|
|
18
|
+
import numpy as np
|
|
18
19
|
|
|
19
20
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
20
|
-
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
|
|
21
|
+
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
|
|
21
22
|
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
class DataWriter:
|
|
@@ -115,3 +117,29 @@ class DataWriter:
|
|
|
115
117
|
self.write_stack_info_json(self.stack_file_path)
|
|
116
118
|
if self.cache_construct:
|
|
117
119
|
self.write_construct_info_json(self.construct_file_path)
|
|
120
|
+
|
|
121
|
+
def fill_stack_tensor_data(self):
|
|
122
|
+
self.process_stat_data_recursive(self.cache_data)
|
|
123
|
+
|
|
124
|
+
def process_stat_data_recursive(self, data, depth=0):
|
|
125
|
+
if depth > Const.MAX_DEPTH:
|
|
126
|
+
logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
|
|
127
|
+
raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
|
|
128
|
+
if isinstance(data, dict):
|
|
129
|
+
if "tensor_stat" in data.keys():
|
|
130
|
+
tensor_stat = data["tensor_stat"]
|
|
131
|
+
if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
|
|
132
|
+
logger.warning("Some bad data in async dump")
|
|
133
|
+
else:
|
|
134
|
+
tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
|
|
135
|
+
if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
|
|
136
|
+
tensor_stat_data = tensor_stat_data.cpu()
|
|
137
|
+
for index, stat in zip(tensor_stat_index, tensor_stat_data):
|
|
138
|
+
data.update({index, stat.item()})
|
|
139
|
+
del data["tensor_stat"]
|
|
140
|
+
else:
|
|
141
|
+
for key in data.keys():
|
|
142
|
+
self.process_stat_data_recursive(data[key], depth + 1)
|
|
143
|
+
elif isinstance(data, (list, tuple)):
|
|
144
|
+
for i in data:
|
|
145
|
+
self.process_stat_data_recursive(i, depth + 1)
|