mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -17,6 +17,9 @@ import inspect
|
|
|
17
17
|
import os
|
|
18
18
|
from dataclasses import dataclass, is_dataclass
|
|
19
19
|
from typing import Tuple, Dict, Optional, Any
|
|
20
|
+
from functools import partial
|
|
21
|
+
import copy
|
|
22
|
+
from typing import Union
|
|
20
23
|
|
|
21
24
|
import numpy as np
|
|
22
25
|
|
|
@@ -39,9 +42,8 @@ class ModuleForwardInputsOutputs:
|
|
|
39
42
|
def output_tuple(self):
|
|
40
43
|
return convert_tuple(self.output)
|
|
41
44
|
|
|
42
|
-
def
|
|
43
|
-
|
|
44
|
-
return args
|
|
45
|
+
def update_output_with_args_and_kwargs(self):
|
|
46
|
+
self.output = self.args + tuple(self.kwargs.values())
|
|
45
47
|
|
|
46
48
|
|
|
47
49
|
@dataclass
|
|
@@ -77,17 +79,18 @@ class ModuleBackwardOutputs:
|
|
|
77
79
|
|
|
78
80
|
|
|
79
81
|
class TensorStatInfo:
|
|
80
|
-
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
82
|
+
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
|
|
81
83
|
self.max = max_val
|
|
82
84
|
self.min = min_val
|
|
83
85
|
self.mean = mean_val
|
|
84
86
|
self.norm = norm_val
|
|
87
|
+
self.stack_tensor_stat = stack_tensor_stat
|
|
85
88
|
|
|
86
89
|
|
|
87
90
|
class BaseDataProcessor:
|
|
88
91
|
_recursive_key_stack = []
|
|
89
92
|
special_type = (
|
|
90
|
-
np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
93
|
+
np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
|
|
91
94
|
bool, int, float, str, slice,
|
|
92
95
|
type(Ellipsis)
|
|
93
96
|
)
|
|
@@ -102,6 +105,7 @@ class BaseDataProcessor:
|
|
|
102
105
|
self.current_iter = 0
|
|
103
106
|
self._return_forward_new_output = False
|
|
104
107
|
self._forward_new_output = None
|
|
108
|
+
self.save_name = None
|
|
105
109
|
if hasattr(config, "data_mode"):
|
|
106
110
|
self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
|
|
107
111
|
|
|
@@ -142,6 +146,37 @@ class BaseDataProcessor:
|
|
|
142
146
|
else:
|
|
143
147
|
return data
|
|
144
148
|
|
|
149
|
+
@staticmethod
|
|
150
|
+
def set_value_into_nested_structure(data_structure, indexes, value):
|
|
151
|
+
'''
|
|
152
|
+
Args:
|
|
153
|
+
data_structure: nested data structure
|
|
154
|
+
indexes: List
|
|
155
|
+
value: value to be set
|
|
156
|
+
'''
|
|
157
|
+
if not indexes:
|
|
158
|
+
raise ValueError("set_value_into_nested_structure failed: "
|
|
159
|
+
"indexes need to be non empty when set value to nested data structure")
|
|
160
|
+
current_level = data_structure
|
|
161
|
+
for i, index in enumerate(indexes):
|
|
162
|
+
valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index
|
|
163
|
+
valid_for_dict = isinstance(current_level, dict) and index in current_level
|
|
164
|
+
is_last = i == len(indexes) - 1
|
|
165
|
+
if valid_for_dict or valid_for_list:
|
|
166
|
+
if is_last:
|
|
167
|
+
try:
|
|
168
|
+
current_level[index] = value
|
|
169
|
+
except Exception as e:
|
|
170
|
+
raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
|
|
171
|
+
else:
|
|
172
|
+
try:
|
|
173
|
+
current_level = current_level[index]
|
|
174
|
+
except Exception as e:
|
|
175
|
+
raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError("set_value_into_nested_structure failed: "
|
|
178
|
+
"invalid data_structure type or invalid index")
|
|
179
|
+
|
|
145
180
|
@staticmethod
|
|
146
181
|
def _convert_numpy_to_builtin(arg):
|
|
147
182
|
type_mapping = {
|
|
@@ -182,8 +217,22 @@ class BaseDataProcessor:
|
|
|
182
217
|
return single_arg
|
|
183
218
|
|
|
184
219
|
@staticmethod
|
|
185
|
-
def _analyze_numpy(
|
|
186
|
-
|
|
220
|
+
def _analyze_numpy(ndarray, numpy_type):
|
|
221
|
+
ndarray_json = {}
|
|
222
|
+
ndarray_json.update({'type': 'numpy.ndarray'})
|
|
223
|
+
ndarray_json.update({'dtype': str(ndarray.dtype)})
|
|
224
|
+
ndarray_json.update({'shape': ndarray.shape})
|
|
225
|
+
if ndarray.size > 0:
|
|
226
|
+
ndarray_json.update({"Max": np.max(ndarray).item()})
|
|
227
|
+
ndarray_json.update({"Min": np.min(ndarray).item()})
|
|
228
|
+
ndarray_json.update({"Mean": np.mean(ndarray).item()})
|
|
229
|
+
ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
|
|
230
|
+
else:
|
|
231
|
+
ndarray_json.update({"Max": None})
|
|
232
|
+
ndarray_json.update({"Min": None})
|
|
233
|
+
ndarray_json.update({"Mean": None})
|
|
234
|
+
ndarray_json.update({"Norm": None})
|
|
235
|
+
return ndarray_json
|
|
187
236
|
|
|
188
237
|
@staticmethod
|
|
189
238
|
def _get_allowed_data_mode(data_mode):
|
|
@@ -202,7 +251,7 @@ class BaseDataProcessor:
|
|
|
202
251
|
return cls.special_type
|
|
203
252
|
|
|
204
253
|
@classmethod
|
|
205
|
-
def recursive_apply_transform(cls, args, transform, depth=0):
|
|
254
|
+
def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
|
|
206
255
|
if depth > Const.MAX_DEPTH:
|
|
207
256
|
logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
|
|
208
257
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
@@ -219,20 +268,20 @@ class BaseDataProcessor:
|
|
|
219
268
|
return cls.apply_transform_dict(args_dict, transform, depth)
|
|
220
269
|
elif isinstance(args, (list, tuple)):
|
|
221
270
|
result_list = cls.apply_transform_list(args, transform, depth)
|
|
222
|
-
return
|
|
271
|
+
return result_list
|
|
223
272
|
elif isinstance(args, dict):
|
|
224
273
|
return cls.apply_transform_dict(args, transform, depth)
|
|
225
274
|
elif args is not None:
|
|
226
|
-
logger.
|
|
275
|
+
logger.debug(f"Data type {type(args)} is not supported.")
|
|
227
276
|
return None
|
|
228
277
|
else:
|
|
229
278
|
return None
|
|
230
|
-
|
|
279
|
+
|
|
231
280
|
@classmethod
|
|
232
281
|
def apply_transform_dict(cls, args, transform, depth):
|
|
233
282
|
result_dict = {}
|
|
234
283
|
for k, arg in args.items():
|
|
235
|
-
cls._recursive_key_stack.append(
|
|
284
|
+
cls._recursive_key_stack.append(k)
|
|
236
285
|
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
237
286
|
cls._recursive_key_stack.pop()
|
|
238
287
|
return result_dict
|
|
@@ -241,11 +290,21 @@ class BaseDataProcessor:
|
|
|
241
290
|
def apply_transform_list(cls, args, transform, depth):
|
|
242
291
|
result_list = []
|
|
243
292
|
for i, arg in enumerate(args):
|
|
244
|
-
cls._recursive_key_stack.append(
|
|
293
|
+
cls._recursive_key_stack.append(i)
|
|
245
294
|
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
246
295
|
cls._recursive_key_stack.pop()
|
|
247
296
|
return result_list
|
|
248
297
|
|
|
298
|
+
@classmethod
|
|
299
|
+
def register_hook_single_element(cls, element, suffix_stack, hook_fn):
|
|
300
|
+
if cls.is_hookable_element(element):
|
|
301
|
+
indexes = copy.deepcopy(suffix_stack)
|
|
302
|
+
wrap_hook_fn = partial(hook_fn, indexes=indexes)
|
|
303
|
+
|
|
304
|
+
def real_hook_fn(grad):
|
|
305
|
+
return wrap_hook_fn(grad)
|
|
306
|
+
element.register_hook(real_hook_fn)
|
|
307
|
+
|
|
249
308
|
def if_return_forward_new_output(self):
|
|
250
309
|
return self._return_forward_new_output
|
|
251
310
|
|
|
@@ -273,13 +332,10 @@ class BaseDataProcessor:
|
|
|
273
332
|
"""
|
|
274
333
|
return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
|
|
275
334
|
|
|
276
|
-
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
277
|
-
pass
|
|
278
|
-
|
|
279
335
|
def analyze_element(self, element):
|
|
280
336
|
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
281
337
|
|
|
282
|
-
def
|
|
338
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
283
339
|
api_info_struct = {}
|
|
284
340
|
# check whether data_mode contains forward or input
|
|
285
341
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
@@ -291,16 +347,22 @@ class BaseDataProcessor:
|
|
|
291
347
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
292
348
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
293
349
|
|
|
294
|
-
|
|
350
|
+
return api_info_struct
|
|
351
|
+
|
|
352
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
353
|
+
api_info_struct = {}
|
|
354
|
+
# check whether data_mode contains forward or input
|
|
295
355
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
296
|
-
api_info_struct[name] =
|
|
356
|
+
api_info_struct[name] = {}
|
|
297
357
|
self.api_data_category = Const.OUTPUT
|
|
298
358
|
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
299
359
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
360
|
+
|
|
300
361
|
return api_info_struct
|
|
301
362
|
|
|
302
|
-
def
|
|
363
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
303
364
|
api_info_struct = {}
|
|
365
|
+
# check whether data_mode contains forward or input
|
|
304
366
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
305
367
|
api_info_struct[name] = {}
|
|
306
368
|
self.api_data_category = Const.INPUT
|
|
@@ -309,16 +371,18 @@ class BaseDataProcessor:
|
|
|
309
371
|
self.api_data_category = Const.KWARGS
|
|
310
372
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
311
373
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
312
|
-
return api_info_struct
|
|
313
374
|
|
|
314
|
-
|
|
315
|
-
concat_args = module_input_output.concat_args_and_kwargs()
|
|
316
|
-
api_info_struct = {}
|
|
375
|
+
# check whether data_mode contains forward or output
|
|
317
376
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
318
|
-
api_info_struct[name] = {}
|
|
377
|
+
api_info_struct[name] = api_info_struct.get(name, {})
|
|
319
378
|
self.api_data_category = Const.OUTPUT
|
|
320
|
-
output_info_list = self.analyze_element(
|
|
379
|
+
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
321
380
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
381
|
+
|
|
382
|
+
if name in api_info_struct and hasattr(module_input_output, Const.PARAMS):
|
|
383
|
+
self.api_data_category = Const.PARAMS
|
|
384
|
+
api_info_struct[name][Const.PARAMS] = self.analyze_element(getattr(module_input_output, Const.PARAMS))
|
|
385
|
+
|
|
322
386
|
return api_info_struct
|
|
323
387
|
|
|
324
388
|
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
@@ -359,9 +423,47 @@ class BaseDataProcessor:
|
|
|
359
423
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
360
424
|
return api_info_struct
|
|
361
425
|
|
|
426
|
+
def analyze_params(self, name, param_name, grad):
|
|
427
|
+
api_info_struct = {}
|
|
428
|
+
self.save_name = name + Const.SEP + param_name
|
|
429
|
+
data_info = self.analyze_element(grad)
|
|
430
|
+
grad_info_dict = {param_name: [data_info]}
|
|
431
|
+
api_info_struct[name] = grad_info_dict
|
|
432
|
+
return api_info_struct
|
|
433
|
+
|
|
362
434
|
def get_save_file_path(self, suffix):
|
|
363
435
|
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
|
|
364
|
-
|
|
365
|
-
|
|
436
|
+
if self.save_name is not None:
|
|
437
|
+
dump_data_name = (self.save_name + file_format)
|
|
438
|
+
self.save_name = None
|
|
439
|
+
else:
|
|
440
|
+
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
|
|
441
|
+
suffix + file_format)
|
|
366
442
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
367
443
|
return dump_data_name, file_path
|
|
444
|
+
|
|
445
|
+
def analyze_element_to_all_none(self, element):
|
|
446
|
+
return self.recursive_apply_transform(element, lambda element, stack: None)
|
|
447
|
+
|
|
448
|
+
def analyze_debug_forward(self, variable, name_with_count):
|
|
449
|
+
self.current_api_or_module_name = name_with_count
|
|
450
|
+
self.api_data_category = Const.TENSOR
|
|
451
|
+
# these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
|
|
452
|
+
data_info = self.analyze_element(variable)
|
|
453
|
+
return data_info
|
|
454
|
+
|
|
455
|
+
def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
|
|
456
|
+
def hook_fn(grad, indexes):
|
|
457
|
+
suffix = Const.SEP.join([str(index) for index in indexes])
|
|
458
|
+
self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
|
|
459
|
+
grad_data_info = self.analyze_element(grad)
|
|
460
|
+
self.save_name = None
|
|
461
|
+
full_index = [grad_name_with_count] + indexes
|
|
462
|
+
try:
|
|
463
|
+
self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
|
|
464
|
+
except (ValueError, IndexError) as e:
|
|
465
|
+
logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
|
|
466
|
+
f"skip current recording, detailed infomation: {e}")
|
|
467
|
+
return grad
|
|
468
|
+
wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
|
|
469
|
+
self.recursive_apply_transform(variable, wrap_register_hook_single_element)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.const import Const
|
|
17
|
+
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class DataProcessorFactory:
|
|
@@ -56,21 +57,25 @@ class DataProcessorFactory:
|
|
|
56
57
|
FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
|
|
57
58
|
KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
|
|
58
59
|
)
|
|
59
|
-
from msprobe.pytorch.module_processer import ModuleProcesser
|
|
60
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
60
61
|
cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
|
|
61
62
|
cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
|
|
62
63
|
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
|
|
63
64
|
cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
|
|
64
65
|
cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
|
|
66
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
|
|
65
67
|
cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
|
|
66
68
|
elif framework == Const.MS_FRAMEWORK:
|
|
67
69
|
from msprobe.core.data_dump.data_processor.mindspore_processor import (
|
|
68
70
|
StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
|
|
69
71
|
TensorDataProcessor as MindsporeTensorDataProcessor,
|
|
70
|
-
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
|
|
72
|
+
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
|
|
73
|
+
KernelDumpDataProcessor as MindsporeKernelDumpDataProcessor
|
|
71
74
|
)
|
|
72
75
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
73
76
|
cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
|
|
74
77
|
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
|
|
75
78
|
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
|
|
79
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
|
|
80
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
|
|
76
81
|
cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2024-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,18 +16,24 @@
|
|
|
16
16
|
import zlib
|
|
17
17
|
|
|
18
18
|
import mindspore as ms
|
|
19
|
-
from mindspore import mint, ops
|
|
19
|
+
from mindspore import mint, ops, hal
|
|
20
20
|
from mindspore._c_expression.typing import Number
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
24
|
from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
|
|
25
25
|
ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
|
|
26
|
-
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
26
|
+
from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
|
|
27
27
|
from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
|
|
28
28
|
from msprobe.mindspore.common.log import logger
|
|
29
29
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
30
30
|
|
|
31
|
+
has_adump = True
|
|
32
|
+
try:
|
|
33
|
+
from msprobe.lib import _msprobe_c
|
|
34
|
+
except ImportError:
|
|
35
|
+
has_adump = False
|
|
36
|
+
|
|
31
37
|
|
|
32
38
|
class MindsporeDataProcessor(BaseDataProcessor):
|
|
33
39
|
mindspore_special_type = tuple([ms.Tensor, Number])
|
|
@@ -37,11 +43,12 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
37
43
|
self.mindspore_object_key = {
|
|
38
44
|
"dtype": self.analyze_dtype_in_kwargs
|
|
39
45
|
}
|
|
46
|
+
self._async_dump_cache = {}
|
|
40
47
|
|
|
41
48
|
@staticmethod
|
|
42
49
|
def get_md5_for_tensor(x):
|
|
43
50
|
x = convert_bf16_to_fp32(x)
|
|
44
|
-
tensor_bytes = x.
|
|
51
|
+
tensor_bytes = x.asnumpy().tobytes()
|
|
45
52
|
crc32_hash = zlib.crc32(tensor_bytes)
|
|
46
53
|
return f"{crc32_hash:08x}"
|
|
47
54
|
|
|
@@ -49,22 +56,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
49
56
|
def analyze_dtype_in_kwargs(element):
|
|
50
57
|
return {"type": "mindspore.dtype", "value": str(element)}
|
|
51
58
|
|
|
52
|
-
@
|
|
53
|
-
def
|
|
54
|
-
return super().get_special_types() + cls.mindspore_special_type
|
|
55
|
-
|
|
56
|
-
def get_stat_info(self, data):
|
|
59
|
+
@staticmethod
|
|
60
|
+
def get_stat_info_sync(data):
|
|
57
61
|
tensor_stat = TensorStatInfo()
|
|
58
|
-
if data.
|
|
59
|
-
|
|
60
|
-
elif data.dtype == ms.bool_:
|
|
61
|
-
data_np = data.contiguous().asnumpy()
|
|
62
|
+
if data.dtype == ms.bool_:
|
|
63
|
+
data_np = data.asnumpy()
|
|
62
64
|
tensor_stat.max = np.max(data_np).item()
|
|
63
65
|
tensor_stat.min = np.min(data_np).item()
|
|
64
66
|
elif not data.shape:
|
|
65
67
|
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
|
|
66
68
|
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
67
|
-
data_abs = np.abs(data.
|
|
69
|
+
data_abs = np.abs(data.asnumpy())
|
|
68
70
|
tensor_stat.max = np.max(data_abs).item()
|
|
69
71
|
tensor_stat.min = np.min(data_abs).item()
|
|
70
72
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
@@ -87,17 +89,64 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
87
89
|
api_register.norm_inner_op_set_hook_func()
|
|
88
90
|
return tensor_stat
|
|
89
91
|
|
|
92
|
+
@staticmethod
|
|
93
|
+
def get_stat_info_async(data):
|
|
94
|
+
tensor_stat = TensorStatInfo()
|
|
95
|
+
stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
|
|
96
|
+
if data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
97
|
+
logger.warning("Async dump do not support complex data!")
|
|
98
|
+
return tensor_stat
|
|
99
|
+
elif data.dtype == ms.bool_:
|
|
100
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
|
|
101
|
+
elif not data.shape:
|
|
102
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
|
|
103
|
+
else:
|
|
104
|
+
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
105
|
+
data = data.to(ms.float32)
|
|
106
|
+
api_register.norm_inner_op_set_ori_func()
|
|
107
|
+
get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
|
|
108
|
+
get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
|
|
109
|
+
get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
|
|
110
|
+
if hasattr(mint, "norm"):
|
|
111
|
+
get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
|
|
112
|
+
else:
|
|
113
|
+
get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
|
|
114
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
|
|
115
|
+
[get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
|
|
116
|
+
api_register.norm_inner_op_set_hook_func()
|
|
117
|
+
return tensor_stat
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def is_hookable_element(element):
|
|
121
|
+
return hasattr(element, "register_hook") and callable(element.register_hook)
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def get_special_types(cls):
|
|
125
|
+
return super().get_special_types() + cls.mindspore_special_type
|
|
126
|
+
|
|
127
|
+
def get_stat_info(self, data):
|
|
128
|
+
tensor_stat = TensorStatInfo()
|
|
129
|
+
if data.numel() == 0:
|
|
130
|
+
return tensor_stat
|
|
131
|
+
else:
|
|
132
|
+
if self.config.async_dump:
|
|
133
|
+
return MindsporeDataProcessor.get_stat_info_async(data)
|
|
134
|
+
else:
|
|
135
|
+
return MindsporeDataProcessor.get_stat_info_sync(data)
|
|
136
|
+
|
|
90
137
|
def analyze_single_element(self, element, suffix_stack):
|
|
91
138
|
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
92
139
|
return self.mindspore_object_key[suffix_stack[-1]](element)
|
|
93
140
|
|
|
94
141
|
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
95
142
|
if converted_numpy is not element:
|
|
96
|
-
return
|
|
143
|
+
return {"type": numpy_type, "value": converted_numpy}
|
|
97
144
|
if isinstance(element, Number):
|
|
98
145
|
return self.analyze_dtype_in_kwargs(element)
|
|
99
146
|
if isinstance(element, ms.Tensor):
|
|
100
|
-
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
|
|
147
|
+
return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
148
|
+
if isinstance(element, np.ndarray):
|
|
149
|
+
return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
101
150
|
if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
|
|
102
151
|
return self._analyze_builtin(element)
|
|
103
152
|
return {}
|
|
@@ -107,13 +156,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
107
156
|
tensor_json = {
|
|
108
157
|
'type': 'mindspore.Tensor',
|
|
109
158
|
'dtype': str(tensor.dtype),
|
|
110
|
-
'shape': tensor.shape
|
|
111
|
-
'Max': self.transfer_type(tensor_stat.max),
|
|
112
|
-
'Min': self.transfer_type(tensor_stat.min),
|
|
113
|
-
'Mean': self.transfer_type(tensor_stat.mean),
|
|
114
|
-
'Norm': self.transfer_type(tensor_stat.norm),
|
|
159
|
+
'shape': tensor.shape
|
|
115
160
|
}
|
|
116
|
-
|
|
161
|
+
|
|
162
|
+
if tensor_stat.stack_tensor_stat is None:
|
|
163
|
+
tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
|
|
164
|
+
tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
|
|
165
|
+
tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
|
|
166
|
+
tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
|
|
167
|
+
else:
|
|
168
|
+
tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
|
|
169
|
+
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
117
170
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
118
171
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
119
172
|
return tensor_json
|
|
@@ -124,12 +177,27 @@ class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
|
124
177
|
|
|
125
178
|
|
|
126
179
|
class TensorDataProcessor(MindsporeDataProcessor):
|
|
180
|
+
def dump_async_data(self):
|
|
181
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
182
|
+
save_tensor_as_npy(tensor, file_path)
|
|
183
|
+
self._async_dump_cache.clear()
|
|
184
|
+
|
|
127
185
|
def _analyze_tensor(self, tensor, suffix):
|
|
128
186
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
129
187
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
130
188
|
single_arg.update({"data_name": dump_data_name})
|
|
131
|
-
|
|
189
|
+
if self.config.async_dump:
|
|
190
|
+
self._async_dump_cache[file_path] = tensor.copy()
|
|
191
|
+
else:
|
|
192
|
+
save_tensor_as_npy(tensor, file_path)
|
|
132
193
|
return single_arg
|
|
194
|
+
|
|
195
|
+
def _analyze_numpy(self, ndarray, suffix):
|
|
196
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
197
|
+
save_npy(ndarray, file_path)
|
|
198
|
+
ndarray_json = super()._analyze_numpy(ndarray, suffix)
|
|
199
|
+
ndarray_json.update({"data_name": dump_data_name})
|
|
200
|
+
return ndarray_json
|
|
133
201
|
|
|
134
202
|
|
|
135
203
|
class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
@@ -138,6 +206,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
138
206
|
def __init__(self, config, data_writer):
|
|
139
207
|
super().__init__(config, data_writer)
|
|
140
208
|
self.has_overflow = False
|
|
209
|
+
self.cached_api_info = {}
|
|
141
210
|
self.cached_tensors_and_file_paths = {}
|
|
142
211
|
self.real_overflow_nums = 0
|
|
143
212
|
self.overflow_nums = config.overflow_nums
|
|
@@ -150,6 +219,20 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
150
219
|
return True
|
|
151
220
|
return False
|
|
152
221
|
|
|
222
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
223
|
+
self.has_overflow = False
|
|
224
|
+
self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
228
|
+
api_info_struct = super().analyze_forward_output(name, module, module_input_output)
|
|
229
|
+
if name in self.cached_api_info and name in api_info_struct:
|
|
230
|
+
self.cached_api_info[name].update(api_info_struct[name])
|
|
231
|
+
elif name in api_info_struct:
|
|
232
|
+
self.cached_api_info = api_info_struct
|
|
233
|
+
self.maybe_save_overflow_data()
|
|
234
|
+
return self.cached_api_info if self.has_overflow else None
|
|
235
|
+
|
|
153
236
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
154
237
|
self.has_overflow = False
|
|
155
238
|
api_info_struct = super().analyze_forward(name, module, module_input_output)
|
|
@@ -162,6 +245,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
162
245
|
self.maybe_save_overflow_data()
|
|
163
246
|
return api_info_struct if self.has_overflow else None
|
|
164
247
|
|
|
248
|
+
def analyze_params(self, name, param_name, grad):
|
|
249
|
+
self.has_overflow = False
|
|
250
|
+
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
251
|
+
self.maybe_save_overflow_data()
|
|
252
|
+
return api_info_struct if self.has_overflow else None
|
|
253
|
+
|
|
165
254
|
def maybe_save_overflow_data(self):
|
|
166
255
|
if self.has_overflow:
|
|
167
256
|
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
@@ -190,3 +279,61 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
190
279
|
self._analyze_maybe_overflow_tensor(single_arg)
|
|
191
280
|
single_arg.update({"data_name": dump_data_name})
|
|
192
281
|
return single_arg
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class KernelDumpDataProcessor(MindsporeDataProcessor):
|
|
285
|
+
def __init__(self, config, data_writer):
|
|
286
|
+
super().__init__(config, data_writer)
|
|
287
|
+
self.enable_kernel_dump = True
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def start_kernel_dump(config_path):
|
|
291
|
+
hal.synchronize()
|
|
292
|
+
_msprobe_c.init_dump()
|
|
293
|
+
_msprobe_c.set_dump(config_path)
|
|
294
|
+
hal.synchronize()
|
|
295
|
+
|
|
296
|
+
@staticmethod
|
|
297
|
+
def stop_kernel_dump():
|
|
298
|
+
hal.synchronize()
|
|
299
|
+
_msprobe_c.finalize_dump()
|
|
300
|
+
hal.synchronize()
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def _print_unsupported_log(api_name):
|
|
304
|
+
logger.warning(f"The kernel dump does not support the {api_name} API.")
|
|
305
|
+
|
|
306
|
+
def analyze_forward_input(self, name, module, module_input_output):
|
|
307
|
+
if not self.enable_kernel_dump:
|
|
308
|
+
return
|
|
309
|
+
if not has_adump:
|
|
310
|
+
logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
|
|
311
|
+
self.enable_kernel_dump = False
|
|
312
|
+
return
|
|
313
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
314
|
+
|
|
315
|
+
def analyze_forward_output(self, name, module, module_input_output):
|
|
316
|
+
if not self.enable_kernel_dump:
|
|
317
|
+
return
|
|
318
|
+
self.enable_kernel_dump = False
|
|
319
|
+
self.stop_kernel_dump()
|
|
320
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
321
|
+
|
|
322
|
+
def analyze_backward_input(self, name, module, module_input_output):
|
|
323
|
+
if not self.enable_kernel_dump:
|
|
324
|
+
return
|
|
325
|
+
if not has_adump:
|
|
326
|
+
logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
|
|
327
|
+
self.enable_kernel_dump = False
|
|
328
|
+
return
|
|
329
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
330
|
+
|
|
331
|
+
def analyze_backward(self, name, module, module_input_output):
|
|
332
|
+
if not self.enable_kernel_dump:
|
|
333
|
+
return
|
|
334
|
+
self.enable_kernel_dump = False
|
|
335
|
+
self.stop_kernel_dump()
|
|
336
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
337
|
+
|
|
338
|
+
def reset_status(self):
|
|
339
|
+
self.enable_kernel_dump = True
|