mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,51 +15,28 @@
|
|
|
15
15
|
|
|
16
16
|
import multiprocessing
|
|
17
17
|
from dataclasses import dataclass
|
|
18
|
+
from functools import partial
|
|
19
|
+
|
|
18
20
|
import pandas as pd
|
|
19
21
|
from tqdm import tqdm
|
|
22
|
+
|
|
20
23
|
from msprobe.core.common.log import logger
|
|
21
24
|
from msprobe.core.common.utils import CompareException
|
|
22
25
|
from msprobe.core.common.const import CompareConst
|
|
26
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
27
|
+
from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
|
|
28
|
+
from msprobe.core.compare.config import ModeConfig
|
|
23
29
|
|
|
24
30
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
results = []
|
|
36
|
-
pool = multiprocessing.Pool(process_num)
|
|
37
|
-
|
|
38
|
-
def err_call(args):
|
|
39
|
-
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
40
|
-
try:
|
|
41
|
-
pool.terminate()
|
|
42
|
-
except OSError as e:
|
|
43
|
-
logger.error("pool terminate failed")
|
|
44
|
-
|
|
45
|
-
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
46
|
-
|
|
47
|
-
def update_progress(size, progress_lock):
|
|
48
|
-
with progress_lock:
|
|
49
|
-
progress_bar.update(size)
|
|
50
|
-
|
|
51
|
-
for process_idx, df_chunk in enumerate(df_chunks):
|
|
52
|
-
idx = df_chunk_size * process_idx
|
|
53
|
-
chunk_size = len(df_chunk)
|
|
54
|
-
result = pool.apply_async(func,
|
|
55
|
-
args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
|
|
56
|
-
error_callback=err_call,
|
|
57
|
-
callback=update_progress(chunk_size, lock))
|
|
58
|
-
results.append(result)
|
|
59
|
-
final_results = [r.get() for r in results]
|
|
60
|
-
pool.close()
|
|
61
|
-
pool.join()
|
|
62
|
-
return pd.concat(final_results, ignore_index=True)
|
|
31
|
+
@dataclass
|
|
32
|
+
class ComparisonResult:
|
|
33
|
+
cos_result: list
|
|
34
|
+
euc_dist_result: list
|
|
35
|
+
max_err_result: list
|
|
36
|
+
max_relative_err_result: list
|
|
37
|
+
one_thousand_err_ratio_result: list
|
|
38
|
+
five_thousand_err_ratio_result: list
|
|
39
|
+
err_msgs: list
|
|
63
40
|
|
|
64
41
|
|
|
65
42
|
def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
@@ -76,9 +53,9 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
|
76
53
|
def err_call(args):
|
|
77
54
|
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
78
55
|
try:
|
|
79
|
-
pool.
|
|
56
|
+
pool.close()
|
|
80
57
|
except OSError as e:
|
|
81
|
-
logger.error(
|
|
58
|
+
logger.error(f'pool terminate failed: {str(e)}')
|
|
82
59
|
|
|
83
60
|
for df_chunk in df_chunks:
|
|
84
61
|
result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
|
|
@@ -89,72 +66,6 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
|
89
66
|
return pd.concat(final_results, ignore_index=True)
|
|
90
67
|
|
|
91
68
|
|
|
92
|
-
def read_dump_data(result_df):
|
|
93
|
-
try:
|
|
94
|
-
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
95
|
-
npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
|
|
96
|
-
op_name_mapping_dict = {}
|
|
97
|
-
for index, _ in enumerate(npu_dump_name_list):
|
|
98
|
-
npu_dump_name = npu_dump_name_list[index]
|
|
99
|
-
npu_dump_tensor = npu_dump_tensor_list[index]
|
|
100
|
-
op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
|
|
101
|
-
return op_name_mapping_dict
|
|
102
|
-
except ValueError as e:
|
|
103
|
-
logger.error('result dataframe is not found.')
|
|
104
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
105
|
-
except IndexError as e:
|
|
106
|
-
logger.error('result dataframe elements can not be access.')
|
|
107
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
@dataclass
|
|
111
|
-
class ComparisonResult:
|
|
112
|
-
cos_result: list
|
|
113
|
-
max_err_result: list
|
|
114
|
-
max_relative_err_result: list
|
|
115
|
-
err_msgs: list
|
|
116
|
-
one_thousand_err_ratio_result: list
|
|
117
|
-
five_thousand_err_ratio_result: list
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
121
|
-
"""
|
|
122
|
-
Save comparison results into the result DataFrame with thread safety.
|
|
123
|
-
Args:
|
|
124
|
-
offset: offset for index
|
|
125
|
-
result: data struct of ComparisonResult
|
|
126
|
-
result_df: result of DataFrame
|
|
127
|
-
lock: thread lock
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
comparison results in DataFrame
|
|
131
|
-
"""
|
|
132
|
-
|
|
133
|
-
lock.acquire()
|
|
134
|
-
try:
|
|
135
|
-
for i, _ in enumerate(result.cos_result):
|
|
136
|
-
process_index = i + offset
|
|
137
|
-
result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
|
|
138
|
-
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
139
|
-
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
140
|
-
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
141
|
-
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
142
|
-
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
143
|
-
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
|
|
144
|
-
result.one_thousand_err_ratio_result)[i]
|
|
145
|
-
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
|
|
146
|
-
result.five_thousand_err_ratio_result)[i]
|
|
147
|
-
return result_df
|
|
148
|
-
except ValueError as e:
|
|
149
|
-
logger.error('result dataframe is not found.')
|
|
150
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
151
|
-
except IndexError as e:
|
|
152
|
-
logger.error('result dataframe elements can not be access.')
|
|
153
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
154
|
-
finally:
|
|
155
|
-
lock.release()
|
|
156
|
-
|
|
157
|
-
|
|
158
69
|
def check_accuracy(cos, max_abs_err):
|
|
159
70
|
if cos == CompareConst.SHAPE_UNMATCH:
|
|
160
71
|
return CompareConst.ACCURACY_CHECK_UNMATCH
|
|
@@ -172,3 +83,212 @@ def check_accuracy(cos, max_abs_err):
|
|
|
172
83
|
if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
|
|
173
84
|
return CompareConst.ACCURACY_CHECK_NO
|
|
174
85
|
return CompareConst.ACCURACY_CHECK_YES
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class CompareRealData:
|
|
89
|
+
def __init__(self, file_reader, mode_config: ModeConfig, cross_frame):
|
|
90
|
+
self.file_reader = file_reader
|
|
91
|
+
self.mode_config = mode_config
|
|
92
|
+
self.cross_frame = cross_frame
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def read_dump_data(result_df):
|
|
96
|
+
try:
|
|
97
|
+
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
98
|
+
dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
|
|
99
|
+
op_name_mapping_dict = {}
|
|
100
|
+
for index, npu_dump_name in enumerate(npu_dump_name_list):
|
|
101
|
+
dump_tensor_pair = dump_tensor_pair_list[index]
|
|
102
|
+
op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
|
|
103
|
+
return op_name_mapping_dict
|
|
104
|
+
except ValueError as e:
|
|
105
|
+
logger.error('result dataframe is not found.')
|
|
106
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
107
|
+
except IndexError as e:
|
|
108
|
+
logger.error('result dataframe elements can not be access.')
|
|
109
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
113
|
+
"""
|
|
114
|
+
Save comparison results into the result DataFrame with thread safety.
|
|
115
|
+
Args:
|
|
116
|
+
offset: offset for index
|
|
117
|
+
result: data struct of ComparisonResult
|
|
118
|
+
result_df: result of DataFrame
|
|
119
|
+
lock: thread lock
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
comparison results in DataFrame
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
lock.acquire()
|
|
126
|
+
try:
|
|
127
|
+
for i, cos_item in enumerate(result.cos_result):
|
|
128
|
+
process_index = i + offset
|
|
129
|
+
result_df.loc[process_index, CompareConst.COSINE] = cos_item
|
|
130
|
+
result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
|
|
131
|
+
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
132
|
+
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
133
|
+
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
|
|
134
|
+
result.one_thousand_err_ratio_result)[i]
|
|
135
|
+
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
|
|
136
|
+
result.five_thousand_err_ratio_result)[i]
|
|
137
|
+
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
138
|
+
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
139
|
+
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
140
|
+
return result_df
|
|
141
|
+
except ValueError as e:
|
|
142
|
+
logger.error('result dataframe is not found.')
|
|
143
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
144
|
+
except IndexError as e:
|
|
145
|
+
logger.error('result dataframe elements can not be access.')
|
|
146
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
147
|
+
finally:
|
|
148
|
+
lock.release()
|
|
149
|
+
|
|
150
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
151
|
+
"""
|
|
152
|
+
:param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
|
|
153
|
+
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
154
|
+
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
155
|
+
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
156
|
+
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
157
|
+
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
|
|
158
|
+
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
159
|
+
"""
|
|
160
|
+
error_file, relative_err, error_flag = None, None, False
|
|
161
|
+
|
|
162
|
+
data_name_pair = op_name_mapping_dict.get(npu_op_name)
|
|
163
|
+
npu_data_name = data_name_pair[0]
|
|
164
|
+
bench_data_name = data_name_pair[1]
|
|
165
|
+
|
|
166
|
+
if str(npu_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有npu真实数据
|
|
167
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
168
|
+
elif str(bench_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有bench真实数据
|
|
169
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
170
|
+
error_file = 'no_bench_data'
|
|
171
|
+
elif str(bench_data_name) == CompareConst.N_A: # bench没匹配
|
|
172
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
173
|
+
error_file = None
|
|
174
|
+
else:
|
|
175
|
+
npu_dir = input_param.get(CompareConst.NPU_DUMP_DATA_DIR)
|
|
176
|
+
bench_dir = input_param.get(CompareConst.BENCH_DUMP_DATA_DIR)
|
|
177
|
+
try:
|
|
178
|
+
n_value, b_value = self.file_reader(npu_dir, npu_data_name, bench_dir, bench_data_name,
|
|
179
|
+
self.cross_frame)
|
|
180
|
+
except IOError as error:
|
|
181
|
+
error_file = error.filename
|
|
182
|
+
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
183
|
+
error_flag = True
|
|
184
|
+
except (FileCheckException, CompareException):
|
|
185
|
+
error_file = data_name_pair
|
|
186
|
+
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
187
|
+
error_flag = True
|
|
188
|
+
|
|
189
|
+
# 通过n_value, b_value同时得到错误标志和错误信息
|
|
190
|
+
n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
|
|
191
|
+
error_flag=error_flag, error_file=error_file)
|
|
192
|
+
|
|
193
|
+
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
|
|
194
|
+
|
|
195
|
+
if self.mode_config.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
196
|
+
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
197
|
+
result_list.append(err_msg)
|
|
198
|
+
return result_list
|
|
199
|
+
|
|
200
|
+
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
201
|
+
cos_result = []
|
|
202
|
+
euc_dist_result = []
|
|
203
|
+
max_err_result = []
|
|
204
|
+
max_relative_err_result = []
|
|
205
|
+
one_thousand_err_ratio_result = []
|
|
206
|
+
five_thousand_err_ratio_result = []
|
|
207
|
+
err_mess = []
|
|
208
|
+
|
|
209
|
+
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
210
|
+
|
|
211
|
+
for i in range(len(result_df)):
|
|
212
|
+
npu_op_name = result_df.iloc[i, 0]
|
|
213
|
+
bench_op_name = result_df.iloc[i, 1]
|
|
214
|
+
if is_print_compare_log:
|
|
215
|
+
logger.info("start compare: {}".format(npu_op_name))
|
|
216
|
+
|
|
217
|
+
cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
|
|
218
|
+
= self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
219
|
+
|
|
220
|
+
if is_print_compare_log:
|
|
221
|
+
logger.info(
|
|
222
|
+
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
|
|
223
|
+
one_thousand_err_ratio {}, "
|
|
224
|
+
"five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
|
|
225
|
+
err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
|
|
226
|
+
cos_result.append(cos_sim)
|
|
227
|
+
euc_dist_result.append(euc_dist)
|
|
228
|
+
max_err_result.append(max_abs_err)
|
|
229
|
+
max_relative_err_result.append(max_relative_err)
|
|
230
|
+
one_thousand_err_ratio_result.append(one_thousand_err_ratio)
|
|
231
|
+
five_thousand_err_ratio_result.append(five_thousand_err_ratio)
|
|
232
|
+
err_mess.append(err_msg)
|
|
233
|
+
|
|
234
|
+
cr = ComparisonResult(
|
|
235
|
+
cos_result=cos_result,
|
|
236
|
+
euc_dist_result=euc_dist_result,
|
|
237
|
+
max_err_result=max_err_result,
|
|
238
|
+
max_relative_err_result=max_relative_err_result,
|
|
239
|
+
one_thousand_err_ratio_result=one_thousand_err_ratio_result,
|
|
240
|
+
five_thousand_err_ratio_result=five_thousand_err_ratio_result,
|
|
241
|
+
err_msgs=err_mess
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return self._save_cmp_result(idx, cr, result_df, lock)
|
|
245
|
+
|
|
246
|
+
def do_multi_process(self, input_param, result_df):
|
|
247
|
+
try:
|
|
248
|
+
result_df = self._handle_multi_process(self.compare_ops, input_param, result_df,
|
|
249
|
+
multiprocessing.Manager().RLock())
|
|
250
|
+
return result_df
|
|
251
|
+
except ValueError as e:
|
|
252
|
+
logger.error('result dataframe is not found.')
|
|
253
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
254
|
+
|
|
255
|
+
def _handle_multi_process(self, func, input_param, result_df, lock):
|
|
256
|
+
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
257
|
+
op_name_mapping_dict = self.read_dump_data(result_df)
|
|
258
|
+
|
|
259
|
+
df_chunk_size = len(result_df) // process_num
|
|
260
|
+
if df_chunk_size > 0:
|
|
261
|
+
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
262
|
+
else:
|
|
263
|
+
df_chunks = [result_df]
|
|
264
|
+
|
|
265
|
+
results = []
|
|
266
|
+
pool = multiprocessing.Pool(process_num)
|
|
267
|
+
|
|
268
|
+
def err_call(args):
|
|
269
|
+
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
270
|
+
try:
|
|
271
|
+
pool.close()
|
|
272
|
+
except OSError:
|
|
273
|
+
logger.error("pool terminate failed")
|
|
274
|
+
|
|
275
|
+
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
276
|
+
|
|
277
|
+
def update_progress(size, progress_lock, extra_param=None):
|
|
278
|
+
with progress_lock:
|
|
279
|
+
progress_bar.update(size)
|
|
280
|
+
|
|
281
|
+
for process_idx, df_chunk in enumerate(df_chunks):
|
|
282
|
+
idx = df_chunk_size * process_idx
|
|
283
|
+
chunk_size = len(df_chunk)
|
|
284
|
+
result = pool.apply_async(func,
|
|
285
|
+
args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
|
|
286
|
+
error_callback=err_call,
|
|
287
|
+
callback=partial(update_progress, chunk_size, lock)
|
|
288
|
+
)
|
|
289
|
+
results.append(result)
|
|
290
|
+
|
|
291
|
+
final_results = [r.get() for r in results]
|
|
292
|
+
pool.close()
|
|
293
|
+
pool.join()
|
|
294
|
+
return pd.concat(final_results, ignore_index=True)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -59,7 +59,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
|
|
|
59
59
|
if error_file == "no_bench_data":
|
|
60
60
|
err_msg = "Bench does not have data file."
|
|
61
61
|
elif error_file:
|
|
62
|
-
err_msg = f"Dump file: {error_file} not found."
|
|
62
|
+
err_msg = f"Dump file: {error_file} not found or read failed."
|
|
63
63
|
else:
|
|
64
64
|
err_msg = CompareConst.NO_BENCH
|
|
65
65
|
error_flag = True
|
|
@@ -70,7 +70,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
|
|
|
70
70
|
error_flag = True
|
|
71
71
|
return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
|
|
72
72
|
if not n_value.shape: # 判断数据是否为0维张量
|
|
73
|
-
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
|
|
73
|
+
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', "
|
|
74
74
|
f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
|
|
75
75
|
error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
|
|
76
76
|
return n_value, b_value, error_flag, err_msg
|
|
@@ -168,8 +168,9 @@ def statistics_data_check(result_dict):
|
|
|
168
168
|
|
|
169
169
|
class TensorComparisonBasic(abc.ABC):
|
|
170
170
|
"""NPU和bench中npy数据的比较模板"""
|
|
171
|
+
|
|
171
172
|
@abc.abstractmethod
|
|
172
|
-
def apply(self, n_value, b_value, relative_err):
|
|
173
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
173
174
|
raise NotImplementedError
|
|
174
175
|
|
|
175
176
|
|
|
@@ -190,6 +191,7 @@ def get_relative_err(n_value, b_value):
|
|
|
190
191
|
|
|
191
192
|
class GetCosineSimilarity(TensorComparisonBasic):
|
|
192
193
|
"""计算cosine相似度"""
|
|
194
|
+
|
|
193
195
|
@staticmethod
|
|
194
196
|
def correct_data(result):
|
|
195
197
|
if result == CompareConst.NAN:
|
|
@@ -198,9 +200,9 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
198
200
|
return round(float(result), 6)
|
|
199
201
|
return result
|
|
200
202
|
|
|
201
|
-
def apply(self, n_value, b_value, relative_err):
|
|
202
|
-
if
|
|
203
|
-
return CompareConst.UNSUPPORTED,
|
|
203
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
204
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
205
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
204
206
|
|
|
205
207
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
206
208
|
if len(n_value) == 1:
|
|
@@ -224,9 +226,22 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
224
226
|
return result, ""
|
|
225
227
|
|
|
226
228
|
|
|
229
|
+
class GetEuclideanDistance(TensorComparisonBasic):
|
|
230
|
+
"""计算欧式距离"""
|
|
231
|
+
|
|
232
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
233
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
234
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
235
|
+
|
|
236
|
+
distance = np.linalg.norm(n_value - b_value, ord=2)
|
|
237
|
+
|
|
238
|
+
return distance, ""
|
|
239
|
+
|
|
240
|
+
|
|
227
241
|
class GetMaxAbsErr(TensorComparisonBasic):
|
|
228
242
|
"""计算最大绝对误差"""
|
|
229
|
-
|
|
243
|
+
|
|
244
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
230
245
|
temp_res = n_value - b_value
|
|
231
246
|
max_value = np.max(np.abs(temp_res))
|
|
232
247
|
if np.isnan(max_value):
|
|
@@ -237,7 +252,8 @@ class GetMaxAbsErr(TensorComparisonBasic):
|
|
|
237
252
|
|
|
238
253
|
class GetMaxRelativeErr(TensorComparisonBasic):
|
|
239
254
|
"""计算最大相对误差"""
|
|
240
|
-
|
|
255
|
+
|
|
256
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
241
257
|
max_relative_err = np.max(np.abs(relative_err))
|
|
242
258
|
if np.isnan(max_relative_err):
|
|
243
259
|
msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
|
|
@@ -247,12 +263,13 @@ class GetMaxRelativeErr(TensorComparisonBasic):
|
|
|
247
263
|
|
|
248
264
|
class GetErrRatio(TensorComparisonBasic):
|
|
249
265
|
"""计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
|
|
266
|
+
|
|
250
267
|
def __init__(self, threshold):
|
|
251
268
|
self.threshold = threshold
|
|
252
269
|
|
|
253
|
-
def apply(self, n_value, b_value, relative_err):
|
|
254
|
-
if
|
|
255
|
-
return CompareConst.UNSUPPORTED,
|
|
270
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
271
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
272
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
256
273
|
|
|
257
274
|
if not np.size(relative_err):
|
|
258
275
|
return CompareConst.NAN, ""
|
|
@@ -264,6 +281,7 @@ class GetErrRatio(TensorComparisonBasic):
|
|
|
264
281
|
class CompareOps:
|
|
265
282
|
compare_ops = {
|
|
266
283
|
"cosine_similarity": GetCosineSimilarity(),
|
|
284
|
+
"euclidean_distance": GetEuclideanDistance(),
|
|
267
285
|
"max_abs_error": GetMaxAbsErr(),
|
|
268
286
|
"max_relative_error": GetMaxRelativeErr(),
|
|
269
287
|
"one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
|
|
@@ -272,10 +290,8 @@ class CompareOps:
|
|
|
272
290
|
|
|
273
291
|
|
|
274
292
|
def error_value_process(n_value):
|
|
275
|
-
if n_value
|
|
293
|
+
if n_value in [CompareConst.READ_NONE, CompareConst.UNREADABLE, CompareConst.NONE]:
|
|
276
294
|
return CompareConst.UNSUPPORTED, ""
|
|
277
|
-
if n_value == CompareConst.NONE:
|
|
278
|
-
return 0, ""
|
|
279
295
|
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
280
296
|
return CompareConst.SHAPE_UNMATCH, ""
|
|
281
297
|
if n_value == CompareConst.NAN:
|
|
@@ -295,7 +311,7 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg):
|
|
|
295
311
|
n_value, b_value = reshape_value(n_value, b_value)
|
|
296
312
|
|
|
297
313
|
for op in CompareOps.compare_ops.values():
|
|
298
|
-
result, msg = op.apply(n_value, b_value, relative_err)
|
|
314
|
+
result, msg = op.apply(n_value, b_value, relative_err, err_msg)
|
|
299
315
|
result_list.append(result)
|
|
300
316
|
err_msg += msg
|
|
301
317
|
return result_list, err_msg
|