mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,7 +21,7 @@ from dataclasses import dataclass
|
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
|
|
24
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
24
|
+
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
25
25
|
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
|
|
26
26
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
27
27
|
|
|
@@ -37,13 +37,20 @@ def extract_json(dirname, stack_json=False):
|
|
|
37
37
|
# Provide robustness on invalid directory inputs
|
|
38
38
|
if not json_path:
|
|
39
39
|
if stack_json:
|
|
40
|
-
logger.
|
|
40
|
+
logger.warning(f'stack.json is not found in dump dir {dirname}.')
|
|
41
41
|
else:
|
|
42
42
|
logger.error(f'dump.json is not found in dump dir {dirname}.')
|
|
43
|
-
|
|
43
|
+
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
44
44
|
return json_path
|
|
45
45
|
|
|
46
46
|
|
|
47
|
+
def set_stack_json_path(input_param):
|
|
48
|
+
npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
|
|
49
|
+
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
50
|
+
input_param["stack_json_path"] = stack_path if stack_path else None
|
|
51
|
+
return bool(stack_path)
|
|
52
|
+
|
|
53
|
+
|
|
47
54
|
def check_and_return_dir_contents(dump_dir, prefix):
|
|
48
55
|
"""
|
|
49
56
|
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
@@ -75,6 +82,10 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
75
82
|
|
|
76
83
|
|
|
77
84
|
def rename_api(npu_name, process):
|
|
85
|
+
"""
|
|
86
|
+
原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
|
|
87
|
+
rename后: {api_type}.{api_name}.{input/output}.{参数序号}
|
|
88
|
+
"""
|
|
78
89
|
npu_split = npu_name.split(process)
|
|
79
90
|
try:
|
|
80
91
|
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
@@ -87,17 +98,13 @@ def rename_api(npu_name, process):
|
|
|
87
98
|
|
|
88
99
|
|
|
89
100
|
def read_op(op_data, op_name):
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
op_parsed_list = []
|
|
98
|
-
for name in io_name_mapping:
|
|
99
|
-
if name in op_data:
|
|
100
|
-
op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
|
|
101
|
+
if Const.PARAMS_GRAD in op_name.split(Const.SEP):
|
|
102
|
+
op_parsed_list = op_item_parse(op_data, op_name)
|
|
103
|
+
else:
|
|
104
|
+
op_parsed_list = []
|
|
105
|
+
for name in CompareConst.IO_NAME_MAPPING:
|
|
106
|
+
if name in op_data:
|
|
107
|
+
op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name]))
|
|
101
108
|
return op_parsed_list
|
|
102
109
|
|
|
103
110
|
|
|
@@ -124,11 +131,14 @@ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
|
|
|
124
131
|
return [default_item]
|
|
125
132
|
elif not op_data:
|
|
126
133
|
return []
|
|
127
|
-
|
|
134
|
+
|
|
128
135
|
item_list = []
|
|
129
136
|
if isinstance(op_data, list):
|
|
130
137
|
for i, data in enumerate(op_data):
|
|
131
|
-
|
|
138
|
+
if Const.PARAMS_GRAD not in op_name.split(Const.SEP):
|
|
139
|
+
item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
|
|
140
|
+
else:
|
|
141
|
+
item_list.extend(op_item_parse(data, op_name, depth + 1))
|
|
132
142
|
elif isinstance(op_data, dict):
|
|
133
143
|
if is_leaf_data(op_data):
|
|
134
144
|
return [gen_op_item(op_data, op_name)]
|
|
@@ -144,14 +154,15 @@ def is_leaf_data(op_data):
|
|
|
144
154
|
def gen_op_item(op_data, op_name):
|
|
145
155
|
op_item = {}
|
|
146
156
|
op_item.update(op_data)
|
|
147
|
-
|
|
148
|
-
op_item['data_name'] =
|
|
157
|
+
data_name = op_data.get('data_name') if op_data.get('data_name') else '-1' # 如果是""也返回-1
|
|
158
|
+
op_item['data_name'] = data_name
|
|
159
|
+
op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name
|
|
149
160
|
|
|
150
161
|
params = ['Max', 'Min', 'Mean', 'Norm']
|
|
151
162
|
for i in params:
|
|
152
163
|
if i not in op_item:
|
|
153
164
|
op_item[i] = None
|
|
154
|
-
|
|
165
|
+
|
|
155
166
|
if not op_item.get('dtype'):
|
|
156
167
|
if op_item.get('type') == 'torch.Size':
|
|
157
168
|
op_item['dtype'] = op_data.get('type')
|
|
@@ -166,7 +177,7 @@ def gen_op_item(op_data, op_name):
|
|
|
166
177
|
op_item[i] = op_data.get('value')
|
|
167
178
|
if not op_item.get('md5'):
|
|
168
179
|
op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
|
|
169
|
-
|
|
180
|
+
|
|
170
181
|
return op_item
|
|
171
182
|
|
|
172
183
|
|
|
@@ -276,6 +287,22 @@ def result_item_init(n_info, b_info, dump_mode):
|
|
|
276
287
|
return result_item
|
|
277
288
|
|
|
278
289
|
|
|
290
|
+
def count_struct(op_dict):
|
|
291
|
+
parts = [
|
|
292
|
+
CompareConst.OP_NAME,
|
|
293
|
+
CompareConst.INPUT_STRUCT,
|
|
294
|
+
CompareConst.OUTPUT_STRUCT,
|
|
295
|
+
CompareConst.PARAMS_STRUCT,
|
|
296
|
+
CompareConst.PARAMS_GRAD_STRUCT
|
|
297
|
+
]
|
|
298
|
+
lengths = [len(op_dict.get(part, [])) for part in parts]
|
|
299
|
+
num = lengths[0]
|
|
300
|
+
if num != sum(lengths[1:]):
|
|
301
|
+
logger.error(f"Length of names and structs of op_dict not match. Please check! op_dict: {op_dict}")
|
|
302
|
+
raise CompareException(CompareException.NAMES_STRUCTS_MATCH_ERROR)
|
|
303
|
+
return tuple(lengths)
|
|
304
|
+
|
|
305
|
+
|
|
279
306
|
def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
280
307
|
def get_accuracy_core(n_start, n_len, b_start, b_len, key):
|
|
281
308
|
min_len = min(n_len, b_len)
|
|
@@ -355,31 +382,50 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
355
382
|
|
|
356
383
|
result.append(result_item)
|
|
357
384
|
|
|
358
|
-
n_num =
|
|
359
|
-
b_num =
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
n_num_output
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
385
|
+
n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict)
|
|
386
|
+
b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict)
|
|
387
|
+
|
|
388
|
+
get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT)
|
|
389
|
+
get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params,
|
|
390
|
+
CompareConst.PARAMS_STRUCT)
|
|
391
|
+
get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT)
|
|
392
|
+
get_accuracy_core(n_num_input + n_num_output + n_num_params, n_num_params_grad,
|
|
393
|
+
b_num_input + b_num_output + b_num_params, b_num_params_grad,
|
|
394
|
+
CompareConst.PARAMS_GRAD_STRUCT)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def append_stack_info(result_item, npu_stack_info, index):
|
|
398
|
+
"""添加堆栈信息到 result_item"""
|
|
399
|
+
if npu_stack_info and index == 0:
|
|
400
|
+
result_item.extend(npu_stack_info)
|
|
401
|
+
else:
|
|
402
|
+
result_item.append(CompareConst.NONE)
|
|
368
403
|
|
|
369
404
|
|
|
370
405
|
def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
371
|
-
index_out = 0
|
|
372
406
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
373
407
|
bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
408
|
+
|
|
409
|
+
struct_to_index_mapping = {
|
|
410
|
+
CompareConst.INPUT_STRUCT: 0,
|
|
411
|
+
CompareConst.OUTPUT_STRUCT: 0,
|
|
412
|
+
CompareConst.PARAMS_STRUCT: 0,
|
|
413
|
+
CompareConst.PARAMS_GRAD_STRUCT: 0
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
op_name_list = n_dict.get(CompareConst.OP_NAME)
|
|
417
|
+
summary_list = n_dict.get(Const.SUMMARY)
|
|
418
|
+
data_name_list = n_dict.get('data_name')
|
|
419
|
+
op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
|
|
420
|
+
summary_list,
|
|
421
|
+
data_name_list)
|
|
422
|
+
for index, n_name in enumerate(op_name_reorder):
|
|
423
|
+
_, state = get_name_and_state(n_name)
|
|
424
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
425
|
+
if not struct_key:
|
|
426
|
+
continue
|
|
427
|
+
n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
|
|
428
|
+
struct_to_index_mapping[struct_key] += 1
|
|
383
429
|
|
|
384
430
|
try:
|
|
385
431
|
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
@@ -390,28 +436,26 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
|
390
436
|
f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
|
|
391
437
|
logger.error(err_msg)
|
|
392
438
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
439
|
+
|
|
393
440
|
if dump_mode == Const.MD5:
|
|
394
441
|
result_item.extend([CompareConst.N_A] * 3)
|
|
395
|
-
|
|
396
|
-
result_item.extend(npu_stack_info)
|
|
397
|
-
else:
|
|
398
|
-
result_item.append(CompareConst.NONE)
|
|
442
|
+
append_stack_info(result_item, npu_stack_info, index)
|
|
399
443
|
result.append(result_item)
|
|
400
444
|
continue
|
|
401
445
|
if dump_mode == Const.SUMMARY:
|
|
402
446
|
result_item.extend([CompareConst.N_A] * 8)
|
|
403
|
-
|
|
447
|
+
if dump_mode == Const.ALL:
|
|
404
448
|
result_item.extend([CompareConst.N_A] * 5)
|
|
405
|
-
|
|
406
|
-
|
|
449
|
+
|
|
450
|
+
npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
|
|
407
451
|
bench_summary_data = [CompareConst.N_A] * 4
|
|
452
|
+
result_item.extend(npu_summary_data)
|
|
408
453
|
result_item.extend(bench_summary_data)
|
|
454
|
+
err_msg = CompareConst.NO_BENCH
|
|
455
|
+
accuracy_check_res = CompareConst.N_A
|
|
409
456
|
result_item.append(accuracy_check_res)
|
|
410
457
|
result_item.append(err_msg)
|
|
411
|
-
|
|
412
|
-
result_item.extend(npu_stack_info)
|
|
413
|
-
else:
|
|
414
|
-
result_item.append(CompareConst.NONE)
|
|
458
|
+
append_stack_info(result_item, npu_stack_info, index)
|
|
415
459
|
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
416
460
|
result_item.extend(["-1"])
|
|
417
461
|
result.append(result_item)
|
|
@@ -423,6 +467,8 @@ def merge_tensor(tensor_list, dump_mode):
|
|
|
423
467
|
op_dict[CompareConst.INPUT_STRUCT] = []
|
|
424
468
|
op_dict[CompareConst.KWARGS_STRUCT] = []
|
|
425
469
|
op_dict[CompareConst.OUTPUT_STRUCT] = []
|
|
470
|
+
op_dict[CompareConst.PARAMS_STRUCT] = []
|
|
471
|
+
op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
|
|
426
472
|
op_dict[Const.SUMMARY] = []
|
|
427
473
|
op_dict["stack_info"] = []
|
|
428
474
|
|
|
@@ -430,30 +476,25 @@ def merge_tensor(tensor_list, dump_mode):
|
|
|
430
476
|
op_dict["data_name"] = []
|
|
431
477
|
|
|
432
478
|
for tensor in tensor_list:
|
|
479
|
+
# A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
|
|
433
480
|
if len(tensor) == 2:
|
|
434
481
|
op_dict['stack_info'].append(tensor['full_info'])
|
|
435
482
|
break
|
|
483
|
+
|
|
436
484
|
op_dict["op_name"].append(tensor['full_op_name'])
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
447
|
-
else:
|
|
448
|
-
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
449
|
-
break
|
|
485
|
+
|
|
486
|
+
_, state = get_name_and_state(tensor['full_op_name'])
|
|
487
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
488
|
+
if not struct_key:
|
|
489
|
+
continue
|
|
490
|
+
if dump_mode == Const.MD5:
|
|
491
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
492
|
+
else:
|
|
493
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
450
494
|
op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
|
|
451
495
|
|
|
452
496
|
if dump_mode == Const.ALL:
|
|
453
497
|
op_dict["data_name"].append(tensor['data_name'])
|
|
454
|
-
data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
|
|
455
|
-
if data_name != "-1":
|
|
456
|
-
op_dict["op_name"][-1] = data_name
|
|
457
498
|
|
|
458
499
|
if not op_dict[CompareConst.KWARGS_STRUCT]:
|
|
459
500
|
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
@@ -467,11 +508,90 @@ def print_compare_ends_info():
|
|
|
467
508
|
logger.info('*' * total_len)
|
|
468
509
|
|
|
469
510
|
|
|
511
|
+
def table_value_is_valid(value: str) -> bool:
|
|
512
|
+
if not isinstance(value, str):
|
|
513
|
+
return True
|
|
514
|
+
try:
|
|
515
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
516
|
+
float(value)
|
|
517
|
+
except ValueError:
|
|
518
|
+
# otherwise, they will be considered as formular injections
|
|
519
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
520
|
+
return True
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def get_name_and_state(name):
|
|
524
|
+
"""
|
|
525
|
+
Get api/module name and state
|
|
526
|
+
example:
|
|
527
|
+
name = 'conv2d.forward.1.input.0'
|
|
528
|
+
return: ('conv2d.forward.1.', 'input')
|
|
529
|
+
|
|
530
|
+
name = 'Functional.pad.0.backward.output.0'
|
|
531
|
+
return: ('Functional.pad.0.backward.', 'output')
|
|
532
|
+
|
|
533
|
+
state type: input, output, kwargs, parameters, parameters_grad
|
|
534
|
+
"""
|
|
535
|
+
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
536
|
+
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
537
|
+
|
|
538
|
+
split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
|
|
539
|
+
api = f'{split[0]}.{split[1]}.'
|
|
540
|
+
state_str = split[2]
|
|
541
|
+
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|
|
542
|
+
if not match:
|
|
543
|
+
raise CompareException(f'Invalid name string: {name}')
|
|
544
|
+
if match.group(1):
|
|
545
|
+
api = f'{api}{match.group(1)}'
|
|
546
|
+
state = match.group(2)
|
|
547
|
+
return api, state
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def reorder_op_name_list(op_name_list):
|
|
551
|
+
if not op_name_list:
|
|
552
|
+
return op_name_list
|
|
553
|
+
|
|
554
|
+
parameters = []
|
|
555
|
+
output = []
|
|
556
|
+
parameters_grad = []
|
|
557
|
+
others = []
|
|
558
|
+
for x in op_name_list:
|
|
559
|
+
state = get_name_and_state(x)[1]
|
|
560
|
+
if state == Const.PARAMS:
|
|
561
|
+
parameters.append(x)
|
|
562
|
+
elif state == Const.OUTPUT:
|
|
563
|
+
output.append(x)
|
|
564
|
+
elif state == Const.PARAMS_GRAD:
|
|
565
|
+
parameters_grad.append(x)
|
|
566
|
+
else:
|
|
567
|
+
others.append(x)
|
|
568
|
+
# 合并others, parameters, 和output,确保parameters排在output前面
|
|
569
|
+
op_name_reorder = others + parameters + output + parameters_grad
|
|
570
|
+
return op_name_reorder
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def reorder_op_x_list(op_name_list, summary_list, data_name_list):
|
|
574
|
+
"""对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
|
|
575
|
+
if not op_name_list or not summary_list:
|
|
576
|
+
return op_name_list, summary_list, data_name_list
|
|
577
|
+
|
|
578
|
+
index_map = {name: index for index, name in enumerate(op_name_list)}
|
|
579
|
+
|
|
580
|
+
op_name_reorder = reorder_op_name_list(op_name_list)
|
|
581
|
+
summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
|
|
582
|
+
if data_name_list:
|
|
583
|
+
data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
|
|
584
|
+
else:
|
|
585
|
+
data_name_reorder = data_name_list
|
|
586
|
+
|
|
587
|
+
return op_name_reorder, summary_reorder, data_name_reorder
|
|
588
|
+
|
|
589
|
+
|
|
470
590
|
def _compare_parser(parser):
|
|
471
591
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
472
592
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
473
593
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
474
|
-
help="<Required> The compare task result out path. Default path: ./output",
|
|
594
|
+
help="<Required> The compare task result out path. Default path: ./output",
|
|
475
595
|
required=False, default="./output", nargs="?", const="./output")
|
|
476
596
|
parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
|
|
477
597
|
help="<optional> Whether to save stack info.", required=False)
|
|
@@ -38,6 +38,8 @@ class DataCollector:
|
|
|
38
38
|
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
|
|
39
39
|
self.module_count = {}
|
|
40
40
|
self.scope = ScopeFactory(self.config).build_scope()
|
|
41
|
+
self.backward_module_names = {}
|
|
42
|
+
self.optimizer_status = ""
|
|
41
43
|
atexit.register(self.write_json)
|
|
42
44
|
|
|
43
45
|
@property
|
|
@@ -52,10 +54,6 @@ class DataCollector:
|
|
|
52
54
|
def check_scope_and_pid(scope, name, pid):
|
|
53
55
|
return (not scope or scope.check(name)) and pid == os.getpid()
|
|
54
56
|
|
|
55
|
-
@staticmethod
|
|
56
|
-
def is_inplace(module):
|
|
57
|
-
return getattr(module, "op_is_inplace", False)
|
|
58
|
-
|
|
59
57
|
def if_return_forward_new_output(self):
|
|
60
58
|
return self.data_processor.if_return_forward_new_output()
|
|
61
59
|
|
|
@@ -79,32 +77,38 @@ class DataCollector:
|
|
|
79
77
|
logger.debug(msg)
|
|
80
78
|
self.data_writer.update_data(data_info)
|
|
81
79
|
|
|
82
|
-
def
|
|
83
|
-
if self.config.
|
|
84
|
-
|
|
80
|
+
def forward_input_data_collect(self, name, module, pid, module_input_output):
|
|
81
|
+
if self.config.task == Const.FREE_BENCHMARK:
|
|
82
|
+
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
83
|
+
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
84
|
+
self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
85
88
|
return
|
|
86
89
|
|
|
87
|
-
|
|
88
|
-
if self.
|
|
89
|
-
self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
|
|
90
|
-
if not self.is_inplace(module) or not self.check_scope_and_pid(self.scope, name, pid):
|
|
90
|
+
data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
|
|
91
|
+
if self.config.level == Const.LEVEL_L2:
|
|
91
92
|
return
|
|
92
|
-
logger.info(f"API {name} is inplace.")
|
|
93
|
-
data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
|
|
94
93
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
95
94
|
|
|
96
|
-
def
|
|
95
|
+
def forward_output_data_collect(self, name, module, pid, module_input_output):
|
|
97
96
|
self.update_construct(name)
|
|
98
97
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
99
98
|
return
|
|
99
|
+
|
|
100
|
+
data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
|
|
100
101
|
if self.config.level == Const.LEVEL_L2:
|
|
101
|
-
self.data_processor.analyze_forward(name, module, module_input_output)
|
|
102
102
|
return
|
|
103
|
+
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
104
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
103
105
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
def forward_data_collect(self, name, module, pid, module_input_output):
|
|
107
|
+
self.update_construct(name)
|
|
108
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
108
112
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
109
113
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
110
114
|
|
|
@@ -116,6 +120,11 @@ class DataCollector:
|
|
|
116
120
|
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
117
121
|
if self.config.level == Const.LEVEL_L2:
|
|
118
122
|
return
|
|
123
|
+
# 获取执行反向的模块名称
|
|
124
|
+
if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
|
|
125
|
+
module_name = name.rsplit(Const.SEP, 2)[0]
|
|
126
|
+
# 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
|
|
127
|
+
self.backward_module_names[module_name] = True
|
|
119
128
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
120
129
|
|
|
121
130
|
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
@@ -136,12 +145,17 @@ class DataCollector:
|
|
|
136
145
|
|
|
137
146
|
def update_construct(self, name):
|
|
138
147
|
if self.config.level not in DataCollector.level_without_construct:
|
|
139
|
-
self.
|
|
148
|
+
if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
|
|
149
|
+
self.data_writer.update_construct({name: self.optimizer_status})
|
|
150
|
+
else:
|
|
151
|
+
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
140
152
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
141
153
|
|
|
142
154
|
def handle_data(self, name, data_info, flush=False):
|
|
143
155
|
if data_info:
|
|
144
156
|
self.update_data(name, data_info)
|
|
157
|
+
if self.config.async_dump:
|
|
158
|
+
return
|
|
145
159
|
if not flush:
|
|
146
160
|
self.data_writer.flush_data_periodically()
|
|
147
161
|
else:
|
|
@@ -149,7 +163,23 @@ class DataCollector:
|
|
|
149
163
|
|
|
150
164
|
def update_dump_paths(self, *args):
|
|
151
165
|
self.data_writer.update_dump_paths(*args)
|
|
152
|
-
|
|
166
|
+
|
|
167
|
+
def initialize_json_file(self, framework=Const.UNKNOWN_FRAMEWORK):
|
|
168
|
+
self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level, framework=framework)
|
|
153
169
|
|
|
154
170
|
def update_iter(self, current_iter):
|
|
155
171
|
self.data_processor.update_iter(current_iter)
|
|
172
|
+
|
|
173
|
+
def params_data_collect(self, name, param_name, pid, data):
|
|
174
|
+
grad_name = name + Const.SEP + Const.PARAMS_GRAD
|
|
175
|
+
# 校验scope和pid,以及当前name是否有过反向计算
|
|
176
|
+
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
177
|
+
# 如果没有反向计算,则需要清除之前占位写入的grad数据
|
|
178
|
+
if self.data_writer.cache_data.get("data"):
|
|
179
|
+
self.data_writer.cache_data.get("data").pop(grad_name, None)
|
|
180
|
+
return
|
|
181
|
+
data_info = self.data_processor.analyze_params(grad_name, param_name, data)
|
|
182
|
+
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
183
|
+
|
|
184
|
+
def fill_stack_tensor_data(self):
|
|
185
|
+
self.data_writer.fill_stack_tensor_data()
|
|
@@ -39,9 +39,8 @@ class ModuleForwardInputsOutputs:
|
|
|
39
39
|
def output_tuple(self):
|
|
40
40
|
return convert_tuple(self.output)
|
|
41
41
|
|
|
42
|
-
def
|
|
43
|
-
|
|
44
|
-
return args
|
|
42
|
+
def update_output_with_args_and_kwargs(self):
|
|
43
|
+
self.output = self.args + tuple(self.kwargs.values())
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
@dataclass
|
|
@@ -77,11 +76,12 @@ class ModuleBackwardOutputs:
|
|
|
77
76
|
|
|
78
77
|
|
|
79
78
|
class TensorStatInfo:
|
|
80
|
-
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
79
|
+
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
|
|
81
80
|
self.max = max_val
|
|
82
81
|
self.min = min_val
|
|
83
82
|
self.mean = mean_val
|
|
84
83
|
self.norm = norm_val
|
|
84
|
+
self.stack_tensor_stat = stack_tensor_stat
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
class BaseDataProcessor:
|
|
@@ -102,6 +102,7 @@ class BaseDataProcessor:
|
|
|
102
102
|
self.current_iter = 0
|
|
103
103
|
self._return_forward_new_output = False
|
|
104
104
|
self._forward_new_output = None
|
|
105
|
+
self.save_name = None
|
|
105
106
|
if hasattr(config, "data_mode"):
|
|
106
107
|
self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
|
|
107
108
|
|
|
@@ -223,7 +224,7 @@ class BaseDataProcessor:
|
|
|
223
224
|
elif isinstance(args, dict):
|
|
224
225
|
return cls.apply_transform_dict(args, transform, depth)
|
|
225
226
|
elif args is not None:
|
|
226
|
-
logger.
|
|
227
|
+
logger.debug(f"Data type {type(args)} is not supported.")
|
|
227
228
|
return None
|
|
228
229
|
else:
|
|
229
230
|
return None
|
|
@@ -273,13 +274,10 @@ class BaseDataProcessor:
|
|
|
273
274
|
"""
|
|
274
275
|
return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
|
|
275
276
|
|
|
276
|
-
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
277
|
-
pass
|
|
278
|
-
|
|
279
277
|
def analyze_element(self, element):
|
|
280
278
|
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
281
279
|
|
|
282
|
-
def
|
|
280
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
283
281
|
api_info_struct = {}
|
|
284
282
|
# check whether data_mode contains forward or input
|
|
285
283
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
@@ -291,16 +289,22 @@ class BaseDataProcessor:
|
|
|
291
289
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
292
290
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
293
291
|
|
|
294
|
-
|
|
292
|
+
return api_info_struct
|
|
293
|
+
|
|
294
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
295
|
+
api_info_struct = {}
|
|
296
|
+
# check whether data_mode contains forward or input
|
|
295
297
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
296
|
-
api_info_struct[name] =
|
|
298
|
+
api_info_struct[name] = {}
|
|
297
299
|
self.api_data_category = Const.OUTPUT
|
|
298
300
|
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
299
301
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
302
|
+
|
|
300
303
|
return api_info_struct
|
|
301
304
|
|
|
302
|
-
def
|
|
305
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
303
306
|
api_info_struct = {}
|
|
307
|
+
# check whether data_mode contains forward or input
|
|
304
308
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
305
309
|
api_info_struct[name] = {}
|
|
306
310
|
self.api_data_category = Const.INPUT
|
|
@@ -309,16 +313,18 @@ class BaseDataProcessor:
|
|
|
309
313
|
self.api_data_category = Const.KWARGS
|
|
310
314
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
311
315
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
312
|
-
return api_info_struct
|
|
313
316
|
|
|
314
|
-
|
|
315
|
-
concat_args = module_input_output.concat_args_and_kwargs()
|
|
316
|
-
api_info_struct = {}
|
|
317
|
+
# check whether data_mode contains forward or output
|
|
317
318
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
318
|
-
api_info_struct[name] = {}
|
|
319
|
+
api_info_struct[name] = api_info_struct.get(name, {})
|
|
319
320
|
self.api_data_category = Const.OUTPUT
|
|
320
|
-
output_info_list = self.analyze_element(
|
|
321
|
+
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
321
322
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
323
|
+
|
|
324
|
+
if name in api_info_struct and hasattr(module_input_output, Const.PARAMS):
|
|
325
|
+
self.api_data_category = Const.PARAMS
|
|
326
|
+
api_info_struct[name][Const.PARAMS] = self.analyze_element(getattr(module_input_output, Const.PARAMS))
|
|
327
|
+
|
|
322
328
|
return api_info_struct
|
|
323
329
|
|
|
324
330
|
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
@@ -359,9 +365,21 @@ class BaseDataProcessor:
|
|
|
359
365
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
360
366
|
return api_info_struct
|
|
361
367
|
|
|
368
|
+
def analyze_params(self, name, param_name, grad):
|
|
369
|
+
api_info_struct = {}
|
|
370
|
+
self.save_name = name + Const.SEP + param_name
|
|
371
|
+
data_info = self.analyze_element(grad)
|
|
372
|
+
grad_info_dict = {param_name: [data_info]}
|
|
373
|
+
api_info_struct[name] = grad_info_dict
|
|
374
|
+
return api_info_struct
|
|
375
|
+
|
|
362
376
|
def get_save_file_path(self, suffix):
|
|
363
377
|
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
|
|
364
|
-
|
|
365
|
-
|
|
378
|
+
if self.save_name is not None:
|
|
379
|
+
dump_data_name = (self.save_name + file_format)
|
|
380
|
+
self.save_name = None
|
|
381
|
+
else:
|
|
382
|
+
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
|
|
383
|
+
suffix + file_format)
|
|
366
384
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
367
385
|
return dump_data_name, file_path
|