mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,7 +21,7 @@ from dataclasses import dataclass
|
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
|
|
24
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
24
|
+
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
25
25
|
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
|
|
26
26
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
27
27
|
|
|
@@ -37,13 +37,20 @@ def extract_json(dirname, stack_json=False):
|
|
|
37
37
|
# Provide robustness on invalid directory inputs
|
|
38
38
|
if not json_path:
|
|
39
39
|
if stack_json:
|
|
40
|
-
logger.
|
|
40
|
+
logger.warning(f'stack.json is not found in dump dir {dirname}.')
|
|
41
41
|
else:
|
|
42
42
|
logger.error(f'dump.json is not found in dump dir {dirname}.')
|
|
43
|
-
|
|
43
|
+
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
44
44
|
return json_path
|
|
45
45
|
|
|
46
46
|
|
|
47
|
+
def set_stack_json_path(input_param):
|
|
48
|
+
npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
|
|
49
|
+
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
50
|
+
input_param["stack_json_path"] = stack_path if stack_path else None
|
|
51
|
+
return bool(stack_path)
|
|
52
|
+
|
|
53
|
+
|
|
47
54
|
def check_and_return_dir_contents(dump_dir, prefix):
|
|
48
55
|
"""
|
|
49
56
|
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
@@ -75,6 +82,10 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
75
82
|
|
|
76
83
|
|
|
77
84
|
def rename_api(npu_name, process):
|
|
85
|
+
"""
|
|
86
|
+
原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
|
|
87
|
+
rename后: {api_type}.{api_name}.{input/output}.{参数序号}
|
|
88
|
+
"""
|
|
78
89
|
npu_split = npu_name.split(process)
|
|
79
90
|
try:
|
|
80
91
|
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
@@ -87,17 +98,13 @@ def rename_api(npu_name, process):
|
|
|
87
98
|
|
|
88
99
|
|
|
89
100
|
def read_op(op_data, op_name):
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
op_parsed_list = []
|
|
98
|
-
for name in io_name_mapping:
|
|
99
|
-
if name in op_data:
|
|
100
|
-
op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
|
|
101
|
+
if Const.PARAMS_GRAD in op_name.split(Const.SEP):
|
|
102
|
+
op_parsed_list = op_item_parse(op_data, op_name)
|
|
103
|
+
else:
|
|
104
|
+
op_parsed_list = []
|
|
105
|
+
for name in CompareConst.IO_NAME_MAPPING:
|
|
106
|
+
if name in op_data:
|
|
107
|
+
op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name]))
|
|
101
108
|
return op_parsed_list
|
|
102
109
|
|
|
103
110
|
|
|
@@ -124,11 +131,14 @@ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
|
|
|
124
131
|
return [default_item]
|
|
125
132
|
elif not op_data:
|
|
126
133
|
return []
|
|
127
|
-
|
|
134
|
+
|
|
128
135
|
item_list = []
|
|
129
136
|
if isinstance(op_data, list):
|
|
130
137
|
for i, data in enumerate(op_data):
|
|
131
|
-
|
|
138
|
+
if Const.PARAMS_GRAD not in op_name.split(Const.SEP):
|
|
139
|
+
item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
|
|
140
|
+
else:
|
|
141
|
+
item_list.extend(op_item_parse(data, op_name, depth + 1))
|
|
132
142
|
elif isinstance(op_data, dict):
|
|
133
143
|
if is_leaf_data(op_data):
|
|
134
144
|
return [gen_op_item(op_data, op_name)]
|
|
@@ -144,14 +154,15 @@ def is_leaf_data(op_data):
|
|
|
144
154
|
def gen_op_item(op_data, op_name):
|
|
145
155
|
op_item = {}
|
|
146
156
|
op_item.update(op_data)
|
|
147
|
-
|
|
148
|
-
op_item['data_name'] =
|
|
157
|
+
data_name = op_data.get('data_name') if op_data.get('data_name') else '-1' # 如果是""也返回-1
|
|
158
|
+
op_item['data_name'] = data_name
|
|
159
|
+
op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name
|
|
149
160
|
|
|
150
161
|
params = ['Max', 'Min', 'Mean', 'Norm']
|
|
151
162
|
for i in params:
|
|
152
163
|
if i not in op_item:
|
|
153
164
|
op_item[i] = None
|
|
154
|
-
|
|
165
|
+
|
|
155
166
|
if not op_item.get('dtype'):
|
|
156
167
|
if op_item.get('type') == 'torch.Size':
|
|
157
168
|
op_item['dtype'] = op_data.get('type')
|
|
@@ -159,6 +170,16 @@ def gen_op_item(op_data, op_name):
|
|
|
159
170
|
elif op_item.get('type') == 'slice':
|
|
160
171
|
op_item['dtype'] = op_data.get('type')
|
|
161
172
|
op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
|
|
173
|
+
elif op_item.get('type') == 'ellipsis':
|
|
174
|
+
op_item['dtype'] = op_data.get('type')
|
|
175
|
+
op_item['shape'] = '[]'
|
|
176
|
+
for i in params:
|
|
177
|
+
op_item[i] = op_data.get('value')
|
|
178
|
+
elif op_item.get('type') == 'torch.ProcessGroup':
|
|
179
|
+
op_item['dtype'] = op_data.get('type')
|
|
180
|
+
op_item['shape'] = '[]'
|
|
181
|
+
for i in params:
|
|
182
|
+
op_item[i] = str(op_data.get('group_ranks'))
|
|
162
183
|
else:
|
|
163
184
|
op_item['dtype'] = str(type(op_data.get('value')))
|
|
164
185
|
op_item['shape'] = '[]'
|
|
@@ -166,7 +187,7 @@ def gen_op_item(op_data, op_name):
|
|
|
166
187
|
op_item[i] = op_data.get('value')
|
|
167
188
|
if not op_item.get('md5'):
|
|
168
189
|
op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
|
|
169
|
-
|
|
190
|
+
|
|
170
191
|
return op_item
|
|
171
192
|
|
|
172
193
|
|
|
@@ -276,6 +297,22 @@ def result_item_init(n_info, b_info, dump_mode):
|
|
|
276
297
|
return result_item
|
|
277
298
|
|
|
278
299
|
|
|
300
|
+
def count_struct(op_dict):
|
|
301
|
+
parts = [
|
|
302
|
+
CompareConst.OP_NAME,
|
|
303
|
+
CompareConst.INPUT_STRUCT,
|
|
304
|
+
CompareConst.OUTPUT_STRUCT,
|
|
305
|
+
CompareConst.PARAMS_STRUCT,
|
|
306
|
+
CompareConst.PARAMS_GRAD_STRUCT
|
|
307
|
+
]
|
|
308
|
+
lengths = [len(op_dict.get(part, [])) for part in parts]
|
|
309
|
+
num = lengths[0]
|
|
310
|
+
if num != sum(lengths[1:]):
|
|
311
|
+
logger.error(f"Length of names and structs of op_dict not match. Please check! op_dict: {op_dict}")
|
|
312
|
+
raise CompareException(CompareException.NAMES_STRUCTS_MATCH_ERROR)
|
|
313
|
+
return tuple(lengths)
|
|
314
|
+
|
|
315
|
+
|
|
279
316
|
def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
280
317
|
def get_accuracy_core(n_start, n_len, b_start, b_len, key):
|
|
281
318
|
min_len = min(n_len, b_len)
|
|
@@ -355,31 +392,50 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
355
392
|
|
|
356
393
|
result.append(result_item)
|
|
357
394
|
|
|
358
|
-
n_num =
|
|
359
|
-
b_num =
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
n_num_output
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
395
|
+
n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict)
|
|
396
|
+
b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict)
|
|
397
|
+
|
|
398
|
+
get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT)
|
|
399
|
+
get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params,
|
|
400
|
+
CompareConst.PARAMS_STRUCT)
|
|
401
|
+
get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT)
|
|
402
|
+
get_accuracy_core(n_num_input + n_num_output + n_num_params, n_num_params_grad,
|
|
403
|
+
b_num_input + b_num_output + b_num_params, b_num_params_grad,
|
|
404
|
+
CompareConst.PARAMS_GRAD_STRUCT)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def append_stack_info(result_item, npu_stack_info, index):
|
|
408
|
+
"""添加堆栈信息到 result_item"""
|
|
409
|
+
if npu_stack_info and index == 0:
|
|
410
|
+
result_item.extend(npu_stack_info)
|
|
411
|
+
else:
|
|
412
|
+
result_item.append(CompareConst.NONE)
|
|
368
413
|
|
|
369
414
|
|
|
370
415
|
def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
371
|
-
index_out = 0
|
|
372
416
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
373
417
|
bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
418
|
+
|
|
419
|
+
struct_to_index_mapping = {
|
|
420
|
+
CompareConst.INPUT_STRUCT: 0,
|
|
421
|
+
CompareConst.OUTPUT_STRUCT: 0,
|
|
422
|
+
CompareConst.PARAMS_STRUCT: 0,
|
|
423
|
+
CompareConst.PARAMS_GRAD_STRUCT: 0
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
op_name_list = n_dict.get(CompareConst.OP_NAME)
|
|
427
|
+
summary_list = n_dict.get(Const.SUMMARY)
|
|
428
|
+
data_name_list = n_dict.get('data_name')
|
|
429
|
+
op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
|
|
430
|
+
summary_list,
|
|
431
|
+
data_name_list)
|
|
432
|
+
for index, n_name in enumerate(op_name_reorder):
|
|
433
|
+
_, state = get_name_and_state(n_name)
|
|
434
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
435
|
+
if not struct_key:
|
|
436
|
+
continue
|
|
437
|
+
n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
|
|
438
|
+
struct_to_index_mapping[struct_key] += 1
|
|
383
439
|
|
|
384
440
|
try:
|
|
385
441
|
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
@@ -390,28 +446,26 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
|
390
446
|
f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
|
|
391
447
|
logger.error(err_msg)
|
|
392
448
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
449
|
+
|
|
393
450
|
if dump_mode == Const.MD5:
|
|
394
451
|
result_item.extend([CompareConst.N_A] * 3)
|
|
395
|
-
|
|
396
|
-
result_item.extend(npu_stack_info)
|
|
397
|
-
else:
|
|
398
|
-
result_item.append(CompareConst.NONE)
|
|
452
|
+
append_stack_info(result_item, npu_stack_info, index)
|
|
399
453
|
result.append(result_item)
|
|
400
454
|
continue
|
|
401
455
|
if dump_mode == Const.SUMMARY:
|
|
402
456
|
result_item.extend([CompareConst.N_A] * 8)
|
|
403
|
-
|
|
457
|
+
if dump_mode == Const.ALL:
|
|
404
458
|
result_item.extend([CompareConst.N_A] * 5)
|
|
405
|
-
|
|
406
|
-
|
|
459
|
+
|
|
460
|
+
npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
|
|
407
461
|
bench_summary_data = [CompareConst.N_A] * 4
|
|
462
|
+
result_item.extend(npu_summary_data)
|
|
408
463
|
result_item.extend(bench_summary_data)
|
|
464
|
+
err_msg = CompareConst.NO_BENCH
|
|
465
|
+
accuracy_check_res = CompareConst.N_A
|
|
409
466
|
result_item.append(accuracy_check_res)
|
|
410
467
|
result_item.append(err_msg)
|
|
411
|
-
|
|
412
|
-
result_item.extend(npu_stack_info)
|
|
413
|
-
else:
|
|
414
|
-
result_item.append(CompareConst.NONE)
|
|
468
|
+
append_stack_info(result_item, npu_stack_info, index)
|
|
415
469
|
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
416
470
|
result_item.extend(["-1"])
|
|
417
471
|
result.append(result_item)
|
|
@@ -423,6 +477,8 @@ def merge_tensor(tensor_list, dump_mode):
|
|
|
423
477
|
op_dict[CompareConst.INPUT_STRUCT] = []
|
|
424
478
|
op_dict[CompareConst.KWARGS_STRUCT] = []
|
|
425
479
|
op_dict[CompareConst.OUTPUT_STRUCT] = []
|
|
480
|
+
op_dict[CompareConst.PARAMS_STRUCT] = []
|
|
481
|
+
op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
|
|
426
482
|
op_dict[Const.SUMMARY] = []
|
|
427
483
|
op_dict["stack_info"] = []
|
|
428
484
|
|
|
@@ -430,30 +486,25 @@ def merge_tensor(tensor_list, dump_mode):
|
|
|
430
486
|
op_dict["data_name"] = []
|
|
431
487
|
|
|
432
488
|
for tensor in tensor_list:
|
|
489
|
+
# A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
|
|
433
490
|
if len(tensor) == 2:
|
|
434
491
|
op_dict['stack_info'].append(tensor['full_info'])
|
|
435
492
|
break
|
|
493
|
+
|
|
436
494
|
op_dict["op_name"].append(tensor['full_op_name'])
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
447
|
-
else:
|
|
448
|
-
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
449
|
-
break
|
|
495
|
+
|
|
496
|
+
_, state = get_name_and_state(tensor['full_op_name'])
|
|
497
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
498
|
+
if not struct_key:
|
|
499
|
+
continue
|
|
500
|
+
if dump_mode == Const.MD5:
|
|
501
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
502
|
+
else:
|
|
503
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
450
504
|
op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
|
|
451
505
|
|
|
452
506
|
if dump_mode == Const.ALL:
|
|
453
507
|
op_dict["data_name"].append(tensor['data_name'])
|
|
454
|
-
data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
|
|
455
|
-
if data_name != "-1":
|
|
456
|
-
op_dict["op_name"][-1] = data_name
|
|
457
508
|
|
|
458
509
|
if not op_dict[CompareConst.KWARGS_STRUCT]:
|
|
459
510
|
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
@@ -467,11 +518,90 @@ def print_compare_ends_info():
|
|
|
467
518
|
logger.info('*' * total_len)
|
|
468
519
|
|
|
469
520
|
|
|
521
|
+
def table_value_is_valid(value: str) -> bool:
|
|
522
|
+
if not isinstance(value, str):
|
|
523
|
+
return True
|
|
524
|
+
try:
|
|
525
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
526
|
+
float(value)
|
|
527
|
+
except ValueError:
|
|
528
|
+
# otherwise, they will be considered as formular injections
|
|
529
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
530
|
+
return True
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def get_name_and_state(name):
|
|
534
|
+
"""
|
|
535
|
+
Get api/module name and state
|
|
536
|
+
example:
|
|
537
|
+
name = 'conv2d.forward.1.input.0'
|
|
538
|
+
return: ('conv2d.forward.1.', 'input')
|
|
539
|
+
|
|
540
|
+
name = 'Functional.pad.0.backward.output.0'
|
|
541
|
+
return: ('Functional.pad.0.backward.', 'output')
|
|
542
|
+
|
|
543
|
+
state type: input, output, kwargs, parameters, parameters_grad
|
|
544
|
+
"""
|
|
545
|
+
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
546
|
+
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
547
|
+
|
|
548
|
+
split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
|
|
549
|
+
api = f'{split[0]}.{split[1]}.'
|
|
550
|
+
state_str = split[2]
|
|
551
|
+
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|
|
552
|
+
if not match:
|
|
553
|
+
raise CompareException(f'Invalid name string: {name}')
|
|
554
|
+
if match.group(1):
|
|
555
|
+
api = f'{api}{match.group(1)}'
|
|
556
|
+
state = match.group(2)
|
|
557
|
+
return api, state
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def reorder_op_name_list(op_name_list):
|
|
561
|
+
if not op_name_list:
|
|
562
|
+
return op_name_list
|
|
563
|
+
|
|
564
|
+
parameters = []
|
|
565
|
+
output = []
|
|
566
|
+
parameters_grad = []
|
|
567
|
+
others = []
|
|
568
|
+
for x in op_name_list:
|
|
569
|
+
state = get_name_and_state(x)[1]
|
|
570
|
+
if state == Const.PARAMS:
|
|
571
|
+
parameters.append(x)
|
|
572
|
+
elif state == Const.OUTPUT:
|
|
573
|
+
output.append(x)
|
|
574
|
+
elif state == Const.PARAMS_GRAD:
|
|
575
|
+
parameters_grad.append(x)
|
|
576
|
+
else:
|
|
577
|
+
others.append(x)
|
|
578
|
+
# 合并others, parameters, 和output,确保parameters排在output前面
|
|
579
|
+
op_name_reorder = others + parameters + output + parameters_grad
|
|
580
|
+
return op_name_reorder
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def reorder_op_x_list(op_name_list, summary_list, data_name_list):
|
|
584
|
+
"""对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
|
|
585
|
+
if not op_name_list or not summary_list:
|
|
586
|
+
return op_name_list, summary_list, data_name_list
|
|
587
|
+
|
|
588
|
+
index_map = {name: index for index, name in enumerate(op_name_list)}
|
|
589
|
+
|
|
590
|
+
op_name_reorder = reorder_op_name_list(op_name_list)
|
|
591
|
+
summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
|
|
592
|
+
if data_name_list:
|
|
593
|
+
data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
|
|
594
|
+
else:
|
|
595
|
+
data_name_reorder = data_name_list
|
|
596
|
+
|
|
597
|
+
return op_name_reorder, summary_reorder, data_name_reorder
|
|
598
|
+
|
|
599
|
+
|
|
470
600
|
def _compare_parser(parser):
|
|
471
601
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
472
602
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
473
603
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
474
|
-
help="<Required> The compare task result out path. Default path: ./output",
|
|
604
|
+
help="<Required> The compare task result out path. Default path: ./output",
|
|
475
605
|
required=False, default="./output", nargs="?", const="./output")
|
|
476
606
|
parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
|
|
477
607
|
help="<optional> Whether to save stack info.", required=False)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -38,6 +38,9 @@ class DataCollector:
|
|
|
38
38
|
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
|
|
39
39
|
self.module_count = {}
|
|
40
40
|
self.scope = ScopeFactory(self.config).build_scope()
|
|
41
|
+
self.backward_module_names = {}
|
|
42
|
+
self.optimizer_status = ""
|
|
43
|
+
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
41
44
|
atexit.register(self.write_json)
|
|
42
45
|
|
|
43
46
|
@property
|
|
@@ -53,8 +56,15 @@ class DataCollector:
|
|
|
53
56
|
return (not scope or scope.check(name)) and pid == os.getpid()
|
|
54
57
|
|
|
55
58
|
@staticmethod
|
|
56
|
-
def
|
|
57
|
-
|
|
59
|
+
def set_is_recomputable(data_info, is_recompute):
|
|
60
|
+
if data_info and len(data_info) == 1 and is_recompute is not None: # 正常情况下data_info的长度应改为1
|
|
61
|
+
data_info[list(data_info.keys())[0]]["is_recompute"] = is_recompute
|
|
62
|
+
|
|
63
|
+
def reset_status(self):
|
|
64
|
+
self.optimizer_status = ""
|
|
65
|
+
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
66
|
+
self.data_writer.reset_cache()
|
|
67
|
+
self.backward_module_names.clear()
|
|
58
68
|
|
|
59
69
|
def if_return_forward_new_output(self):
|
|
60
70
|
return self.data_processor.if_return_forward_new_output()
|
|
@@ -79,69 +89,105 @@ class DataCollector:
|
|
|
79
89
|
logger.debug(msg)
|
|
80
90
|
self.data_writer.update_data(data_info)
|
|
81
91
|
|
|
82
|
-
def
|
|
83
|
-
if self.config.
|
|
84
|
-
|
|
92
|
+
def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
93
|
+
if self.config.task == Const.FREE_BENCHMARK:
|
|
94
|
+
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
95
|
+
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
96
|
+
self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
85
100
|
return
|
|
86
101
|
|
|
87
|
-
|
|
88
|
-
if self.
|
|
89
|
-
self.data_processor.
|
|
90
|
-
|
|
102
|
+
data_info = {}
|
|
103
|
+
if self.config.task != Const.STRUCTURE:
|
|
104
|
+
data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
|
|
105
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
106
|
+
if self.config.level == Const.LEVEL_L2:
|
|
91
107
|
return
|
|
92
|
-
logger.info(f"API {name} is inplace.")
|
|
93
|
-
data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
|
|
94
108
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
95
109
|
|
|
96
|
-
def
|
|
110
|
+
def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
97
111
|
self.update_construct(name)
|
|
98
112
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
99
113
|
return
|
|
114
|
+
|
|
115
|
+
data_info = {}
|
|
116
|
+
if self.config.task != Const.STRUCTURE:
|
|
117
|
+
data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
|
|
118
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
100
119
|
if self.config.level == Const.LEVEL_L2:
|
|
101
|
-
|
|
120
|
+
return
|
|
121
|
+
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
122
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
123
|
+
|
|
124
|
+
def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
125
|
+
self.update_construct(name)
|
|
126
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
102
127
|
return
|
|
103
128
|
|
|
104
|
-
|
|
129
|
+
data_info = {}
|
|
130
|
+
if self.config.task != Const.STRUCTURE:
|
|
105
131
|
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
106
|
-
|
|
107
|
-
data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
|
|
132
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
108
133
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
109
134
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
110
135
|
|
|
111
|
-
def backward_data_collect(self, name, module, pid, module_input_output):
|
|
136
|
+
def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
112
137
|
self.update_construct(name)
|
|
113
138
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
114
139
|
return
|
|
115
140
|
|
|
116
|
-
data_info =
|
|
141
|
+
data_info = {}
|
|
142
|
+
if self.config.task != Const.STRUCTURE:
|
|
143
|
+
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
117
144
|
if self.config.level == Const.LEVEL_L2:
|
|
118
145
|
return
|
|
146
|
+
# 获取执行反向的模块名称
|
|
147
|
+
if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
|
|
148
|
+
module_name = name.rsplit(Const.SEP, 2)[0]
|
|
149
|
+
# 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
|
|
150
|
+
self.backward_module_names[module_name] = True
|
|
119
151
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
120
152
|
|
|
121
|
-
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
153
|
+
def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
122
154
|
self.update_construct(name)
|
|
123
155
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
124
156
|
return
|
|
125
157
|
|
|
126
|
-
data_info =
|
|
158
|
+
data_info = {}
|
|
159
|
+
if self.config.task != Const.STRUCTURE:
|
|
160
|
+
data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
|
|
161
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
127
162
|
self.handle_data(name, data_info)
|
|
128
163
|
|
|
129
|
-
def backward_output_data_collect(self, name, module, pid, module_input_output):
|
|
164
|
+
def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
130
165
|
self.update_construct(name)
|
|
131
166
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
132
167
|
return
|
|
133
168
|
|
|
134
|
-
data_info =
|
|
169
|
+
data_info = {}
|
|
170
|
+
if self.config.task != Const.STRUCTURE:
|
|
171
|
+
data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
|
|
172
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
135
173
|
self.handle_data(name, data_info)
|
|
136
174
|
|
|
137
175
|
def update_construct(self, name):
|
|
138
176
|
if self.config.level not in DataCollector.level_without_construct:
|
|
139
|
-
self.
|
|
177
|
+
if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
|
|
178
|
+
if self.optimizer_status_first_start[self.optimizer_status]:
|
|
179
|
+
self.data_writer.update_construct({self.optimizer_status: None})
|
|
180
|
+
self.optimizer_status_first_start[self.optimizer_status] = False
|
|
181
|
+
self.data_writer.update_construct({name: self.optimizer_status})
|
|
182
|
+
else:
|
|
183
|
+
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
140
184
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
141
185
|
|
|
142
186
|
def handle_data(self, name, data_info, flush=False):
|
|
143
187
|
if data_info:
|
|
144
188
|
self.update_data(name, data_info)
|
|
189
|
+
if self.config.async_dump:
|
|
190
|
+
return
|
|
145
191
|
if not flush:
|
|
146
192
|
self.data_writer.flush_data_periodically()
|
|
147
193
|
else:
|
|
@@ -149,7 +195,36 @@ class DataCollector:
|
|
|
149
195
|
|
|
150
196
|
def update_dump_paths(self, *args):
|
|
151
197
|
self.data_writer.update_dump_paths(*args)
|
|
152
|
-
|
|
198
|
+
|
|
199
|
+
def initialize_json_file(self, framework=Const.UNKNOWN_FRAMEWORK):
|
|
200
|
+
self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level, framework=framework)
|
|
153
201
|
|
|
154
202
|
def update_iter(self, current_iter):
|
|
155
203
|
self.data_processor.update_iter(current_iter)
|
|
204
|
+
|
|
205
|
+
def params_data_collect(self, name, param_name, pid, data):
|
|
206
|
+
grad_name = name + Const.SEP + Const.PARAMS_GRAD
|
|
207
|
+
# 校验scope和pid,以及当前name是否有过反向计算
|
|
208
|
+
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
209
|
+
# 如果没有反向计算,则需要清除之前占位写入的grad数据
|
|
210
|
+
if self.data_writer.cache_data.get("data"):
|
|
211
|
+
self.data_writer.cache_data.get("data").pop(grad_name, None)
|
|
212
|
+
return
|
|
213
|
+
data_info = self.data_processor.analyze_params(grad_name, param_name, data)
|
|
214
|
+
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
215
|
+
|
|
216
|
+
def fill_stack_tensor_data(self):
|
|
217
|
+
self.data_writer.fill_stack_tensor_data()
|
|
218
|
+
|
|
219
|
+
def debug_data_collect_forward(self, variable, name_with_count):
|
|
220
|
+
|
|
221
|
+
data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
|
|
222
|
+
self.data_writer.update_debug({name_with_count: data_info})
|
|
223
|
+
|
|
224
|
+
def debug_data_collect_backward(self, variable, grad_name_with_count):
|
|
225
|
+
# prepare all None nested data structure
|
|
226
|
+
all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
|
|
227
|
+
self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
|
|
228
|
+
|
|
229
|
+
# register tensor backward hook
|
|
230
|
+
self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])
|