mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -22,6 +22,7 @@ from msprobe.core.config_check.config_checker import register_checker_item, regi
|
|
|
22
22
|
from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features
|
|
23
23
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
24
24
|
from msprobe.core.common.framework_adapter import FmkAdp
|
|
25
|
+
from msprobe.core.common.const import Const
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
@recursion_depth_decorator("config_check: process_obj")
|
|
@@ -134,5 +135,5 @@ class DatasetChecker(BaseChecker):
|
|
|
134
135
|
cmp_dataset_pack_path = os.path.join(cmp_dir, DatasetChecker.target_name_in_zip)
|
|
135
136
|
|
|
136
137
|
df = compare_dataset(bench_dataset_pack_path, cmp_dataset_pack_path)
|
|
137
|
-
pass_check = False not in df['equal'].values
|
|
138
|
+
pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR
|
|
138
139
|
return DatasetChecker.target_name_in_zip, pass_check, df
|
|
@@ -21,7 +21,7 @@ import pandas as pd
|
|
|
21
21
|
from msprobe.core.common.file_utils import load_json, load_yaml, create_file_with_content, create_file_in_zip
|
|
22
22
|
from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
23
23
|
from msprobe.core.config_check.config_checker import register_checker_item
|
|
24
|
-
from msprobe.core.config_check.utils.utils import config_checking_print
|
|
24
|
+
from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check
|
|
25
25
|
from msprobe.core.common.const import Const
|
|
26
26
|
|
|
27
27
|
|
|
@@ -59,17 +59,17 @@ def compare_env_data(npu_path, bench_path):
|
|
|
59
59
|
cmp_env_name = cmp_env["name"]
|
|
60
60
|
cmp_value = cmp_data.get(cmp_env_name, value[cmp_type]["default_value"])
|
|
61
61
|
if not bench_env:
|
|
62
|
-
data.append(["only cmp has this env", cmp_env["name"], "", cmp_value,
|
|
62
|
+
data.append(["only cmp has this env", cmp_env["name"], "", cmp_value, Const.CONFIG_CHECK_WARNING])
|
|
63
63
|
continue
|
|
64
64
|
bench_env_name = bench_env["name"]
|
|
65
65
|
bench_value = bench_data.get(bench_env_name, value[bench_type]["default_value"])
|
|
66
66
|
if cmp_value != bench_value:
|
|
67
|
-
data.append([bench_env_name, cmp_env_name, bench_value, cmp_value,
|
|
67
|
+
data.append([bench_env_name, cmp_env_name, bench_value, cmp_value, Const.CONFIG_CHECK_ERROR])
|
|
68
68
|
else:
|
|
69
69
|
bench_env_name = bench_env["name"]
|
|
70
70
|
bench_value = bench_data.get(bench_env_name) if bench_data.get(bench_env_name) else value[bench_type][
|
|
71
71
|
"default_value"]
|
|
72
|
-
data.append([bench_env_name, "only bench has this env", bench_value, "",
|
|
72
|
+
data.append([bench_env_name, "only bench has this env", bench_value, "", Const.CONFIG_CHECK_WARNING])
|
|
73
73
|
df = pd.DataFrame(data, columns=EnvArgsChecker.result_header)
|
|
74
74
|
return df
|
|
75
75
|
|
|
@@ -92,5 +92,5 @@ class EnvArgsChecker(BaseChecker):
|
|
|
92
92
|
bench_env_data = os.path.join(bench_dir, EnvArgsChecker.target_name_in_zip)
|
|
93
93
|
cmp_env_data = os.path.join(cmp_dir, EnvArgsChecker.target_name_in_zip)
|
|
94
94
|
df = compare_env_data(bench_env_data, cmp_env_data)
|
|
95
|
-
pass_check =
|
|
95
|
+
pass_check = process_pass_check(df['level'].values)
|
|
96
96
|
return EnvArgsChecker.target_name_in_zip, pass_check, df
|
|
@@ -23,7 +23,7 @@ import pandas as pd
|
|
|
23
23
|
from msprobe.core.common.utils import check_extern_input_list
|
|
24
24
|
from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
25
25
|
from msprobe.core.config_check.config_checker import register_checker_item
|
|
26
|
-
from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict
|
|
26
|
+
from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict, process_pass_check
|
|
27
27
|
from msprobe.core.config_check.utils.hyperparameter_parser import ParserFactory
|
|
28
28
|
from msprobe.core.common.file_utils import (check_file_or_directory_path, create_file_in_zip, load_json,
|
|
29
29
|
load_yaml)
|
|
@@ -36,6 +36,20 @@ parameter_name_mapping = load_yaml(os.path.realpath(hyperparameters_path))
|
|
|
36
36
|
hyperparameters_dict = {}
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def refine_json_keys(json_dcit):
|
|
40
|
+
new_dict = {}
|
|
41
|
+
for key in json_dcit.keys():
|
|
42
|
+
new_key = key.split(Const.SEP)[-1].replace("-", "_")
|
|
43
|
+
new_dict[new_key] = key
|
|
44
|
+
return new_dict
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def to_str_if_number(value):
|
|
48
|
+
if isinstance(value, (int, float)):
|
|
49
|
+
return str(value)
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
|
|
39
53
|
@register_checker_item("hyperparameter")
|
|
40
54
|
class HyperparameterChecker(BaseChecker):
|
|
41
55
|
target_name_in_zip = "hyperparameters"
|
|
@@ -86,29 +100,35 @@ class HyperparameterChecker(BaseChecker):
|
|
|
86
100
|
all_diffs.extend(
|
|
87
101
|
HyperparameterChecker.compare_param(bench_hyperparameters, cmp_hyperparameters, file_name))
|
|
88
102
|
df = pd.DataFrame(all_diffs, columns=HyperparameterChecker.result_header)
|
|
89
|
-
pass_check =
|
|
103
|
+
pass_check = process_pass_check(df["level"].values)
|
|
90
104
|
return HyperparameterChecker.target_name_in_zip, pass_check, df
|
|
91
105
|
|
|
92
106
|
@staticmethod
|
|
93
107
|
def compare_param(bench_params, cmp_params, file_name):
|
|
94
108
|
all_diffs = []
|
|
95
|
-
|
|
96
|
-
|
|
109
|
+
bench_params_refined = refine_json_keys(bench_params)
|
|
110
|
+
cmp_params_refined = refine_json_keys(cmp_params)
|
|
111
|
+
|
|
112
|
+
for bench_param_name in bench_params_refined.keys():
|
|
97
113
|
matched_cmp_param_name, matched_with = HyperparameterChecker._fuzzy_match_parameter(bench_param_name,
|
|
98
|
-
|
|
99
|
-
|
|
114
|
+
cmp_params_refined)
|
|
115
|
+
matched_cmp_param_name = cmp_params_refined.get(matched_cmp_param_name)
|
|
116
|
+
bench_param_name = bench_params_refined.get(bench_param_name)
|
|
117
|
+
bench_param_value = to_str_if_number(bench_params[bench_param_name])
|
|
100
118
|
if matched_cmp_param_name:
|
|
101
|
-
cmp_param_value = cmp_params[matched_cmp_param_name]
|
|
119
|
+
cmp_param_value = to_str_if_number(cmp_params[matched_cmp_param_name])
|
|
102
120
|
if bench_param_value != cmp_param_value:
|
|
103
121
|
all_diffs.append(
|
|
104
122
|
[file_name, bench_param_name, matched_cmp_param_name, bench_param_value, cmp_param_value,
|
|
105
|
-
matched_with,
|
|
123
|
+
matched_with, Const.CONFIG_CHECK_ERROR])
|
|
106
124
|
del cmp_params[matched_cmp_param_name]
|
|
107
125
|
else:
|
|
108
126
|
all_diffs.append(
|
|
109
|
-
[file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "",
|
|
127
|
+
[file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "",
|
|
128
|
+
Const.CONFIG_CHECK_WARNING])
|
|
110
129
|
for cmp_param_name, cmp_param_value in cmp_params.items():
|
|
111
|
-
all_diffs.append(
|
|
130
|
+
all_diffs.append(
|
|
131
|
+
[file_name, "Only in comparison", cmp_param_name, "", cmp_param_value, "", Const.CONFIG_CHECK_WARNING])
|
|
112
132
|
all_diffs.sort()
|
|
113
133
|
return all_diffs
|
|
114
134
|
|
|
@@ -23,8 +23,9 @@ except ImportError:
|
|
|
23
23
|
from msprobe.core.common.file_utils import load_yaml, create_file_in_zip
|
|
24
24
|
from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
25
25
|
from msprobe.core.config_check.config_checker import register_checker_item
|
|
26
|
-
from msprobe.core.config_check.utils.utils import config_checking_print
|
|
26
|
+
from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check
|
|
27
27
|
from msprobe.core.common.file_utils import FileOpen, save_excel
|
|
28
|
+
from msprobe.core.common.const import Const
|
|
28
29
|
|
|
29
30
|
dirpath = os.path.dirname(__file__)
|
|
30
31
|
depend_path = os.path.join(dirpath, "../resource/dependency.yaml")
|
|
@@ -62,7 +63,7 @@ def compare_pip_data(bench_pip_path, cmp_pip_path, fmk):
|
|
|
62
63
|
if bench_version != cmp_version:
|
|
63
64
|
data.append([package, bench_version if bench_version else 'None',
|
|
64
65
|
cmp_version if cmp_version else 'None',
|
|
65
|
-
|
|
66
|
+
Const.CONFIG_CHECK_ERROR])
|
|
66
67
|
|
|
67
68
|
df = pd.DataFrame(data, columns=PipPackageChecker.result_header)
|
|
68
69
|
return df
|
|
@@ -86,5 +87,5 @@ class PipPackageChecker(BaseChecker):
|
|
|
86
87
|
bench_pip_path = os.path.join(bench_dir, PipPackageChecker.target_name_in_zip)
|
|
87
88
|
cmp_pip_path = os.path.join(cmp_dir, PipPackageChecker.target_name_in_zip)
|
|
88
89
|
df = compare_pip_data(bench_pip_path, cmp_pip_path, fmk)
|
|
89
|
-
pass_check =
|
|
90
|
+
pass_check = process_pass_check(df['level'].values)
|
|
90
91
|
return PipPackageChecker.target_name_in_zip, pass_check, df
|
|
@@ -280,9 +280,9 @@ def mindspore_patchs():
|
|
|
280
280
|
import mindspore
|
|
281
281
|
|
|
282
282
|
mindspore_ops_patches = {
|
|
283
|
-
'rand': mindspore.ops.
|
|
283
|
+
'rand': mindspore.ops.rand,
|
|
284
284
|
'randint': mindspore.ops.randint,
|
|
285
|
-
'randn': mindspore.ops.
|
|
285
|
+
'randn': mindspore.ops.randn
|
|
286
286
|
}
|
|
287
287
|
for name, func in mindspore_ops_patches.items():
|
|
288
288
|
setattr(mindspore.ops, name, track_random_call(func, f"mindspore.ops.{name}"))
|
|
@@ -331,7 +331,7 @@ class RandomChecker(BaseChecker):
|
|
|
331
331
|
cmp_stats_path = os.path.join(cmp_dir, RandomChecker.target_name_in_zip)
|
|
332
332
|
|
|
333
333
|
df = compare_random_calls(bench_stats_path, cmp_stats_path)
|
|
334
|
-
pass_check = False not in df['check_result'].values
|
|
334
|
+
pass_check = Const.CONFIG_CHECK_PASS if False not in df['check_result'].values else Const.CONFIG_CHECK_ERROR
|
|
335
335
|
|
|
336
336
|
return RandomChecker.target_name_in_zip, pass_check, df
|
|
337
337
|
|
|
@@ -22,6 +22,7 @@ from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
|
22
22
|
from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list
|
|
23
23
|
from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features
|
|
24
24
|
from msprobe.core.common.framework_adapter import FmkAdp
|
|
25
|
+
from msprobe.core.common.const import Const
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def collect_weights_data(model):
|
|
@@ -143,5 +144,5 @@ class WeightsChecker(BaseChecker):
|
|
|
143
144
|
bench_weight_pack_path = os.path.join(bench_dir, WeightsChecker.target_name_in_zip)
|
|
144
145
|
cmp_weight_pack_path = os.path.join(cmp_dir, WeightsChecker.target_name_in_zip)
|
|
145
146
|
df = compare_weight(bench_weight_pack_path, cmp_weight_pack_path)
|
|
146
|
-
pass_check = False not in df['equal'].values
|
|
147
|
+
pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR
|
|
147
148
|
return WeightsChecker.target_name_in_zip, pass_check, df
|
|
@@ -138,6 +138,8 @@ def _consolidate_tp_weights(weights: Dict) -> Dict:
|
|
|
138
138
|
def _parse_num_layers_per_stage(tp_partition):
|
|
139
139
|
match = [re.findall(LAYER_IDX_PATTERN, key) for key in tp_partition.keys()]
|
|
140
140
|
layer_idx = [int(i[0]) for i in match if i]
|
|
141
|
+
if not layer_idx:
|
|
142
|
+
return 1
|
|
141
143
|
num_layers_per_pipeline_stage = max(layer_idx) + 1
|
|
142
144
|
|
|
143
145
|
return num_layers_per_pipeline_stage
|
|
@@ -96,9 +96,13 @@ class YamlParser(Parser):
|
|
|
96
96
|
new_prefix = prefix + Const.SEP + key if prefix else key
|
|
97
97
|
self.recursive_parse_parameters(value, new_prefix)
|
|
98
98
|
elif isinstance(parameters, list):
|
|
99
|
-
for
|
|
100
|
-
self.
|
|
101
|
-
|
|
99
|
+
if all(isinstance(x, (int, float, str, bool, list))for x in parameters):
|
|
100
|
+
self.hyperparameters.update({prefix: parameters})
|
|
101
|
+
else:
|
|
102
|
+
for idx, value in enumerate(parameters):
|
|
103
|
+
new_prefix = prefix + Const.SEP + str(idx) if prefix else str(idx)
|
|
104
|
+
self.recursive_parse_parameters(value, new_prefix)
|
|
105
|
+
elif isinstance(parameters, (int, float, str, bool)):
|
|
102
106
|
self.hyperparameters.update({prefix: parameters})
|
|
103
107
|
|
|
104
108
|
|
|
@@ -19,6 +19,7 @@ import hashlib
|
|
|
19
19
|
|
|
20
20
|
from msprobe.core.common.framework_adapter import FmkAdp
|
|
21
21
|
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
def merge_keys(dir_0, dir_1):
|
|
@@ -105,3 +106,12 @@ def update_dict(ori_dict, new_dict):
|
|
|
105
106
|
ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]}
|
|
106
107
|
else:
|
|
107
108
|
ori_dict[key] = value
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def process_pass_check(data):
|
|
112
|
+
if Const.CONFIG_CHECK_ERROR in data:
|
|
113
|
+
return Const.CONFIG_CHECK_ERROR
|
|
114
|
+
elif Const.CONFIG_CHECK_WARNING in data:
|
|
115
|
+
return Const.CONFIG_CHECK_WARNING
|
|
116
|
+
else:
|
|
117
|
+
return Const.CONFIG_CHECK_PASS
|
|
@@ -35,7 +35,7 @@ class ApiWrapper:
|
|
|
35
35
|
def __init__(
|
|
36
36
|
self, api_types: Dict[str, Dict[str, Any]],
|
|
37
37
|
api_list_paths: Union[str, List[str], Tuple[str]],
|
|
38
|
-
|
|
38
|
+
blacklist: Union[List[str], Tuple[str]] = None
|
|
39
39
|
):
|
|
40
40
|
self.api_types = api_types
|
|
41
41
|
if not isinstance(api_list_paths, (list, tuple)):
|
|
@@ -44,7 +44,7 @@ class ApiWrapper:
|
|
|
44
44
|
raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
|
|
45
45
|
"when api_list_paths is a list or tuple.")
|
|
46
46
|
self.api_list_paths = api_list_paths
|
|
47
|
-
self.
|
|
47
|
+
self.blacklist = blacklist if blacklist else []
|
|
48
48
|
self.api_names = self._get_api_names()
|
|
49
49
|
self.wrapped_api_functions = dict()
|
|
50
50
|
|
|
@@ -80,6 +80,26 @@ class ApiWrapper:
|
|
|
80
80
|
|
|
81
81
|
return True, args, kwargs
|
|
82
82
|
|
|
83
|
+
def wrap_api_func(self, api_name, api_func, prefix, hook_build_func, api_template):
|
|
84
|
+
api_instance = api_template(api_name, api_func, prefix, hook_build_func)
|
|
85
|
+
|
|
86
|
+
def api_function(*args, **kwargs):
|
|
87
|
+
api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1])
|
|
88
|
+
enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, api_func, args, kwargs)
|
|
89
|
+
if not enable_wrap:
|
|
90
|
+
logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. '
|
|
91
|
+
'It may be fixed by passing the value of "self" '
|
|
92
|
+
'as a positional argument instead of a keyword argument. ')
|
|
93
|
+
return api_func(*args, **kwargs)
|
|
94
|
+
return api_instance(*args, **kwargs)
|
|
95
|
+
|
|
96
|
+
for attr_name in Const.API_ATTR_LIST:
|
|
97
|
+
if hasattr(api_func, attr_name):
|
|
98
|
+
attr_value = getattr(api_func, attr_name)
|
|
99
|
+
setattr(api_function, attr_name, attr_value)
|
|
100
|
+
|
|
101
|
+
return api_function
|
|
102
|
+
|
|
83
103
|
def wrap_api(
|
|
84
104
|
self, api_templates, hook_build_func: Optional[Callable]
|
|
85
105
|
):
|
|
@@ -100,23 +120,17 @@ class ApiWrapper:
|
|
|
100
120
|
api_template = api_templates[index]
|
|
101
121
|
index += 1
|
|
102
122
|
for api_name in self.api_names.get(framework, {}).get(api_type, []):
|
|
103
|
-
ori_api =
|
|
123
|
+
ori_api = None
|
|
124
|
+
for module in api_modules[0]:
|
|
125
|
+
ori_api = ori_api or _get_attr(module, api_name)
|
|
104
126
|
if callable(ori_api):
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
'It may be fixed by passing the value of "self" '
|
|
113
|
-
'as a positional argument instead of a keyword argument. ')
|
|
114
|
-
return api_func(*args, **kwargs)
|
|
115
|
-
return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
|
|
116
|
-
api_function.__name__ = api_name
|
|
117
|
-
return api_function
|
|
118
|
-
wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
|
|
119
|
-
hook_build_func, api_template)
|
|
127
|
+
wrapped_functions[api_name] = self.wrap_api_func(
|
|
128
|
+
api_name,
|
|
129
|
+
ori_api,
|
|
130
|
+
name_prefix,
|
|
131
|
+
hook_build_func,
|
|
132
|
+
api_template
|
|
133
|
+
)
|
|
120
134
|
wrapped_functions_in_framework[api_type] = wrapped_functions
|
|
121
135
|
self.wrapped_api_functions[framework] = wrapped_functions_in_framework
|
|
122
136
|
return self.wrapped_api_functions
|
|
@@ -132,15 +146,17 @@ class ApiWrapper:
|
|
|
132
146
|
api_from_file = api_list.get(key_in_file, [])
|
|
133
147
|
names = set()
|
|
134
148
|
for api_name in api_from_file:
|
|
135
|
-
if f'{key_in_file}.{api_name}' in self.
|
|
149
|
+
if f'{key_in_file}.{api_name}' in self.blacklist:
|
|
136
150
|
continue
|
|
137
151
|
target_attr = api_name
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
152
|
+
for module in api_modules[0]:
|
|
153
|
+
if Const.SEP in api_name:
|
|
154
|
+
sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
|
|
155
|
+
target_module = getattr(module, sub_module_name, None)
|
|
156
|
+
else:
|
|
157
|
+
target_module = module
|
|
158
|
+
if target_module and target_attr in dir(target_module):
|
|
159
|
+
names.add(api_name)
|
|
144
160
|
valid_names[api_type] = names
|
|
145
161
|
api_names[framework] = valid_names
|
|
146
162
|
|
|
@@ -152,7 +168,7 @@ class ApiRegistry:
|
|
|
152
168
|
Base class for api registry.
|
|
153
169
|
"""
|
|
154
170
|
|
|
155
|
-
def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates,
|
|
171
|
+
def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, blacklist=None):
|
|
156
172
|
self.ori_api_attr = dict()
|
|
157
173
|
self.wrapped_api_attr = dict()
|
|
158
174
|
self.inner_used_ori_attr = dict()
|
|
@@ -161,13 +177,16 @@ class ApiRegistry:
|
|
|
161
177
|
self.inner_used_api = inner_used_api
|
|
162
178
|
self.supported_api_list_path = supported_api_list_path
|
|
163
179
|
self.api_templates = api_templates
|
|
164
|
-
self.
|
|
180
|
+
self.blacklist = blacklist if blacklist else []
|
|
165
181
|
self.all_api_registered = False
|
|
166
182
|
|
|
167
183
|
@staticmethod
|
|
168
|
-
def store_ori_attr(
|
|
184
|
+
def store_ori_attr(ori_api_groups, api_list, api_ori_attr):
|
|
169
185
|
for api in api_list:
|
|
170
|
-
|
|
186
|
+
ori_api = None
|
|
187
|
+
for ori_api_group in ori_api_groups:
|
|
188
|
+
ori_api = ori_api or _get_attr(ori_api_group, api)
|
|
189
|
+
api_ori_attr[api] = ori_api
|
|
171
190
|
|
|
172
191
|
@staticmethod
|
|
173
192
|
def set_api_attr(api_group, attr_dict):
|
|
@@ -217,7 +236,7 @@ class ApiRegistry:
|
|
|
217
236
|
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
|
|
218
237
|
|
|
219
238
|
def initialize_hook(self, hook_build_func):
|
|
220
|
-
api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.
|
|
239
|
+
api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.blacklist)
|
|
221
240
|
wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
|
|
222
241
|
|
|
223
242
|
for framework, api_types in self.api_types.items():
|
|
@@ -23,6 +23,7 @@ from msprobe.core.data_dump.json_writer import DataWriter
|
|
|
23
23
|
from msprobe.core.common.log import logger
|
|
24
24
|
from msprobe.core.common.const import Const
|
|
25
25
|
from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
|
|
26
|
+
from msprobe.core.common.megatron_utils import MegatronStepInfo, get_micro_step, is_megatron
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
def build_data_collector(config):
|
|
@@ -41,6 +42,7 @@ class DataCollector:
|
|
|
41
42
|
self.module_count = {}
|
|
42
43
|
self.scope = ScopeFactory(self.config).build_scope()
|
|
43
44
|
self.backward_module_names = {}
|
|
45
|
+
self.params_grad_record = {}
|
|
44
46
|
self.optimizer_status = ""
|
|
45
47
|
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
46
48
|
atexit.register(self.write_json_at_exit)
|
|
@@ -118,12 +120,16 @@ class DataCollector:
|
|
|
118
120
|
self.set_is_recomputable(data_info, is_recompute)
|
|
119
121
|
if self.config.level == Const.LEVEL_L2:
|
|
120
122
|
return
|
|
123
|
+
self.call_stack_collect(name)
|
|
121
124
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
122
125
|
|
|
123
|
-
except Exception:
|
|
126
|
+
except Exception as e:
|
|
127
|
+
# 取异常类名作为“类型”做去重
|
|
128
|
+
error_type = type(e).__name__
|
|
124
129
|
tb = traceback.format_exc()
|
|
125
130
|
self.data_writer.write_error_log(
|
|
126
|
-
f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
131
|
+
f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}",
|
|
132
|
+
error_type=error_type
|
|
127
133
|
)
|
|
128
134
|
|
|
129
135
|
def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
@@ -139,13 +145,15 @@ class DataCollector:
|
|
|
139
145
|
self.set_is_recomputable(data_info, is_recompute)
|
|
140
146
|
if self.config.level == Const.LEVEL_L2:
|
|
141
147
|
return
|
|
142
|
-
self.call_stack_collect(name)
|
|
143
148
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
144
149
|
|
|
145
|
-
except Exception:
|
|
150
|
+
except Exception as e:
|
|
151
|
+
# 取异常类名作为“类型”做去重
|
|
152
|
+
error_type = type(e).__name__
|
|
146
153
|
tb = traceback.format_exc()
|
|
147
154
|
self.data_writer.write_error_log(
|
|
148
|
-
f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
155
|
+
f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}",
|
|
156
|
+
error_type=error_type
|
|
149
157
|
)
|
|
150
158
|
|
|
151
159
|
def forward_data_collect_only_tensor(self, name, module, pid, module_input_output):
|
|
@@ -154,10 +162,13 @@ class DataCollector:
|
|
|
154
162
|
return
|
|
155
163
|
self.data_processor.analyze_forward(name, module, module_input_output)
|
|
156
164
|
|
|
157
|
-
except Exception:
|
|
165
|
+
except Exception as e:
|
|
166
|
+
# 取异常类名作为“类型”做去重
|
|
167
|
+
error_type = type(e).__name__
|
|
158
168
|
tb = traceback.format_exc()
|
|
159
169
|
self.data_writer.write_error_log(
|
|
160
|
-
f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
|
|
170
|
+
f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}",
|
|
171
|
+
error_type=error_type
|
|
161
172
|
)
|
|
162
173
|
|
|
163
174
|
def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
@@ -173,10 +184,12 @@ class DataCollector:
|
|
|
173
184
|
self.call_stack_collect(name)
|
|
174
185
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
175
186
|
|
|
176
|
-
except Exception:
|
|
187
|
+
except Exception as e:
|
|
188
|
+
error_type = type(e).__name__
|
|
177
189
|
tb = traceback.format_exc()
|
|
178
190
|
self.data_writer.write_error_log(
|
|
179
|
-
f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
191
|
+
f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}",
|
|
192
|
+
error_type=error_type
|
|
180
193
|
)
|
|
181
194
|
|
|
182
195
|
def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None):
|
|
@@ -185,10 +198,12 @@ class DataCollector:
|
|
|
185
198
|
return
|
|
186
199
|
self.data_processor.analyze_backward(name, module, module_input_output)
|
|
187
200
|
|
|
188
|
-
except Exception:
|
|
201
|
+
except Exception as e:
|
|
202
|
+
error_type = type(e).__name__
|
|
189
203
|
tb = traceback.format_exc()
|
|
190
204
|
self.data_writer.write_error_log(
|
|
191
|
-
f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
|
|
205
|
+
f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}",
|
|
206
|
+
error_type=error_type
|
|
192
207
|
)
|
|
193
208
|
|
|
194
209
|
def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
@@ -206,10 +221,12 @@ class DataCollector:
|
|
|
206
221
|
self.backward_module_names[module_name] = True
|
|
207
222
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
208
223
|
|
|
209
|
-
except Exception:
|
|
224
|
+
except Exception as e:
|
|
225
|
+
error_type = type(e).__name__
|
|
210
226
|
tb = traceback.format_exc()
|
|
211
227
|
self.data_writer.write_error_log(
|
|
212
|
-
f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
228
|
+
f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}",
|
|
229
|
+
error_type=error_type
|
|
213
230
|
)
|
|
214
231
|
|
|
215
232
|
def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
@@ -223,10 +240,12 @@ class DataCollector:
|
|
|
223
240
|
self.set_is_recomputable(data_info, is_recompute)
|
|
224
241
|
self.handle_data(name, data_info)
|
|
225
242
|
|
|
226
|
-
except Exception:
|
|
243
|
+
except Exception as e:
|
|
244
|
+
error_type = type(e).__name__
|
|
227
245
|
tb = traceback.format_exc()
|
|
228
246
|
self.data_writer.write_error_log(
|
|
229
|
-
f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
247
|
+
f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}",
|
|
248
|
+
error_type=error_type
|
|
230
249
|
)
|
|
231
250
|
|
|
232
251
|
def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
@@ -240,25 +259,32 @@ class DataCollector:
|
|
|
240
259
|
self.set_is_recomputable(data_info, is_recompute)
|
|
241
260
|
self.handle_data(name, data_info)
|
|
242
261
|
|
|
243
|
-
except Exception:
|
|
262
|
+
except Exception as e:
|
|
263
|
+
error_type = type(e).__name__
|
|
244
264
|
tb = traceback.format_exc()
|
|
245
265
|
self.data_writer.write_error_log(
|
|
246
|
-
f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
266
|
+
f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}",
|
|
267
|
+
error_type=error_type
|
|
247
268
|
)
|
|
248
269
|
|
|
249
270
|
def update_construct(self, name):
|
|
250
271
|
if self.config.level not in DataCollector.level_without_construct:
|
|
251
272
|
if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
|
|
252
273
|
if self.optimizer_status_first_start[self.optimizer_status]:
|
|
253
|
-
self.data_writer.update_construct(
|
|
274
|
+
self.data_writer.update_construct(
|
|
275
|
+
{self.optimizer_status: None if not is_megatron() else [None, get_micro_step()]})
|
|
254
276
|
self.optimizer_status_first_start[self.optimizer_status] = False
|
|
255
|
-
self.data_writer.update_construct(
|
|
277
|
+
self.data_writer.update_construct(
|
|
278
|
+
{name: self.optimizer_status if not is_megatron() else [self.optimizer_status, get_micro_step()]})
|
|
256
279
|
else:
|
|
257
280
|
if self.config.level == Const.LEVEL_MIX and \
|
|
258
281
|
not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
|
|
259
282
|
self.data_writer.update_construct(
|
|
260
283
|
{name: self.module_processor.api_parent_node.get(threading.get_ident())}
|
|
261
284
|
)
|
|
285
|
+
if MegatronStepInfo.is_megatron:
|
|
286
|
+
micro_step_number = max(MegatronStepInfo.forward_micro_step, MegatronStepInfo.backward_micro_step)
|
|
287
|
+
self.data_writer.update_construct({Const.MEGATRON_MICRO_STEP_NUMBER: micro_step_number})
|
|
262
288
|
|
|
263
289
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
264
290
|
|
|
@@ -282,20 +308,36 @@ class DataCollector:
|
|
|
282
308
|
self.data_processor.update_iter(current_iter)
|
|
283
309
|
|
|
284
310
|
def params_data_collect(self, name, param_name, pid, data):
|
|
311
|
+
grad_name = name + Const.SEP + Const.PARAMS_GRAD
|
|
312
|
+
self.update_api_or_module_name(grad_name)
|
|
313
|
+
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
314
|
+
if self.data_writer.cache_data.get("data"):
|
|
315
|
+
self.data_writer.cache_data.get("data").pop(grad_name, None)
|
|
316
|
+
self.params_grad_record[grad_name] = False
|
|
317
|
+
return
|
|
318
|
+
data_info = self.data_processor.analyze_params(grad_name, param_name, data)
|
|
319
|
+
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
320
|
+
self.params_grad_record[grad_name] = False
|
|
321
|
+
|
|
322
|
+
def params_data_collect_in_bw_hook(self, params_dict, name):
|
|
285
323
|
try:
|
|
286
|
-
|
|
287
|
-
self.update_api_or_module_name(grad_name)
|
|
288
|
-
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
289
|
-
if self.data_writer.cache_data.get("data"):
|
|
290
|
-
self.data_writer.cache_data.get("data").pop(grad_name, None)
|
|
324
|
+
if not params_dict:
|
|
291
325
|
return
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
326
|
+
ori_name = name.rsplit(Const.SEP, 2)[0]
|
|
327
|
+
for param_name, param in params_dict.items():
|
|
328
|
+
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
329
|
+
self.update_api_or_module_name(grad_name)
|
|
330
|
+
if self.params_grad_record.get(grad_name, False):
|
|
331
|
+
grad = param.grad if hasattr(param, "grad") else None
|
|
332
|
+
data_info = self.data_processor.analyze_params(grad_name, param_name, grad)
|
|
333
|
+
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
334
|
+
except Exception as e:
|
|
335
|
+
error_type = type(e).__name__
|
|
295
336
|
tb = traceback.format_exc()
|
|
296
337
|
self.data_writer.write_error_log(
|
|
297
|
-
f"[ERROR]
|
|
298
|
-
f"name={name},
|
|
338
|
+
f"[ERROR] params_data_collect_in_bw_hook failed: "
|
|
339
|
+
f"name={name}",
|
|
340
|
+
error_type=error_type
|
|
299
341
|
)
|
|
300
342
|
|
|
301
343
|
def debug_data_collect_forward(self, variable, name_with_count):
|
|
@@ -94,6 +94,8 @@ class BaseDataProcessor:
|
|
|
94
94
|
def __init__(self, config, data_writer):
|
|
95
95
|
self.data_writer = data_writer
|
|
96
96
|
self.config = config
|
|
97
|
+
if self.data_writer is not None:
|
|
98
|
+
self.data_writer.config = config
|
|
97
99
|
self.api_info_struct = {}
|
|
98
100
|
self.stack_info_struct = {}
|
|
99
101
|
self.current_api_or_module_name = None
|