mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
- msprobe/README.md +2 -2
- msprobe/core/common/const.py +34 -9
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +14 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -7
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/utils.py +10 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +92 -8
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
- msprobe/core/data_dump/json_writer.py +26 -8
- msprobe/docs/01.installation.md +25 -0
- msprobe/docs/02.config_introduction.md +14 -12
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +34 -15
- msprobe/docs/06.data_dump_MindSpore.md +45 -22
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
- msprobe/docs/19.monitor.md +257 -260
- msprobe/docs/21.visualization_PyTorch.md +10 -0
- msprobe/docs/22.visualization_MindSpore.md +11 -0
- msprobe/docs/27.dump_json_instruction.md +24 -20
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/utils.py +20 -2
- msprobe/mindspore/debugger/debugger_config.py +25 -2
- msprobe/mindspore/debugger/precision_debugger.py +25 -6
- msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/service.py +95 -21
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +71 -0
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +14 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +10 -12
- msprobe/pytorch/monitor/module_hook.py +123 -104
- msprobe/pytorch/monitor/module_metric.py +6 -6
- msprobe/pytorch/monitor/optimizer_collect.py +45 -63
- msprobe/pytorch/monitor/utils.py +8 -43
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +103 -24
- msprobe/visualization/builder/graph_builder.py +31 -5
- msprobe/visualization/builder/msprobe_adapter.py +7 -5
- msprobe/visualization/graph/base_node.py +3 -2
- msprobe/visualization/graph/distributed_analyzer.py +80 -3
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +3 -4
- msprobe/visualization/utils.py +10 -2
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -170,6 +170,16 @@ def gen_op_item(op_data, op_name):
|
|
|
170
170
|
elif op_item.get('type') == 'slice':
|
|
171
171
|
op_item['dtype'] = op_data.get('type')
|
|
172
172
|
op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
|
|
173
|
+
elif op_item.get('type') == 'ellipsis':
|
|
174
|
+
op_item['dtype'] = op_data.get('type')
|
|
175
|
+
op_item['shape'] = '[]'
|
|
176
|
+
for i in params:
|
|
177
|
+
op_item[i] = op_data.get('value')
|
|
178
|
+
elif op_item.get('type') == 'torch.ProcessGroup':
|
|
179
|
+
op_item['dtype'] = op_data.get('type')
|
|
180
|
+
op_item['shape'] = '[]'
|
|
181
|
+
for i in params:
|
|
182
|
+
op_item[i] = str(op_data.get('group_ranks'))
|
|
173
183
|
else:
|
|
174
184
|
op_item['dtype'] = str(type(op_data.get('value')))
|
|
175
185
|
op_item['shape'] = '[]'
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -40,6 +40,7 @@ class DataCollector:
|
|
|
40
40
|
self.scope = ScopeFactory(self.config).build_scope()
|
|
41
41
|
self.backward_module_names = {}
|
|
42
42
|
self.optimizer_status = ""
|
|
43
|
+
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
43
44
|
atexit.register(self.write_json)
|
|
44
45
|
|
|
45
46
|
@property
|
|
@@ -54,6 +55,17 @@ class DataCollector:
|
|
|
54
55
|
def check_scope_and_pid(scope, name, pid):
|
|
55
56
|
return (not scope or scope.check(name)) and pid == os.getpid()
|
|
56
57
|
|
|
58
|
+
@staticmethod
|
|
59
|
+
def set_is_recomputable(data_info, is_recompute):
|
|
60
|
+
if data_info and len(data_info) == 1 and is_recompute is not None: # 正常情况下data_info的长度应改为1
|
|
61
|
+
data_info[list(data_info.keys())[0]]["is_recompute"] = is_recompute
|
|
62
|
+
|
|
63
|
+
def reset_status(self):
|
|
64
|
+
self.optimizer_status = ""
|
|
65
|
+
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
66
|
+
self.data_writer.reset_cache()
|
|
67
|
+
self.backward_module_names.clear()
|
|
68
|
+
|
|
57
69
|
def if_return_forward_new_output(self):
|
|
58
70
|
return self.data_processor.if_return_forward_new_output()
|
|
59
71
|
|
|
@@ -77,7 +89,7 @@ class DataCollector:
|
|
|
77
89
|
logger.debug(msg)
|
|
78
90
|
self.data_writer.update_data(data_info)
|
|
79
91
|
|
|
80
|
-
def forward_input_data_collect(self, name, module, pid, module_input_output):
|
|
92
|
+
def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
81
93
|
if self.config.task == Const.FREE_BENCHMARK:
|
|
82
94
|
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
83
95
|
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
@@ -87,37 +99,48 @@ class DataCollector:
|
|
|
87
99
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
88
100
|
return
|
|
89
101
|
|
|
90
|
-
data_info =
|
|
102
|
+
data_info = {}
|
|
103
|
+
if self.config.task != Const.STRUCTURE:
|
|
104
|
+
data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
|
|
105
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
91
106
|
if self.config.level == Const.LEVEL_L2:
|
|
92
107
|
return
|
|
93
108
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
94
109
|
|
|
95
|
-
def forward_output_data_collect(self, name, module, pid, module_input_output):
|
|
110
|
+
def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
96
111
|
self.update_construct(name)
|
|
97
112
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
98
113
|
return
|
|
99
114
|
|
|
100
|
-
data_info =
|
|
115
|
+
data_info = {}
|
|
116
|
+
if self.config.task != Const.STRUCTURE:
|
|
117
|
+
data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
|
|
118
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
101
119
|
if self.config.level == Const.LEVEL_L2:
|
|
102
120
|
return
|
|
103
121
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
104
122
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
105
123
|
|
|
106
|
-
def forward_data_collect(self, name, module, pid, module_input_output):
|
|
124
|
+
def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
107
125
|
self.update_construct(name)
|
|
108
126
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
109
127
|
return
|
|
110
128
|
|
|
111
|
-
data_info =
|
|
129
|
+
data_info = {}
|
|
130
|
+
if self.config.task != Const.STRUCTURE:
|
|
131
|
+
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
132
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
112
133
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
113
134
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
114
135
|
|
|
115
|
-
def backward_data_collect(self, name, module, pid, module_input_output):
|
|
136
|
+
def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
116
137
|
self.update_construct(name)
|
|
117
138
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
118
139
|
return
|
|
119
140
|
|
|
120
|
-
data_info =
|
|
141
|
+
data_info = {}
|
|
142
|
+
if self.config.task != Const.STRUCTURE:
|
|
143
|
+
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
121
144
|
if self.config.level == Const.LEVEL_L2:
|
|
122
145
|
return
|
|
123
146
|
# 获取执行反向的模块名称
|
|
@@ -127,25 +150,34 @@ class DataCollector:
|
|
|
127
150
|
self.backward_module_names[module_name] = True
|
|
128
151
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
129
152
|
|
|
130
|
-
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
153
|
+
def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
131
154
|
self.update_construct(name)
|
|
132
155
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
133
156
|
return
|
|
134
157
|
|
|
135
|
-
data_info =
|
|
158
|
+
data_info = {}
|
|
159
|
+
if self.config.task != Const.STRUCTURE:
|
|
160
|
+
data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
|
|
161
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
136
162
|
self.handle_data(name, data_info)
|
|
137
163
|
|
|
138
|
-
def backward_output_data_collect(self, name, module, pid, module_input_output):
|
|
164
|
+
def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
139
165
|
self.update_construct(name)
|
|
140
166
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
141
167
|
return
|
|
142
168
|
|
|
143
|
-
data_info =
|
|
169
|
+
data_info = {}
|
|
170
|
+
if self.config.task != Const.STRUCTURE:
|
|
171
|
+
data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
|
|
172
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
144
173
|
self.handle_data(name, data_info)
|
|
145
174
|
|
|
146
175
|
def update_construct(self, name):
|
|
147
176
|
if self.config.level not in DataCollector.level_without_construct:
|
|
148
177
|
if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
|
|
178
|
+
if self.optimizer_status_first_start[self.optimizer_status]:
|
|
179
|
+
self.data_writer.update_construct({self.optimizer_status: None})
|
|
180
|
+
self.optimizer_status_first_start[self.optimizer_status] = False
|
|
149
181
|
self.data_writer.update_construct({name: self.optimizer_status})
|
|
150
182
|
else:
|
|
151
183
|
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
@@ -183,3 +215,16 @@ class DataCollector:
|
|
|
183
215
|
|
|
184
216
|
def fill_stack_tensor_data(self):
|
|
185
217
|
self.data_writer.fill_stack_tensor_data()
|
|
218
|
+
|
|
219
|
+
def debug_data_collect_forward(self, variable, name_with_count):
|
|
220
|
+
|
|
221
|
+
data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
|
|
222
|
+
self.data_writer.update_debug({name_with_count: data_info})
|
|
223
|
+
|
|
224
|
+
def debug_data_collect_backward(self, variable, grad_name_with_count):
|
|
225
|
+
# prepare all None nested data structure
|
|
226
|
+
all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
|
|
227
|
+
self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
|
|
228
|
+
|
|
229
|
+
# register tensor backward hook
|
|
230
|
+
self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])
|
|
@@ -17,6 +17,9 @@ import inspect
|
|
|
17
17
|
import os
|
|
18
18
|
from dataclasses import dataclass, is_dataclass
|
|
19
19
|
from typing import Tuple, Dict, Optional, Any
|
|
20
|
+
from functools import partial
|
|
21
|
+
import copy
|
|
22
|
+
from typing import Union
|
|
20
23
|
|
|
21
24
|
import numpy as np
|
|
22
25
|
|
|
@@ -87,7 +90,7 @@ class TensorStatInfo:
|
|
|
87
90
|
class BaseDataProcessor:
|
|
88
91
|
_recursive_key_stack = []
|
|
89
92
|
special_type = (
|
|
90
|
-
np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
93
|
+
np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
|
|
91
94
|
bool, int, float, str, slice,
|
|
92
95
|
type(Ellipsis)
|
|
93
96
|
)
|
|
@@ -143,6 +146,37 @@ class BaseDataProcessor:
|
|
|
143
146
|
else:
|
|
144
147
|
return data
|
|
145
148
|
|
|
149
|
+
@staticmethod
|
|
150
|
+
def set_value_into_nested_structure(data_structure, indexes, value):
|
|
151
|
+
'''
|
|
152
|
+
Args:
|
|
153
|
+
data_structure: nested data structure
|
|
154
|
+
indexes: List
|
|
155
|
+
value: value to be set
|
|
156
|
+
'''
|
|
157
|
+
if not indexes:
|
|
158
|
+
raise ValueError("set_value_into_nested_structure failed: "
|
|
159
|
+
"indexes need to be non empty when set value to nested data structure")
|
|
160
|
+
current_level = data_structure
|
|
161
|
+
for i, index in enumerate(indexes):
|
|
162
|
+
valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index
|
|
163
|
+
valid_for_dict = isinstance(current_level, dict) and index in current_level
|
|
164
|
+
is_last = i == len(indexes) - 1
|
|
165
|
+
if valid_for_dict or valid_for_list:
|
|
166
|
+
if is_last:
|
|
167
|
+
try:
|
|
168
|
+
current_level[index] = value
|
|
169
|
+
except Exception as e:
|
|
170
|
+
raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
|
|
171
|
+
else:
|
|
172
|
+
try:
|
|
173
|
+
current_level = current_level[index]
|
|
174
|
+
except Exception as e:
|
|
175
|
+
raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError("set_value_into_nested_structure failed: "
|
|
178
|
+
"invalid data_structure type or invalid index")
|
|
179
|
+
|
|
146
180
|
@staticmethod
|
|
147
181
|
def _convert_numpy_to_builtin(arg):
|
|
148
182
|
type_mapping = {
|
|
@@ -183,8 +217,22 @@ class BaseDataProcessor:
|
|
|
183
217
|
return single_arg
|
|
184
218
|
|
|
185
219
|
@staticmethod
|
|
186
|
-
def _analyze_numpy(
|
|
187
|
-
|
|
220
|
+
def _analyze_numpy(ndarray, numpy_type):
|
|
221
|
+
ndarray_json = {}
|
|
222
|
+
ndarray_json.update({'type': 'numpy.ndarray'})
|
|
223
|
+
ndarray_json.update({'dtype': str(ndarray.dtype)})
|
|
224
|
+
ndarray_json.update({'shape': ndarray.shape})
|
|
225
|
+
if ndarray.size > 0:
|
|
226
|
+
ndarray_json.update({"Max": np.max(ndarray).item()})
|
|
227
|
+
ndarray_json.update({"Min": np.min(ndarray).item()})
|
|
228
|
+
ndarray_json.update({"Mean": np.mean(ndarray).item()})
|
|
229
|
+
ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
|
|
230
|
+
else:
|
|
231
|
+
ndarray_json.update({"Max": None})
|
|
232
|
+
ndarray_json.update({"Min": None})
|
|
233
|
+
ndarray_json.update({"Mean": None})
|
|
234
|
+
ndarray_json.update({"Norm": None})
|
|
235
|
+
return ndarray_json
|
|
188
236
|
|
|
189
237
|
@staticmethod
|
|
190
238
|
def _get_allowed_data_mode(data_mode):
|
|
@@ -203,7 +251,7 @@ class BaseDataProcessor:
|
|
|
203
251
|
return cls.special_type
|
|
204
252
|
|
|
205
253
|
@classmethod
|
|
206
|
-
def recursive_apply_transform(cls, args, transform, depth=0):
|
|
254
|
+
def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
|
|
207
255
|
if depth > Const.MAX_DEPTH:
|
|
208
256
|
logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
|
|
209
257
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
@@ -220,7 +268,7 @@ class BaseDataProcessor:
|
|
|
220
268
|
return cls.apply_transform_dict(args_dict, transform, depth)
|
|
221
269
|
elif isinstance(args, (list, tuple)):
|
|
222
270
|
result_list = cls.apply_transform_list(args, transform, depth)
|
|
223
|
-
return
|
|
271
|
+
return result_list
|
|
224
272
|
elif isinstance(args, dict):
|
|
225
273
|
return cls.apply_transform_dict(args, transform, depth)
|
|
226
274
|
elif args is not None:
|
|
@@ -228,12 +276,12 @@ class BaseDataProcessor:
|
|
|
228
276
|
return None
|
|
229
277
|
else:
|
|
230
278
|
return None
|
|
231
|
-
|
|
279
|
+
|
|
232
280
|
@classmethod
|
|
233
281
|
def apply_transform_dict(cls, args, transform, depth):
|
|
234
282
|
result_dict = {}
|
|
235
283
|
for k, arg in args.items():
|
|
236
|
-
cls._recursive_key_stack.append(
|
|
284
|
+
cls._recursive_key_stack.append(k)
|
|
237
285
|
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
238
286
|
cls._recursive_key_stack.pop()
|
|
239
287
|
return result_dict
|
|
@@ -242,11 +290,21 @@ class BaseDataProcessor:
|
|
|
242
290
|
def apply_transform_list(cls, args, transform, depth):
|
|
243
291
|
result_list = []
|
|
244
292
|
for i, arg in enumerate(args):
|
|
245
|
-
cls._recursive_key_stack.append(
|
|
293
|
+
cls._recursive_key_stack.append(i)
|
|
246
294
|
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
247
295
|
cls._recursive_key_stack.pop()
|
|
248
296
|
return result_list
|
|
249
297
|
|
|
298
|
+
@classmethod
|
|
299
|
+
def register_hook_single_element(cls, element, suffix_stack, hook_fn):
|
|
300
|
+
if cls.is_hookable_element(element):
|
|
301
|
+
indexes = copy.deepcopy(suffix_stack)
|
|
302
|
+
wrap_hook_fn = partial(hook_fn, indexes=indexes)
|
|
303
|
+
|
|
304
|
+
def real_hook_fn(grad):
|
|
305
|
+
return wrap_hook_fn(grad)
|
|
306
|
+
element.register_hook(real_hook_fn)
|
|
307
|
+
|
|
250
308
|
def if_return_forward_new_output(self):
|
|
251
309
|
return self._return_forward_new_output
|
|
252
310
|
|
|
@@ -383,3 +441,29 @@ class BaseDataProcessor:
|
|
|
383
441
|
suffix + file_format)
|
|
384
442
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
385
443
|
return dump_data_name, file_path
|
|
444
|
+
|
|
445
|
+
def analyze_element_to_all_none(self, element):
|
|
446
|
+
return self.recursive_apply_transform(element, lambda element, stack: None)
|
|
447
|
+
|
|
448
|
+
def analyze_debug_forward(self, variable, name_with_count):
|
|
449
|
+
self.current_api_or_module_name = name_with_count
|
|
450
|
+
self.api_data_category = Const.TENSOR
|
|
451
|
+
# these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
|
|
452
|
+
data_info = self.analyze_element(variable)
|
|
453
|
+
return data_info
|
|
454
|
+
|
|
455
|
+
def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
|
|
456
|
+
def hook_fn(grad, indexes):
|
|
457
|
+
suffix = Const.SEP.join([str(index) for index in indexes])
|
|
458
|
+
self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
|
|
459
|
+
grad_data_info = self.analyze_element(grad)
|
|
460
|
+
self.save_name = None
|
|
461
|
+
full_index = [grad_name_with_count] + indexes
|
|
462
|
+
try:
|
|
463
|
+
self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
|
|
464
|
+
except (ValueError, IndexError) as e:
|
|
465
|
+
logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
|
|
466
|
+
f"skip current recording, detailed infomation: {e}")
|
|
467
|
+
return grad
|
|
468
|
+
wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
|
|
469
|
+
self.recursive_apply_transform(variable, wrap_register_hook_single_element)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.const import Const
|
|
17
|
+
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class DataProcessorFactory:
|
|
@@ -62,6 +63,7 @@ class DataProcessorFactory:
|
|
|
62
63
|
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
|
|
63
64
|
cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
|
|
64
65
|
cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
|
|
66
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
|
|
65
67
|
cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
|
|
66
68
|
elif framework == Const.MS_FRAMEWORK:
|
|
67
69
|
from msprobe.core.data_dump.data_processor.mindspore_processor import (
|
|
@@ -75,4 +77,5 @@ class DataProcessorFactory:
|
|
|
75
77
|
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
|
|
76
78
|
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
|
|
77
79
|
cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
|
|
80
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
|
|
78
81
|
cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
|
|
@@ -23,7 +23,7 @@ import numpy as np
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
24
|
from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
|
|
25
25
|
ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
|
|
26
|
-
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
26
|
+
from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
|
|
27
27
|
from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
|
|
28
28
|
from msprobe.mindspore.common.log import logger
|
|
29
29
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
@@ -116,6 +116,10 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
116
116
|
api_register.norm_inner_op_set_hook_func()
|
|
117
117
|
return tensor_stat
|
|
118
118
|
|
|
119
|
+
@staticmethod
|
|
120
|
+
def is_hookable_element(element):
|
|
121
|
+
return hasattr(element, "register_hook") and callable(element.register_hook)
|
|
122
|
+
|
|
119
123
|
@classmethod
|
|
120
124
|
def get_special_types(cls):
|
|
121
125
|
return super().get_special_types() + cls.mindspore_special_type
|
|
@@ -136,11 +140,13 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
136
140
|
|
|
137
141
|
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
138
142
|
if converted_numpy is not element:
|
|
139
|
-
return
|
|
143
|
+
return {"type": numpy_type, "value": converted_numpy}
|
|
140
144
|
if isinstance(element, Number):
|
|
141
145
|
return self.analyze_dtype_in_kwargs(element)
|
|
142
146
|
if isinstance(element, ms.Tensor):
|
|
143
|
-
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
|
|
147
|
+
return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
148
|
+
if isinstance(element, np.ndarray):
|
|
149
|
+
return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
144
150
|
if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
|
|
145
151
|
return self._analyze_builtin(element)
|
|
146
152
|
return {}
|
|
@@ -185,6 +191,13 @@ class TensorDataProcessor(MindsporeDataProcessor):
|
|
|
185
191
|
else:
|
|
186
192
|
save_tensor_as_npy(tensor, file_path)
|
|
187
193
|
return single_arg
|
|
194
|
+
|
|
195
|
+
def _analyze_numpy(self, ndarray, suffix):
|
|
196
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
197
|
+
save_npy(ndarray, file_path)
|
|
198
|
+
ndarray_json = super()._analyze_numpy(ndarray, suffix)
|
|
199
|
+
ndarray_json.update({"data_name": dump_data_name})
|
|
200
|
+
return ndarray_json
|
|
188
201
|
|
|
189
202
|
|
|
190
203
|
class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
@@ -231,7 +244,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
231
244
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
232
245
|
self.maybe_save_overflow_data()
|
|
233
246
|
return api_info_struct if self.has_overflow else None
|
|
234
|
-
|
|
247
|
+
|
|
235
248
|
def analyze_params(self, name, param_name, grad):
|
|
236
249
|
self.has_overflow = False
|
|
237
250
|
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,6 +21,7 @@ from typing import List
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import torch
|
|
23
23
|
from torch import distributed as dist
|
|
24
|
+
from torch.distributed.distributed_c10d import _get_default_group
|
|
24
25
|
|
|
25
26
|
from msprobe.core.common.const import Const
|
|
26
27
|
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
@@ -40,7 +41,16 @@ except ImportError:
|
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
class PytorchDataProcessor(BaseDataProcessor):
|
|
43
|
-
pytorch_special_type = (
|
|
44
|
+
pytorch_special_type = (
|
|
45
|
+
torch.device,
|
|
46
|
+
torch.dtype,
|
|
47
|
+
torch.Size,
|
|
48
|
+
torch.Tensor,
|
|
49
|
+
torch.memory_format,
|
|
50
|
+
dist.ProcessGroup,
|
|
51
|
+
dist.P2POp,
|
|
52
|
+
dist.ReduceOp
|
|
53
|
+
)
|
|
44
54
|
memory_format = {
|
|
45
55
|
torch.contiguous_format: "contiguous_format",
|
|
46
56
|
torch.channels_last: "channels_last",
|
|
@@ -168,6 +178,11 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
168
178
|
def is_distributed_op(module):
|
|
169
179
|
return getattr(module, "op_is_distributed", False)
|
|
170
180
|
|
|
181
|
+
@staticmethod
|
|
182
|
+
def is_hookable_element(element):
|
|
183
|
+
return (hasattr(element, "register_hook") and callable(element.register_hook)) and \
|
|
184
|
+
(hasattr(element, "requires_grad") and element.requires_grad)
|
|
185
|
+
|
|
171
186
|
@staticmethod
|
|
172
187
|
def _analyze_torch_size(arg):
|
|
173
188
|
return {"type": "torch.Size", "value": list(arg)}
|
|
@@ -176,7 +191,6 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
176
191
|
def _analyze_memory_format(arg):
|
|
177
192
|
# 获取内存格式
|
|
178
193
|
format_type = PytorchDataProcessor.memory_format.get(arg)
|
|
179
|
-
|
|
180
194
|
return {"type": "torch.memory_format", "format": format_type}
|
|
181
195
|
|
|
182
196
|
@staticmethod
|
|
@@ -188,9 +202,18 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
188
202
|
group_id = PytorchDataProcessor.process_group_hash(arg)
|
|
189
203
|
group_info.update({"group_id": group_id})
|
|
190
204
|
except Exception as e:
|
|
191
|
-
logger.warning(f"Failed to get process group
|
|
205
|
+
logger.warning(f"Failed to get process group ranks info with error info: {e}.")
|
|
192
206
|
return group_info
|
|
193
207
|
|
|
208
|
+
@staticmethod
|
|
209
|
+
def _analyze_reduce_op(arg):
|
|
210
|
+
op_type = None
|
|
211
|
+
try:
|
|
212
|
+
op_type = str(arg)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
|
|
215
|
+
return {"type": "torch.distributed.ReduceOp", "value": op_type}
|
|
216
|
+
|
|
194
217
|
@classmethod
|
|
195
218
|
def get_special_types(cls):
|
|
196
219
|
return super().get_special_types() + cls.pytorch_special_type
|
|
@@ -204,11 +227,17 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
204
227
|
return self._analyze_memory_format(element)
|
|
205
228
|
if isinstance(element, dist.ProcessGroup):
|
|
206
229
|
return self._analyze_process_group(element)
|
|
230
|
+
if isinstance(element, dist.P2POp):
|
|
231
|
+
return self._analyze_p2pop(element)
|
|
232
|
+
if isinstance(element, dist.ReduceOp):
|
|
233
|
+
return self._analyze_reduce_op(element)
|
|
207
234
|
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
208
235
|
if converted_numpy is not element:
|
|
209
|
-
return
|
|
236
|
+
return {"type": numpy_type, "value": converted_numpy}
|
|
210
237
|
if isinstance(element, torch.Tensor):
|
|
211
|
-
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
|
|
238
|
+
return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
239
|
+
if isinstance(element, np.ndarray):
|
|
240
|
+
return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
212
241
|
if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
|
|
213
242
|
return self._analyze_builtin(element)
|
|
214
243
|
return {}
|
|
@@ -218,6 +247,21 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
218
247
|
module_input_output.update_output_with_args_and_kwargs()
|
|
219
248
|
return super().analyze_forward_output(name, module, module_input_output)
|
|
220
249
|
|
|
250
|
+
def _analyze_p2pop(self, arg):
|
|
251
|
+
p2pop_info = {"class_type": "torch.distributed.P2POp"}
|
|
252
|
+
try:
|
|
253
|
+
tensor_info = self._analyze_tensor(arg.tensor, [])
|
|
254
|
+
p2pop_info.update({"tensor": tensor_info})
|
|
255
|
+
p2pop_info.update({"op": arg.op.__name__})
|
|
256
|
+
p2pop_info.update({"peer": arg.peer})
|
|
257
|
+
p2pop_info.update({"tag": arg.tag})
|
|
258
|
+
group_id = PytorchDataProcessor.process_group_hash(
|
|
259
|
+
arg.group) if arg.group else PytorchDataProcessor.process_group_hash(_get_default_group())
|
|
260
|
+
p2pop_info.update({"group_id": group_id})
|
|
261
|
+
except Exception as e:
|
|
262
|
+
logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
|
|
263
|
+
return p2pop_info
|
|
264
|
+
|
|
221
265
|
def _analyze_tensor(self, tensor, suffix):
|
|
222
266
|
tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
|
|
223
267
|
tensor_json = {}
|
|
@@ -267,6 +311,13 @@ class TensorDataProcessor(PytorchDataProcessor):
|
|
|
267
311
|
saved_tensor = tensor.clone().contiguous().detach()
|
|
268
312
|
save_pt(saved_tensor, file_path)
|
|
269
313
|
return single_arg
|
|
314
|
+
|
|
315
|
+
def _analyze_numpy(self, ndarray, suffix):
|
|
316
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
317
|
+
save_pt(torch.tensor(ndarray), file_path)
|
|
318
|
+
ndarray_json = super()._analyze_numpy(ndarray, suffix)
|
|
319
|
+
ndarray_json.update({"data_name": dump_data_name})
|
|
320
|
+
return ndarray_json
|
|
270
321
|
|
|
271
322
|
|
|
272
323
|
class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
@@ -319,7 +370,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
319
370
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
320
371
|
self.handle_overflow()
|
|
321
372
|
return api_info_struct if self.has_overflow else None
|
|
322
|
-
|
|
373
|
+
|
|
323
374
|
def analyze_params(self, name, param_name, grad):
|
|
324
375
|
self.has_overflow = False
|
|
325
376
|
self._is_support_inf_nan()
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import csv
|
|
17
17
|
import os
|
|
18
|
+
import copy
|
|
18
19
|
import numpy as np
|
|
19
20
|
|
|
20
21
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
@@ -31,10 +32,12 @@ class DataWriter:
|
|
|
31
32
|
self.construct_file_path = None
|
|
32
33
|
self.free_benchmark_file_path = None
|
|
33
34
|
self.dump_tensor_data_dir = None
|
|
35
|
+
self.debug_file_path = None
|
|
34
36
|
self.flush_size = 1000
|
|
35
37
|
self.cache_data = {}
|
|
36
38
|
self.cache_stack = {}
|
|
37
39
|
self.cache_construct = {}
|
|
40
|
+
self.cache_debug = {}
|
|
38
41
|
|
|
39
42
|
@staticmethod
|
|
40
43
|
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
@@ -57,6 +60,13 @@ class DataWriter:
|
|
|
57
60
|
self.cache_construct = {}
|
|
58
61
|
|
|
59
62
|
def initialize_json_file(self, **kwargs):
|
|
63
|
+
if self.debug_file_path and not self.cache_debug:
|
|
64
|
+
# debug level case only create debug.json
|
|
65
|
+
debug_dict = copy.deepcopy(kwargs)
|
|
66
|
+
debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
67
|
+
self.cache_debug = debug_dict
|
|
68
|
+
save_json(self.debug_file_path, self.cache_debug, indent=1)
|
|
69
|
+
return
|
|
60
70
|
if not self.cache_data:
|
|
61
71
|
kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
62
72
|
self.cache_data = kwargs
|
|
@@ -66,13 +76,13 @@ class DataWriter:
|
|
|
66
76
|
if not self.cache_construct:
|
|
67
77
|
save_json(self.construct_file_path, self.cache_construct, indent=1)
|
|
68
78
|
|
|
69
|
-
def update_dump_paths(self,
|
|
70
|
-
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
79
|
+
def update_dump_paths(self, dump_path_aggregation):
|
|
80
|
+
self.dump_file_path = dump_path_aggregation.dump_file_path
|
|
81
|
+
self.stack_file_path = dump_path_aggregation.stack_file_path
|
|
82
|
+
self.construct_file_path = dump_path_aggregation.construct_file_path
|
|
83
|
+
self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir
|
|
84
|
+
self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path
|
|
85
|
+
self.debug_file_path = dump_path_aggregation.debug_file_path
|
|
76
86
|
|
|
77
87
|
def flush_data_periodically(self):
|
|
78
88
|
dump_data = self.cache_data.get(Const.DATA)
|
|
@@ -100,6 +110,9 @@ class DataWriter:
|
|
|
100
110
|
def update_construct(self, new_data):
|
|
101
111
|
self.cache_construct.update(new_data)
|
|
102
112
|
|
|
113
|
+
def update_debug(self, new_data):
|
|
114
|
+
self.cache_debug['data'].update(new_data)
|
|
115
|
+
|
|
103
116
|
def write_data_json(self, file_path):
|
|
104
117
|
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
|
|
105
118
|
save_json(file_path, self.cache_data, indent=1)
|
|
@@ -110,6 +123,9 @@ class DataWriter:
|
|
|
110
123
|
def write_construct_info_json(self, file_path):
|
|
111
124
|
save_json(file_path, self.cache_construct, indent=1)
|
|
112
125
|
|
|
126
|
+
def write_debug_info_json(self, file_path):
|
|
127
|
+
save_json(file_path, self.cache_debug, indent=1)
|
|
128
|
+
|
|
113
129
|
def write_json(self):
|
|
114
130
|
if self.cache_data:
|
|
115
131
|
self.write_data_json(self.dump_file_path)
|
|
@@ -117,6 +133,8 @@ class DataWriter:
|
|
|
117
133
|
self.write_stack_info_json(self.stack_file_path)
|
|
118
134
|
if self.cache_construct:
|
|
119
135
|
self.write_construct_info_json(self.construct_file_path)
|
|
136
|
+
if self.cache_debug:
|
|
137
|
+
self.write_debug_info_json(self.debug_file_path)
|
|
120
138
|
|
|
121
139
|
def fill_stack_tensor_data(self):
|
|
122
140
|
self.process_stat_data_recursive(self.cache_data)
|
|
@@ -135,7 +153,7 @@ class DataWriter:
|
|
|
135
153
|
if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
|
|
136
154
|
tensor_stat_data = tensor_stat_data.cpu()
|
|
137
155
|
for index, stat in zip(tensor_stat_index, tensor_stat_data):
|
|
138
|
-
data.update({index
|
|
156
|
+
data.update({index: stat.item()})
|
|
139
157
|
del data["tensor_stat"]
|
|
140
158
|
else:
|
|
141
159
|
for key in data.keys():
|