mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -23,48 +23,20 @@ from tqdm import tqdm
|
|
|
23
23
|
from msprobe.core.common.log import logger
|
|
24
24
|
from msprobe.core.common.utils import CompareException
|
|
25
25
|
from msprobe.core.common.const import CompareConst
|
|
26
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
27
|
+
from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
|
|
28
|
+
from msprobe.core.compare.config import ModeConfig
|
|
26
29
|
|
|
27
30
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
results = []
|
|
39
|
-
pool = multiprocessing.Pool(process_num)
|
|
40
|
-
|
|
41
|
-
def err_call(args):
|
|
42
|
-
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
43
|
-
try:
|
|
44
|
-
pool.terminate()
|
|
45
|
-
except OSError as e:
|
|
46
|
-
logger.error("pool terminate failed")
|
|
47
|
-
|
|
48
|
-
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
49
|
-
|
|
50
|
-
def update_progress(size, progress_lock, extra_param=None):
|
|
51
|
-
with progress_lock:
|
|
52
|
-
progress_bar.update(size)
|
|
53
|
-
|
|
54
|
-
for process_idx, df_chunk in enumerate(df_chunks):
|
|
55
|
-
idx = df_chunk_size * process_idx
|
|
56
|
-
chunk_size = len(df_chunk)
|
|
57
|
-
result = pool.apply_async(func,
|
|
58
|
-
args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
|
|
59
|
-
error_callback=err_call,
|
|
60
|
-
callback=partial(update_progress, chunk_size, lock)
|
|
61
|
-
)
|
|
62
|
-
results.append(result)
|
|
63
|
-
|
|
64
|
-
final_results = [r.get() for r in results]
|
|
65
|
-
pool.close()
|
|
66
|
-
pool.join()
|
|
67
|
-
return pd.concat(final_results, ignore_index=True)
|
|
31
|
+
@dataclass
|
|
32
|
+
class ComparisonResult:
|
|
33
|
+
cos_result: list
|
|
34
|
+
euc_dist_result: list
|
|
35
|
+
max_err_result: list
|
|
36
|
+
max_relative_err_result: list
|
|
37
|
+
one_thousand_err_ratio_result: list
|
|
38
|
+
five_thousand_err_ratio_result: list
|
|
39
|
+
err_msgs: list
|
|
68
40
|
|
|
69
41
|
|
|
70
42
|
def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
@@ -81,9 +53,9 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
|
81
53
|
def err_call(args):
|
|
82
54
|
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
83
55
|
try:
|
|
84
|
-
pool.
|
|
56
|
+
pool.close()
|
|
85
57
|
except OSError as e:
|
|
86
|
-
logger.error(
|
|
58
|
+
logger.error(f'pool terminate failed: {str(e)}')
|
|
87
59
|
|
|
88
60
|
for df_chunk in df_chunks:
|
|
89
61
|
result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
|
|
@@ -94,74 +66,6 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
|
94
66
|
return pd.concat(final_results, ignore_index=True)
|
|
95
67
|
|
|
96
68
|
|
|
97
|
-
def read_dump_data(result_df):
|
|
98
|
-
try:
|
|
99
|
-
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
100
|
-
dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
|
|
101
|
-
op_name_mapping_dict = {}
|
|
102
|
-
for index, _ in enumerate(npu_dump_name_list):
|
|
103
|
-
npu_dump_name = npu_dump_name_list[index]
|
|
104
|
-
dump_tensor_pair = dump_tensor_pair_list[index]
|
|
105
|
-
op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
|
|
106
|
-
return op_name_mapping_dict
|
|
107
|
-
except ValueError as e:
|
|
108
|
-
logger.error('result dataframe is not found.')
|
|
109
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
110
|
-
except IndexError as e:
|
|
111
|
-
logger.error('result dataframe elements can not be access.')
|
|
112
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
@dataclass
|
|
116
|
-
class ComparisonResult:
|
|
117
|
-
cos_result: list
|
|
118
|
-
euc_dist_result: list
|
|
119
|
-
max_err_result: list
|
|
120
|
-
max_relative_err_result: list
|
|
121
|
-
one_thousand_err_ratio_result: list
|
|
122
|
-
five_thousand_err_ratio_result: list
|
|
123
|
-
err_msgs: list
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
127
|
-
"""
|
|
128
|
-
Save comparison results into the result DataFrame with thread safety.
|
|
129
|
-
Args:
|
|
130
|
-
offset: offset for index
|
|
131
|
-
result: data struct of ComparisonResult
|
|
132
|
-
result_df: result of DataFrame
|
|
133
|
-
lock: thread lock
|
|
134
|
-
|
|
135
|
-
Returns:
|
|
136
|
-
comparison results in DataFrame
|
|
137
|
-
"""
|
|
138
|
-
|
|
139
|
-
lock.acquire()
|
|
140
|
-
try:
|
|
141
|
-
for i, _ in enumerate(result.cos_result):
|
|
142
|
-
process_index = i + offset
|
|
143
|
-
result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
|
|
144
|
-
result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
|
|
145
|
-
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
146
|
-
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
147
|
-
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
|
|
148
|
-
result.one_thousand_err_ratio_result)[i]
|
|
149
|
-
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
|
|
150
|
-
result.five_thousand_err_ratio_result)[i]
|
|
151
|
-
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
152
|
-
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
153
|
-
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
154
|
-
return result_df
|
|
155
|
-
except ValueError as e:
|
|
156
|
-
logger.error('result dataframe is not found.')
|
|
157
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
158
|
-
except IndexError as e:
|
|
159
|
-
logger.error('result dataframe elements can not be access.')
|
|
160
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
161
|
-
finally:
|
|
162
|
-
lock.release()
|
|
163
|
-
|
|
164
|
-
|
|
165
69
|
def check_accuracy(cos, max_abs_err):
|
|
166
70
|
if cos == CompareConst.SHAPE_UNMATCH:
|
|
167
71
|
return CompareConst.ACCURACY_CHECK_UNMATCH
|
|
@@ -179,3 +83,212 @@ def check_accuracy(cos, max_abs_err):
|
|
|
179
83
|
if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
|
|
180
84
|
return CompareConst.ACCURACY_CHECK_NO
|
|
181
85
|
return CompareConst.ACCURACY_CHECK_YES
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class CompareRealData:
|
|
89
|
+
def __init__(self, file_reader, mode_config: ModeConfig, cross_frame):
|
|
90
|
+
self.file_reader = file_reader
|
|
91
|
+
self.mode_config = mode_config
|
|
92
|
+
self.cross_frame = cross_frame
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def read_dump_data(result_df):
|
|
96
|
+
try:
|
|
97
|
+
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
98
|
+
dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
|
|
99
|
+
op_name_mapping_dict = {}
|
|
100
|
+
for index, npu_dump_name in enumerate(npu_dump_name_list):
|
|
101
|
+
dump_tensor_pair = dump_tensor_pair_list[index]
|
|
102
|
+
op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
|
|
103
|
+
return op_name_mapping_dict
|
|
104
|
+
except ValueError as e:
|
|
105
|
+
logger.error('result dataframe is not found.')
|
|
106
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
107
|
+
except IndexError as e:
|
|
108
|
+
logger.error('result dataframe elements can not be access.')
|
|
109
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
113
|
+
"""
|
|
114
|
+
Save comparison results into the result DataFrame with thread safety.
|
|
115
|
+
Args:
|
|
116
|
+
offset: offset for index
|
|
117
|
+
result: data struct of ComparisonResult
|
|
118
|
+
result_df: result of DataFrame
|
|
119
|
+
lock: thread lock
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
comparison results in DataFrame
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
lock.acquire()
|
|
126
|
+
try:
|
|
127
|
+
for i, cos_item in enumerate(result.cos_result):
|
|
128
|
+
process_index = i + offset
|
|
129
|
+
result_df.loc[process_index, CompareConst.COSINE] = cos_item
|
|
130
|
+
result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
|
|
131
|
+
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
132
|
+
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
133
|
+
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
|
|
134
|
+
result.one_thousand_err_ratio_result)[i]
|
|
135
|
+
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
|
|
136
|
+
result.five_thousand_err_ratio_result)[i]
|
|
137
|
+
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
138
|
+
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
139
|
+
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
140
|
+
return result_df
|
|
141
|
+
except ValueError as e:
|
|
142
|
+
logger.error('result dataframe is not found.')
|
|
143
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
144
|
+
except IndexError as e:
|
|
145
|
+
logger.error('result dataframe elements can not be access.')
|
|
146
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
147
|
+
finally:
|
|
148
|
+
lock.release()
|
|
149
|
+
|
|
150
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
151
|
+
"""
|
|
152
|
+
:param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
|
|
153
|
+
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
154
|
+
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
155
|
+
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
156
|
+
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
157
|
+
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
|
|
158
|
+
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
159
|
+
"""
|
|
160
|
+
error_file, relative_err, error_flag = None, None, False
|
|
161
|
+
|
|
162
|
+
data_name_pair = op_name_mapping_dict.get(npu_op_name)
|
|
163
|
+
npu_data_name = data_name_pair[0]
|
|
164
|
+
bench_data_name = data_name_pair[1]
|
|
165
|
+
|
|
166
|
+
if str(npu_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有npu真实数据
|
|
167
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
168
|
+
elif str(bench_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有bench真实数据
|
|
169
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
170
|
+
error_file = 'no_bench_data'
|
|
171
|
+
elif str(bench_data_name) == CompareConst.N_A: # bench没匹配
|
|
172
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
173
|
+
error_file = None
|
|
174
|
+
else:
|
|
175
|
+
npu_dir = input_param.get(CompareConst.NPU_DUMP_DATA_DIR)
|
|
176
|
+
bench_dir = input_param.get(CompareConst.BENCH_DUMP_DATA_DIR)
|
|
177
|
+
try:
|
|
178
|
+
n_value, b_value = self.file_reader(npu_dir, npu_data_name, bench_dir, bench_data_name,
|
|
179
|
+
self.cross_frame)
|
|
180
|
+
except IOError as error:
|
|
181
|
+
error_file = error.filename
|
|
182
|
+
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
183
|
+
error_flag = True
|
|
184
|
+
except (FileCheckException, CompareException):
|
|
185
|
+
error_file = data_name_pair
|
|
186
|
+
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
187
|
+
error_flag = True
|
|
188
|
+
|
|
189
|
+
# 通过n_value, b_value同时得到错误标志和错误信息
|
|
190
|
+
n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
|
|
191
|
+
error_flag=error_flag, error_file=error_file)
|
|
192
|
+
|
|
193
|
+
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
|
|
194
|
+
|
|
195
|
+
if self.mode_config.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
196
|
+
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
197
|
+
result_list.append(err_msg)
|
|
198
|
+
return result_list
|
|
199
|
+
|
|
200
|
+
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
201
|
+
cos_result = []
|
|
202
|
+
euc_dist_result = []
|
|
203
|
+
max_err_result = []
|
|
204
|
+
max_relative_err_result = []
|
|
205
|
+
one_thousand_err_ratio_result = []
|
|
206
|
+
five_thousand_err_ratio_result = []
|
|
207
|
+
err_mess = []
|
|
208
|
+
|
|
209
|
+
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
210
|
+
|
|
211
|
+
for i in range(len(result_df)):
|
|
212
|
+
npu_op_name = result_df.iloc[i, 0]
|
|
213
|
+
bench_op_name = result_df.iloc[i, 1]
|
|
214
|
+
if is_print_compare_log:
|
|
215
|
+
logger.info("start compare: {}".format(npu_op_name))
|
|
216
|
+
|
|
217
|
+
cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
|
|
218
|
+
= self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
219
|
+
|
|
220
|
+
if is_print_compare_log:
|
|
221
|
+
logger.info(
|
|
222
|
+
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
|
|
223
|
+
one_thousand_err_ratio {}, "
|
|
224
|
+
"five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
|
|
225
|
+
err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
|
|
226
|
+
cos_result.append(cos_sim)
|
|
227
|
+
euc_dist_result.append(euc_dist)
|
|
228
|
+
max_err_result.append(max_abs_err)
|
|
229
|
+
max_relative_err_result.append(max_relative_err)
|
|
230
|
+
one_thousand_err_ratio_result.append(one_thousand_err_ratio)
|
|
231
|
+
five_thousand_err_ratio_result.append(five_thousand_err_ratio)
|
|
232
|
+
err_mess.append(err_msg)
|
|
233
|
+
|
|
234
|
+
cr = ComparisonResult(
|
|
235
|
+
cos_result=cos_result,
|
|
236
|
+
euc_dist_result=euc_dist_result,
|
|
237
|
+
max_err_result=max_err_result,
|
|
238
|
+
max_relative_err_result=max_relative_err_result,
|
|
239
|
+
one_thousand_err_ratio_result=one_thousand_err_ratio_result,
|
|
240
|
+
five_thousand_err_ratio_result=five_thousand_err_ratio_result,
|
|
241
|
+
err_msgs=err_mess
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return self._save_cmp_result(idx, cr, result_df, lock)
|
|
245
|
+
|
|
246
|
+
def do_multi_process(self, input_param, result_df):
|
|
247
|
+
try:
|
|
248
|
+
result_df = self._handle_multi_process(self.compare_ops, input_param, result_df,
|
|
249
|
+
multiprocessing.Manager().RLock())
|
|
250
|
+
return result_df
|
|
251
|
+
except ValueError as e:
|
|
252
|
+
logger.error('result dataframe is not found.')
|
|
253
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
254
|
+
|
|
255
|
+
def _handle_multi_process(self, func, input_param, result_df, lock):
|
|
256
|
+
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
257
|
+
op_name_mapping_dict = self.read_dump_data(result_df)
|
|
258
|
+
|
|
259
|
+
df_chunk_size = len(result_df) // process_num
|
|
260
|
+
if df_chunk_size > 0:
|
|
261
|
+
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
262
|
+
else:
|
|
263
|
+
df_chunks = [result_df]
|
|
264
|
+
|
|
265
|
+
results = []
|
|
266
|
+
pool = multiprocessing.Pool(process_num)
|
|
267
|
+
|
|
268
|
+
def err_call(args):
|
|
269
|
+
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
270
|
+
try:
|
|
271
|
+
pool.close()
|
|
272
|
+
except OSError:
|
|
273
|
+
logger.error("pool terminate failed")
|
|
274
|
+
|
|
275
|
+
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
276
|
+
|
|
277
|
+
def update_progress(size, progress_lock, extra_param=None):
|
|
278
|
+
with progress_lock:
|
|
279
|
+
progress_bar.update(size)
|
|
280
|
+
|
|
281
|
+
for process_idx, df_chunk in enumerate(df_chunks):
|
|
282
|
+
idx = df_chunk_size * process_idx
|
|
283
|
+
chunk_size = len(df_chunk)
|
|
284
|
+
result = pool.apply_async(func,
|
|
285
|
+
args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
|
|
286
|
+
error_callback=err_call,
|
|
287
|
+
callback=partial(update_progress, chunk_size, lock)
|
|
288
|
+
)
|
|
289
|
+
results.append(result)
|
|
290
|
+
|
|
291
|
+
final_results = [r.get() for r in results]
|
|
292
|
+
pool.close()
|
|
293
|
+
pool.join()
|
|
294
|
+
return pd.concat(final_results, ignore_index=True)
|
|
@@ -59,7 +59,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
|
|
|
59
59
|
if error_file == "no_bench_data":
|
|
60
60
|
err_msg = "Bench does not have data file."
|
|
61
61
|
elif error_file:
|
|
62
|
-
err_msg = f"Dump file: {error_file} not found."
|
|
62
|
+
err_msg = f"Dump file: {error_file} not found or read failed."
|
|
63
63
|
else:
|
|
64
64
|
err_msg = CompareConst.NO_BENCH
|
|
65
65
|
error_flag = True
|
|
@@ -290,10 +290,8 @@ class CompareOps:
|
|
|
290
290
|
|
|
291
291
|
|
|
292
292
|
def error_value_process(n_value):
|
|
293
|
-
if n_value
|
|
293
|
+
if n_value in [CompareConst.READ_NONE, CompareConst.UNREADABLE, CompareConst.NONE]:
|
|
294
294
|
return CompareConst.UNSUPPORTED, ""
|
|
295
|
-
if n_value == CompareConst.NONE:
|
|
296
|
-
return 0, ""
|
|
297
295
|
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
298
296
|
return CompareConst.SHAPE_UNMATCH, ""
|
|
299
297
|
if n_value == CompareConst.NAN:
|