mindstudio-probe 8.2.0__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +63 -61
- msprobe/README.md +4 -4
- msprobe/core/common/const.py +6 -0
- msprobe/core/common/db_manager.py +35 -4
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/utils.py +14 -3
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +16 -4
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/analyzer.py +8 -7
- msprobe/core/compare/find_first/graph.py +11 -3
- msprobe/core/compare/find_first/utils.py +3 -2
- msprobe/core/compare/highlight.py +13 -6
- msprobe/core/compare/multiprocessing_compute.py +17 -10
- msprobe/core/compare/utils.py +14 -5
- msprobe/core/data_dump/data_collector.py +18 -21
- msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
- msprobe/core/data_dump/json_writer.py +18 -8
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +21 -0
- msprobe/core/service.py +2 -0
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +7 -5
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/06.data_dump_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +15 -80
- msprobe/docs/22.visualization_MindSpore.md +20 -104
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/cell_processor.py +33 -5
- msprobe/mindspore/compare/common_dir_compare.py +22 -26
- msprobe/mindspore/debugger/precision_debugger.py +1 -1
- msprobe/mindspore/dump/cell_dump_process.py +73 -62
- msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +15 -8
- msprobe/pytorch/monitor/module_hook.py +28 -9
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/visualization/builder/graph_builder.py +169 -64
- msprobe/visualization/builder/graph_merger.py +0 -1
- msprobe/visualization/builder/msprobe_adapter.py +1 -1
- msprobe/visualization/db_utils.py +25 -2
- msprobe/visualization/graph/base_node.py +0 -24
- msprobe/visualization/graph/graph.py +5 -14
- msprobe/visualization/graph_service.py +29 -53
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
|
@@ -26,6 +26,8 @@ from msprobe.core.compare.utils import gen_api_batches
|
|
|
26
26
|
|
|
27
27
|
cur_dir = os.path.dirname(os.path.realpath(__file__))
|
|
28
28
|
diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml')
|
|
29
|
+
ignore_op_list_yaml_path = os.path.join(cur_dir, 'ignore_op_list.yaml')
|
|
30
|
+
ignore_list = load_yaml(ignore_op_list_yaml_path)
|
|
29
31
|
thresholds = load_yaml(diff_threshold_yaml_path)
|
|
30
32
|
cmp_metrics = thresholds.get('compare_metrics')
|
|
31
33
|
|
|
@@ -51,7 +53,7 @@ class FirstDiffAnalyze:
|
|
|
51
53
|
return True
|
|
52
54
|
return False
|
|
53
55
|
|
|
54
|
-
def single_api_check(self, result_slice, header):
|
|
56
|
+
def single_api_check(self, result_slice, header, api_name=None):
|
|
55
57
|
"""
|
|
56
58
|
单个api差异检查
|
|
57
59
|
|
|
@@ -65,14 +67,18 @@ class FirstDiffAnalyze:
|
|
|
65
67
|
}
|
|
66
68
|
|
|
67
69
|
column_indices = {name: idx for idx, name in enumerate(header)}
|
|
68
|
-
|
|
70
|
+
output_idx = -1
|
|
69
71
|
for line in result_slice:
|
|
70
72
|
op_item = {
|
|
71
73
|
column_name: line[column_indices[column_name]]
|
|
72
74
|
for column_name in header
|
|
73
75
|
}
|
|
74
76
|
single_check_result['op_items'].append(op_item)
|
|
75
|
-
|
|
77
|
+
if op_item['state'] != 'output':
|
|
78
|
+
continue
|
|
79
|
+
output_idx += 1
|
|
80
|
+
if output_idx in ignore_list.get(api_name, []):
|
|
81
|
+
continue
|
|
76
82
|
# set is_same
|
|
77
83
|
if self.mode_config.dump_mode == Const.MD5:
|
|
78
84
|
if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF:
|
|
@@ -117,7 +123,13 @@ class FirstDiffAnalyze:
|
|
|
117
123
|
with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar:
|
|
118
124
|
for api_batch in api_batches:
|
|
119
125
|
result_slice = result[api_batch.start: api_batch.params_grad_end_index]
|
|
120
|
-
|
|
126
|
+
api_compo = api_batch.api_name.split('.')
|
|
127
|
+
# suppose name is Tensor.MatMul.0.forward
|
|
128
|
+
if len(api_compo) < 4:
|
|
129
|
+
continue
|
|
130
|
+
# get MatMul as api_name
|
|
131
|
+
api_name = api_compo[-3]
|
|
132
|
+
check_result[api_batch.api_name] = self.single_api_check(result_slice, header, api_name)
|
|
121
133
|
progress_bar.update(1)
|
|
122
134
|
|
|
123
135
|
return check_result
|
|
@@ -47,7 +47,6 @@ class DiffAnalyzer:
|
|
|
47
47
|
analyze_func()
|
|
48
48
|
if self._diff_nodes:
|
|
49
49
|
self._gen_analyze_info()
|
|
50
|
-
self._post_process()
|
|
51
50
|
return
|
|
52
51
|
logger.info('Cannot find any diff node, no need to generate analyze file.')
|
|
53
52
|
|
|
@@ -56,12 +55,6 @@ class DiffAnalyzer:
|
|
|
56
55
|
self._resolve_input_path(self._output_path)
|
|
57
56
|
logger.info("Pre Process completed.")
|
|
58
57
|
|
|
59
|
-
def _post_process(self):
|
|
60
|
-
for rank_path in self._paths.values():
|
|
61
|
-
dump_path = rank_path.dump_path
|
|
62
|
-
logger.debug(f"Remove {dump_path} success")
|
|
63
|
-
logger.info("Post Process completed.")
|
|
64
|
-
|
|
65
58
|
"""
|
|
66
59
|
这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中
|
|
67
60
|
"""
|
|
@@ -105,6 +98,8 @@ class DiffAnalyzer:
|
|
|
105
98
|
logger.warning(f'Rank {path.rank} has no dump data!')
|
|
106
99
|
continue
|
|
107
100
|
for op_name, op_data in dump_data.items():
|
|
101
|
+
if is_ignore_op(op_name):
|
|
102
|
+
continue
|
|
108
103
|
if is_communication_op(op_name):
|
|
109
104
|
self._first_comm_nodes[path.rank] = op_name
|
|
110
105
|
break
|
|
@@ -131,10 +126,16 @@ class DiffAnalyzer:
|
|
|
131
126
|
for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]:
|
|
132
127
|
searched_ranks.add(rank)
|
|
133
128
|
seen_nodes = set()
|
|
129
|
+
last_node = None
|
|
134
130
|
for cur_node in nodes.values():
|
|
131
|
+
is_overflow = last_node and hasattr(last_node, 'layer') and hasattr(cur_node, 'layer') and \
|
|
132
|
+
last_node.layer >= cur_node.layer
|
|
133
|
+
if is_overflow:
|
|
134
|
+
cur_node.layer = last_node.layer + 1
|
|
135
135
|
conn_info = cur_node.find_connected_nodes()
|
|
136
136
|
if not conn_info.get('ranks'):
|
|
137
137
|
conn_info['ranks'] = self._rank_comm_nodes_dict.keys()
|
|
138
|
+
last_node = cur_node
|
|
138
139
|
if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes):
|
|
139
140
|
logger.debug(f'Cannot find connected communication node for "{cur_node.node_id}".')
|
|
140
141
|
|
|
@@ -52,19 +52,25 @@ class DataNode:
|
|
|
52
52
|
metrics = {}
|
|
53
53
|
for cmp_data in self.op_data:
|
|
54
54
|
name = cmp_data.get(CompareConst.NPU_NAME)
|
|
55
|
+
# 构建度量指标字典
|
|
56
|
+
metrics = {}
|
|
57
|
+
|
|
55
58
|
if CompareConst.NPU_MAX in cmp_data:
|
|
56
59
|
metrics = {CompareConst.NPU_MAX: cmp_data.get(CompareConst.NPU_MAX),
|
|
57
60
|
CompareConst.NPU_MIN: cmp_data.get(CompareConst.NPU_MIN),
|
|
58
61
|
CompareConst.NPU_MEAN: cmp_data.get(CompareConst.NPU_MEAN),
|
|
59
62
|
CompareConst.NPU_NORM: cmp_data.get(CompareConst.NPU_NORM)}
|
|
60
63
|
elif CompareConst.NPU_MD5 in cmp_data:
|
|
61
|
-
metrics
|
|
64
|
+
metrics[CompareConst.NPU_MD5] = cmp_data.get(CompareConst.NPU_MD5)
|
|
65
|
+
|
|
66
|
+
if CompareConst.NPU_P2POP_PEER in cmp_data:
|
|
67
|
+
metrics[CompareConst.NPU_P2POP_PEER] = cmp_data.get(CompareConst.NPU_P2POP_PEER)
|
|
62
68
|
|
|
63
69
|
if cmp_data.get(CompareConst.STACK) != CompareConst.N_A and not self.stack:
|
|
64
70
|
self.stack = cmp_data.get(CompareConst.STACK)
|
|
65
|
-
if
|
|
71
|
+
if cmp_data.get('state') == "input":
|
|
66
72
|
self.inputs[name] = metrics
|
|
67
|
-
elif
|
|
73
|
+
elif cmp_data.get('state') == "output":
|
|
68
74
|
self.outputs[name] = metrics
|
|
69
75
|
|
|
70
76
|
def gen_node_info(self, path: RankPath):
|
|
@@ -161,6 +167,8 @@ class CommunicationNode:
|
|
|
161
167
|
if val and val.startswith('[') and val.endswith(']'):
|
|
162
168
|
val = [int(part) for part in val.strip('[]').split(',')]
|
|
163
169
|
ranks.update(val)
|
|
170
|
+
elif v.get(CompareConst.NPU_P2POP_PEER) != "None":
|
|
171
|
+
ranks.add(v.get(CompareConst.NPU_P2POP_PEER))
|
|
164
172
|
|
|
165
173
|
return {'ranks': ranks, 'api': f'Distributed.{tar_api}',
|
|
166
174
|
'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)}
|
|
@@ -120,7 +120,8 @@ def is_communication_op(op_name):
|
|
|
120
120
|
def is_ignore_op(op_name):
|
|
121
121
|
ignore_keywords = [
|
|
122
122
|
'Torch.empty',
|
|
123
|
-
'Torch.fill'
|
|
123
|
+
'Torch.fill',
|
|
124
|
+
'Tensor.__setitem__'
|
|
124
125
|
]
|
|
125
126
|
return any(keyword in op_name for keyword in ignore_keywords)
|
|
126
127
|
|
|
@@ -181,7 +182,7 @@ def analyze_diff_in_group(nodes_group):
|
|
|
181
182
|
input_diff_nodes = list(filter(lambda node: node.is_diff, src_list))
|
|
182
183
|
# 如果有异常回溯计算节点找到异常来源
|
|
183
184
|
# 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
|
|
184
|
-
get_compute_ops_from_comm_nodes(
|
|
185
|
+
get_compute_ops_from_comm_nodes(nodes_group)
|
|
185
186
|
# 筛选入参没问题但出参有问题的通信节点
|
|
186
187
|
output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group))
|
|
187
188
|
get_comm_ops(output_diff_nodes)
|
|
@@ -26,7 +26,7 @@ from tqdm import tqdm
|
|
|
26
26
|
from msprobe.core.common.const import CompareConst, Const
|
|
27
27
|
from msprobe.core.common.file_utils import save_workbook
|
|
28
28
|
from msprobe.core.common.log import logger
|
|
29
|
-
from msprobe.core.common.utils import get_header_index
|
|
29
|
+
from msprobe.core.common.utils import get_header_index, CompareException
|
|
30
30
|
from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches
|
|
31
31
|
from msprobe.core.compare.config import ModeConfig
|
|
32
32
|
|
|
@@ -359,18 +359,25 @@ class HighLight:
|
|
|
359
359
|
|
|
360
360
|
def err_call(args):
|
|
361
361
|
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
362
|
-
try:
|
|
363
|
-
pool.close()
|
|
364
|
-
except OSError:
|
|
365
|
-
logger.error("Pool terminate failed")
|
|
366
362
|
|
|
367
363
|
result_df_columns = result_df.columns.tolist()
|
|
368
364
|
for column in result_df_columns:
|
|
369
365
|
self.value_check(column)
|
|
366
|
+
async_results = []
|
|
370
367
|
for df_chunk in chunks:
|
|
371
|
-
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
368
|
+
result = pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
369
|
+
async_results.append(result)
|
|
372
370
|
|
|
373
371
|
pool.close()
|
|
372
|
+
|
|
373
|
+
for ar in async_results:
|
|
374
|
+
try:
|
|
375
|
+
ar.get(timeout=3600)
|
|
376
|
+
except Exception as e:
|
|
377
|
+
logger.error(f"Task failed with exception: {e}")
|
|
378
|
+
pool.terminate()
|
|
379
|
+
raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
|
|
380
|
+
|
|
374
381
|
pool.join()
|
|
375
382
|
|
|
376
383
|
def df_malicious_value_check(self, result_df):
|
|
@@ -52,16 +52,20 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
|
52
52
|
|
|
53
53
|
def err_call(args):
|
|
54
54
|
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
55
|
-
try:
|
|
56
|
-
pool.close()
|
|
57
|
-
except OSError as e:
|
|
58
|
-
logger.error(f'pool terminate failed: {str(e)}')
|
|
59
55
|
|
|
60
56
|
for df_chunk in df_chunks:
|
|
61
57
|
result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
|
|
62
58
|
results.append(result)
|
|
63
|
-
|
|
59
|
+
|
|
64
60
|
pool.close()
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
final_results = [r.get(timeout=3600) for r in results]
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.error(f"Task failed with exception: {e}")
|
|
66
|
+
pool.terminate()
|
|
67
|
+
raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
|
|
68
|
+
|
|
65
69
|
pool.join()
|
|
66
70
|
return pd.concat(final_results, ignore_index=True)
|
|
67
71
|
|
|
@@ -277,10 +281,6 @@ class CompareRealData:
|
|
|
277
281
|
|
|
278
282
|
def err_call(args):
|
|
279
283
|
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
280
|
-
try:
|
|
281
|
-
pool.close()
|
|
282
|
-
except OSError:
|
|
283
|
-
logger.error("pool terminate failed")
|
|
284
284
|
|
|
285
285
|
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
286
286
|
|
|
@@ -298,7 +298,14 @@ class CompareRealData:
|
|
|
298
298
|
)
|
|
299
299
|
results.append(result)
|
|
300
300
|
|
|
301
|
-
final_results = [r.get() for r in results]
|
|
302
301
|
pool.close()
|
|
302
|
+
|
|
303
|
+
try:
|
|
304
|
+
final_results = [r.get(timeout=3600) for r in results]
|
|
305
|
+
except Exception as e:
|
|
306
|
+
logger.error(f"Task failed with exception: {e}")
|
|
307
|
+
pool.terminate()
|
|
308
|
+
raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
|
|
309
|
+
|
|
303
310
|
pool.join()
|
|
304
311
|
return pd.concat(final_results, ignore_index=True)
|
msprobe/core/compare/utils.py
CHANGED
|
@@ -695,10 +695,6 @@ def get_sorted_ranks(npu_dump_dir, bench_dump_dir):
|
|
|
695
695
|
def multi_statistics_compare(func, func_args):
|
|
696
696
|
def err_call(args):
|
|
697
697
|
logger.error(f'Multiprocess statistics compare failed! Reason: {args}')
|
|
698
|
-
try:
|
|
699
|
-
pool.close()
|
|
700
|
-
except OSError:
|
|
701
|
-
logger.error("Pool terminate failed")
|
|
702
698
|
|
|
703
699
|
compare_func, input_param_nr_list, output_path, kwargs = func_args
|
|
704
700
|
|
|
@@ -715,9 +711,22 @@ def multi_statistics_compare(func, func_args):
|
|
|
715
711
|
chunks[i].append(input_param_nr_list[param_num - remainder + i])
|
|
716
712
|
|
|
717
713
|
pool = multiprocessing.Pool(process_num)
|
|
714
|
+
|
|
715
|
+
async_results = []
|
|
718
716
|
for chunk in chunks:
|
|
719
|
-
pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call)
|
|
717
|
+
result = pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call)
|
|
718
|
+
async_results.append(result)
|
|
719
|
+
|
|
720
720
|
pool.close()
|
|
721
|
+
|
|
722
|
+
for ar in async_results:
|
|
723
|
+
try:
|
|
724
|
+
ar.get(timeout=3600)
|
|
725
|
+
except Exception as e:
|
|
726
|
+
logger.error(f"Task failed with exception: {e}")
|
|
727
|
+
pool.terminate()
|
|
728
|
+
raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
|
|
729
|
+
|
|
721
730
|
pool.join()
|
|
722
731
|
|
|
723
732
|
|
|
@@ -23,6 +23,7 @@ from msprobe.core.data_dump.json_writer import DataWriter
|
|
|
23
23
|
from msprobe.core.common.log import logger
|
|
24
24
|
from msprobe.core.common.const import Const
|
|
25
25
|
from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
|
|
26
|
+
from msprobe.core.common.megatron_utils import MegatronStepInfo, get_micro_step, is_megatron
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
def build_data_collector(config):
|
|
@@ -270,15 +271,20 @@ class DataCollector:
|
|
|
270
271
|
if self.config.level not in DataCollector.level_without_construct:
|
|
271
272
|
if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
|
|
272
273
|
if self.optimizer_status_first_start[self.optimizer_status]:
|
|
273
|
-
self.data_writer.update_construct(
|
|
274
|
+
self.data_writer.update_construct(
|
|
275
|
+
{self.optimizer_status: None if not is_megatron() else [None, get_micro_step()]})
|
|
274
276
|
self.optimizer_status_first_start[self.optimizer_status] = False
|
|
275
|
-
self.data_writer.update_construct(
|
|
277
|
+
self.data_writer.update_construct(
|
|
278
|
+
{name: self.optimizer_status if not is_megatron() else [self.optimizer_status, get_micro_step()]})
|
|
276
279
|
else:
|
|
277
280
|
if self.config.level == Const.LEVEL_MIX and \
|
|
278
281
|
not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
|
|
279
282
|
self.data_writer.update_construct(
|
|
280
283
|
{name: self.module_processor.api_parent_node.get(threading.get_ident())}
|
|
281
284
|
)
|
|
285
|
+
if MegatronStepInfo.is_megatron:
|
|
286
|
+
micro_step_number = max(MegatronStepInfo.forward_micro_step, MegatronStepInfo.backward_micro_step)
|
|
287
|
+
self.data_writer.update_construct({Const.MEGATRON_MICRO_STEP_NUMBER: micro_step_number})
|
|
282
288
|
|
|
283
289
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
284
290
|
|
|
@@ -302,25 +308,16 @@ class DataCollector:
|
|
|
302
308
|
self.data_processor.update_iter(current_iter)
|
|
303
309
|
|
|
304
310
|
def params_data_collect(self, name, param_name, pid, data):
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
if
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
self.params_grad_record[grad_name] = False
|
|
316
|
-
except Exception as e:
|
|
317
|
-
error_type = type(e).__name__
|
|
318
|
-
tb = traceback.format_exc()
|
|
319
|
-
self.data_writer.write_error_log(
|
|
320
|
-
f"[ERROR] params_data_collect failed: "
|
|
321
|
-
f"name={name}, param_name={param_name}, pid={pid}\n{tb}",
|
|
322
|
-
error_type=error_type
|
|
323
|
-
)
|
|
311
|
+
grad_name = name + Const.SEP + Const.PARAMS_GRAD
|
|
312
|
+
self.update_api_or_module_name(grad_name)
|
|
313
|
+
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
314
|
+
if self.data_writer.cache_data.get("data"):
|
|
315
|
+
self.data_writer.cache_data.get("data").pop(grad_name, None)
|
|
316
|
+
self.params_grad_record[grad_name] = False
|
|
317
|
+
return
|
|
318
|
+
data_info = self.data_processor.analyze_params(grad_name, param_name, data)
|
|
319
|
+
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
320
|
+
self.params_grad_record[grad_name] = False
|
|
324
321
|
|
|
325
322
|
def params_data_collect_in_bw_hook(self, params_dict, name):
|
|
326
323
|
try:
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import ctypes
|
|
16
17
|
import os
|
|
17
18
|
import zlib
|
|
18
|
-
import ctypes
|
|
19
19
|
from collections.abc import Iterable
|
|
20
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
20
21
|
from dataclasses import asdict
|
|
21
22
|
from typing import List
|
|
22
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
import torch
|
|
@@ -29,7 +29,6 @@ from torch.distributed.distributed_c10d import _get_default_group
|
|
|
29
29
|
from msprobe.core.common.const import Const
|
|
30
30
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
31
|
from msprobe.core.common.exceptions import MsprobeException
|
|
32
|
-
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
33
32
|
from msprobe.core.common.log import logger
|
|
34
33
|
from msprobe.core.common.utils import convert_tuple, is_int
|
|
35
34
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
@@ -48,15 +47,28 @@ class TensorHandler:
|
|
|
48
47
|
def __init__(self):
|
|
49
48
|
self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor")
|
|
50
49
|
self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor")
|
|
50
|
+
self.has_async_collective_tensor = hasattr(dist, "_functional_collectives") and \
|
|
51
|
+
hasattr(dist._functional_collectives, "AsyncCollectiveTensor")
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def free_tensor(tensor, tensor_name):
|
|
55
|
+
try:
|
|
56
|
+
tensor.untyped_storage().resize_(0)
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.warning(f"Failed to free tensor: {tensor_name}, the detail info: {e}.")
|
|
51
59
|
|
|
52
60
|
def is_dtensor(self, tensor):
|
|
53
|
-
return self.has_dtensor and isinstance(tensor,
|
|
61
|
+
return self.has_dtensor and isinstance(tensor, dist.tensor.DTensor)
|
|
54
62
|
|
|
55
63
|
def is_fake_tensor(self, tensor):
|
|
56
64
|
return self.has_fake_tensor and isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor)
|
|
57
65
|
|
|
66
|
+
def is_async_collective_tensor(self, tensor):
|
|
67
|
+
return self.has_async_collective_tensor and \
|
|
68
|
+
isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor)
|
|
69
|
+
|
|
58
70
|
def is_empty_data(self, tensor):
|
|
59
|
-
return tensor.is_meta or self.is_fake_tensor(tensor)
|
|
71
|
+
return tensor.is_meta or self.is_fake_tensor(tensor) or self.is_async_collective_tensor(tensor)
|
|
60
72
|
|
|
61
73
|
def convert_common_tensor(self, tensor):
|
|
62
74
|
if self.is_dtensor(tensor):
|
|
@@ -71,6 +83,8 @@ class TensorHandler:
|
|
|
71
83
|
return Const.DTENSOR_TYPE
|
|
72
84
|
if self.is_fake_tensor(tensor):
|
|
73
85
|
return Const.FAKE_TENSOR_TYPE
|
|
86
|
+
if self.is_async_collective_tensor(tensor):
|
|
87
|
+
return Const.AC_TENSOR_TYPE
|
|
74
88
|
return Const.TENSOR_TYPE
|
|
75
89
|
|
|
76
90
|
def get_dtensor_info(self, tensor):
|
|
@@ -94,6 +108,18 @@ class TensorHandler:
|
|
|
94
108
|
dtensor_info.update({"placements": placements})
|
|
95
109
|
return dtensor_info
|
|
96
110
|
|
|
111
|
+
def save_tensor(self, tensor, file_path):
|
|
112
|
+
common_tensor = self.convert_common_tensor(tensor)
|
|
113
|
+
if self.is_empty_data(common_tensor):
|
|
114
|
+
logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
|
|
115
|
+
return
|
|
116
|
+
if common_tensor.untyped_storage().data_ptr() == 0:
|
|
117
|
+
logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
|
|
118
|
+
return
|
|
119
|
+
saved_tensor = common_tensor.clone().contiguous().detach()
|
|
120
|
+
save_pt(saved_tensor, file_path)
|
|
121
|
+
self.free_tensor(saved_tensor, file_path)
|
|
122
|
+
|
|
97
123
|
|
|
98
124
|
class PytorchDataProcessor(BaseDataProcessor):
|
|
99
125
|
pytorch_special_type = (
|
|
@@ -288,7 +314,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
288
314
|
|
|
289
315
|
def dump_async_data(self):
|
|
290
316
|
for file_path, tensor in self._async_dump_cache.items():
|
|
291
|
-
|
|
317
|
+
self.tensor_handler.save_tensor(tensor, file_path)
|
|
292
318
|
self._async_dump_cache.clear()
|
|
293
319
|
|
|
294
320
|
def analyze_single_element(self, element, suffix_stack):
|
|
@@ -385,24 +411,24 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
385
411
|
def _analyze_and_save_tensor(self, tensor, suffix):
|
|
386
412
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
387
413
|
single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
414
|
+
common_tensor = self.tensor_handler.convert_common_tensor(tensor)
|
|
415
|
+
if self.tensor_handler.is_empty_data(common_tensor):
|
|
416
|
+
logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
|
|
417
|
+
return single_arg
|
|
418
|
+
if common_tensor.untyped_storage().data_ptr() == 0:
|
|
419
|
+
logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
|
|
393
420
|
return single_arg
|
|
394
421
|
|
|
395
422
|
single_arg.update({"data_name": dump_data_name})
|
|
396
423
|
if self.config.async_dump:
|
|
397
|
-
self._async_dump_cache[file_path] =
|
|
424
|
+
self._async_dump_cache[file_path] = common_tensor.clone().detach()
|
|
398
425
|
else:
|
|
399
|
-
|
|
400
|
-
save_pt(saved_tensor, file_path)
|
|
426
|
+
self.tensor_handler.save_tensor(common_tensor, file_path)
|
|
401
427
|
return single_arg
|
|
402
428
|
|
|
403
429
|
def _analyze_and_save_ndarray(self, ndarray, suffix):
|
|
404
430
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
405
|
-
|
|
431
|
+
self.tensor_handler.save_tensor(torch.tensor(ndarray), file_path)
|
|
406
432
|
ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
|
|
407
433
|
ndarray_json.update({"data_name": dump_data_name})
|
|
408
434
|
return ndarray_json
|
|
@@ -493,7 +519,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
493
519
|
self._analyze_maybe_overflow_flag()
|
|
494
520
|
if self.has_overflow:
|
|
495
521
|
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
496
|
-
|
|
522
|
+
self.tensor_handler.save_tensor(tensor, file_path)
|
|
497
523
|
self.real_overflow_nums += 1
|
|
498
524
|
if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
|
|
499
525
|
logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
|
|
@@ -538,10 +564,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
538
564
|
|
|
539
565
|
def _analyze_tensor(self, tensor, suffix):
|
|
540
566
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
541
|
-
|
|
542
|
-
self.cached_tensors_and_file_paths.update({file_path: tensor})
|
|
543
|
-
else:
|
|
544
|
-
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
567
|
+
self.cached_tensors_and_file_paths.update({file_path: tensor})
|
|
545
568
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
546
569
|
single_arg.update({"data_name": dump_data_name})
|
|
547
570
|
if not self.has_overflow and self.support_inf_nan:
|
|
@@ -13,18 +13,18 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import concurrent
|
|
17
|
+
import copy
|
|
16
18
|
import csv
|
|
17
19
|
import os
|
|
18
|
-
import copy
|
|
19
20
|
import threading
|
|
20
21
|
import traceback
|
|
21
22
|
from datetime import datetime, timezone, timedelta
|
|
22
23
|
|
|
23
|
-
import concurrent
|
|
24
24
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
25
|
-
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
|
|
26
|
-
from msprobe.core.common.log import logger
|
|
27
25
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
26
|
+
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, check_path_before_create
|
|
27
|
+
from msprobe.core.common.log import logger
|
|
28
28
|
|
|
29
29
|
lock = threading.Lock()
|
|
30
30
|
|
|
@@ -40,6 +40,7 @@ class DataWriter:
|
|
|
40
40
|
self.debug_file_path = None
|
|
41
41
|
self.dump_error_info_path = None
|
|
42
42
|
self.flush_size = 1000
|
|
43
|
+
self.md5_flush_size = 5000
|
|
43
44
|
self.larger_flush_size = 20000
|
|
44
45
|
self.cache_data = {}
|
|
45
46
|
self.cache_stack = {}
|
|
@@ -49,6 +50,7 @@ class DataWriter:
|
|
|
49
50
|
self._error_log_initialized = False
|
|
50
51
|
self._cache_logged_error_types = set()
|
|
51
52
|
self.crc32_stack_list = []
|
|
53
|
+
self.data_updated = False
|
|
52
54
|
|
|
53
55
|
@staticmethod
|
|
54
56
|
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
@@ -60,7 +62,7 @@ class DataWriter:
|
|
|
60
62
|
spawn_writer = csv.writer(csv_file)
|
|
61
63
|
if not is_exists:
|
|
62
64
|
spawn_writer.writerow(result_header)
|
|
63
|
-
spawn_writer.writerows([result,])
|
|
65
|
+
spawn_writer.writerows([result, ])
|
|
64
66
|
is_new_file = not is_exists
|
|
65
67
|
if is_new_file:
|
|
66
68
|
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -190,7 +192,7 @@ class DataWriter:
|
|
|
190
192
|
summary_mode = getattr(cfg, "summary_mode", None)
|
|
191
193
|
|
|
192
194
|
if summary_mode == Const.MD5:
|
|
193
|
-
threshold = self.
|
|
195
|
+
threshold = self.md5_flush_size
|
|
194
196
|
else:
|
|
195
197
|
threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
|
|
196
198
|
|
|
@@ -238,6 +240,7 @@ class DataWriter:
|
|
|
238
240
|
logger.warning(f"The dump data({dump_data}) should be a dict.")
|
|
239
241
|
return
|
|
240
242
|
|
|
243
|
+
self.data_updated = True
|
|
241
244
|
key = next(iter(new_data.keys()))
|
|
242
245
|
if key in dump_data:
|
|
243
246
|
dump_data.get(key).update(new_data.get(key))
|
|
@@ -246,6 +249,7 @@ class DataWriter:
|
|
|
246
249
|
|
|
247
250
|
def update_stack(self, name, stack_data):
|
|
248
251
|
with lock:
|
|
252
|
+
self.data_updated = True
|
|
249
253
|
api_list = self.cache_stack.get(stack_data)
|
|
250
254
|
if api_list is None:
|
|
251
255
|
self.cache_stack.update({stack_data: [name]})
|
|
@@ -254,10 +258,12 @@ class DataWriter:
|
|
|
254
258
|
|
|
255
259
|
def update_construct(self, new_data):
|
|
256
260
|
with lock:
|
|
261
|
+
self.data_updated = True
|
|
257
262
|
self.cache_construct.update(new_data)
|
|
258
263
|
|
|
259
264
|
def update_debug(self, new_data):
|
|
260
265
|
with lock:
|
|
266
|
+
self.data_updated = True
|
|
261
267
|
self.cache_debug['data'].update(new_data)
|
|
262
268
|
|
|
263
269
|
def write_data_json(self, file_path):
|
|
@@ -324,17 +330,21 @@ class DataWriter:
|
|
|
324
330
|
stat_result = self.flush_stat_stack()
|
|
325
331
|
# 遍历 cache_data,将占位符替换为最终统计值
|
|
326
332
|
if stat_result:
|
|
333
|
+
self.data_updated = True
|
|
327
334
|
self._replace_stat_placeholders(self.cache_data, stat_result)
|
|
328
335
|
if self.cache_debug:
|
|
329
336
|
self._replace_stat_placeholders(self.cache_debug, stat_result)
|
|
330
337
|
|
|
331
|
-
# 2) 再 flush CRC32
|
|
332
338
|
crc32_result = self.flush_crc32_stack()
|
|
333
339
|
if crc32_result:
|
|
340
|
+
self.data_updated = True
|
|
334
341
|
self._replace_crc32_placeholders(self.cache_data, crc32_result)
|
|
335
342
|
if self.cache_debug:
|
|
336
343
|
self._replace_crc32_placeholders(self.cache_debug, crc32_result)
|
|
337
344
|
|
|
345
|
+
if not self.data_updated:
|
|
346
|
+
return
|
|
347
|
+
|
|
338
348
|
if self.cache_data:
|
|
339
349
|
self.write_data_json(self.dump_file_path)
|
|
340
350
|
if self.cache_stack:
|
|
@@ -343,4 +353,4 @@ class DataWriter:
|
|
|
343
353
|
self.write_construct_info_json(self.construct_file_path)
|
|
344
354
|
if self.cache_debug:
|
|
345
355
|
self.write_debug_info_json(self.debug_file_path)
|
|
346
|
-
|
|
356
|
+
self.data_updated = False
|
msprobe/core/data_dump/scope.py
CHANGED
|
@@ -69,8 +69,7 @@ class BaseScope(ABC):
|
|
|
69
69
|
self.scope = scope
|
|
70
70
|
self.api_list = api_list
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
def rectify_args(scope, api_list):
|
|
72
|
+
def rectify_args(self, scope, api_list):
|
|
74
73
|
if not isinstance(api_list, list):
|
|
75
74
|
raise ScopeException(ScopeException.InvalidApiStr,
|
|
76
75
|
f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
|
|
@@ -104,12 +103,11 @@ class BaseScope(ABC):
|
|
|
104
103
|
|
|
105
104
|
|
|
106
105
|
class ListScope(BaseScope):
|
|
107
|
-
|
|
108
|
-
def rectify_args(scope, api_list):
|
|
106
|
+
def rectify_args(self, scope, api_list):
|
|
109
107
|
if scope and api_list:
|
|
110
108
|
raise ScopeException(ScopeException.ArgConflict,
|
|
111
109
|
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
112
|
-
return super(
|
|
110
|
+
return super().rectify_args(scope, api_list)
|
|
113
111
|
|
|
114
112
|
def check(self, name):
|
|
115
113
|
if not self.scope or name in self.scope:
|
|
@@ -147,7 +145,7 @@ class RangeScope(BaseScope, ABC):
|
|
|
147
145
|
f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
|
|
148
146
|
|
|
149
147
|
def rectify_args(self, scope, api_list):
|
|
150
|
-
scope, api_list = super(
|
|
148
|
+
scope, api_list = super().rectify_args(scope, api_list)
|
|
151
149
|
if scope and len(scope) != 2:
|
|
152
150
|
raise ScopeException(ScopeException.InvalidScope,
|
|
153
151
|
f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
|