mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -29,12 +29,16 @@ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
|
29
29
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
|
|
30
30
|
API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
|
|
31
31
|
ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
|
|
32
|
-
BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage
|
|
33
|
-
|
|
32
|
+
BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_input import PrecisionCompareInput
|
|
34
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
|
|
35
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpPrecisionCompare
|
|
36
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkPrecisionCompare
|
|
37
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
34
38
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
|
|
35
39
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
|
|
36
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
|
|
37
|
-
from msprobe.core.common.file_utils import FileChecker, change_mode,
|
|
40
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
|
|
41
|
+
from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
|
|
38
42
|
from msprobe.pytorch.common.log import logger
|
|
39
43
|
from msprobe.core.common.utils import CompareException
|
|
40
44
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
@@ -47,30 +51,6 @@ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_va
|
|
|
47
51
|
'eb_inf_nan_consistency'])
|
|
48
52
|
UNSUPPORTED_MESSAGE = 'This data type does not support benchmark compare.'
|
|
49
53
|
|
|
50
|
-
DEFAULT_THRESHOLD = 1
|
|
51
|
-
|
|
52
|
-
benchmark_algorithms_thresholds = {
|
|
53
|
-
'small_value': {
|
|
54
|
-
'error_threshold': 2,
|
|
55
|
-
'warning_threshold': 1
|
|
56
|
-
},
|
|
57
|
-
'rmse': {
|
|
58
|
-
'error_threshold': 2,
|
|
59
|
-
'warning_threshold': 1
|
|
60
|
-
},
|
|
61
|
-
'max_rel_err': {
|
|
62
|
-
'error_threshold': 10,
|
|
63
|
-
'warning_threshold': 1
|
|
64
|
-
},
|
|
65
|
-
'mean_rel_err': {
|
|
66
|
-
'error_threshold': 2,
|
|
67
|
-
'warning_threshold': 1
|
|
68
|
-
},
|
|
69
|
-
'eb': {
|
|
70
|
-
'error_threshold': 2,
|
|
71
|
-
'warning_threshold': 1
|
|
72
|
-
}
|
|
73
|
-
}
|
|
74
54
|
|
|
75
55
|
benchmark_message = {
|
|
76
56
|
"small_value_err_status": {
|
|
@@ -92,189 +72,6 @@ benchmark_message = {
|
|
|
92
72
|
}
|
|
93
73
|
|
|
94
74
|
|
|
95
|
-
class Standard:
|
|
96
|
-
@staticmethod
|
|
97
|
-
def _calc_ratio(column_name, x, y, default_value):
|
|
98
|
-
'''
|
|
99
|
-
计算npu侧和gpu侧统计量的比值
|
|
100
|
-
输入:
|
|
101
|
-
column_name:统计量名称
|
|
102
|
-
x:npu侧统计量
|
|
103
|
-
y:gpu侧统计量
|
|
104
|
-
default:当x不接近0,y接近0,设置的比值默认值
|
|
105
|
-
输出:
|
|
106
|
-
ratio:统计量x和y的比值
|
|
107
|
-
inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
|
|
108
|
-
message:当出现inf或nan时的提示信息
|
|
109
|
-
'''
|
|
110
|
-
x, y = convert_str_to_float(x), convert_str_to_float(y)
|
|
111
|
-
|
|
112
|
-
if is_inf_or_nan(x) or is_inf_or_nan(y):
|
|
113
|
-
return check_inf_or_nan(x, y, column_name)
|
|
114
|
-
|
|
115
|
-
inf_nan_consistency = True
|
|
116
|
-
message = ""
|
|
117
|
-
if math.isclose(y, 0.0):
|
|
118
|
-
if math.isclose(x, 0.0):
|
|
119
|
-
return 1.0, inf_nan_consistency, message
|
|
120
|
-
else:
|
|
121
|
-
return default_value, inf_nan_consistency, message
|
|
122
|
-
else:
|
|
123
|
-
return abs(x / y), inf_nan_consistency, message
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class BenchmarkStandard(Standard):
|
|
127
|
-
def __init__(self, api_name, npu_precision, gpu_precision):
|
|
128
|
-
self.api_name = api_name
|
|
129
|
-
self.npu_precision = npu_precision
|
|
130
|
-
self.gpu_precision = gpu_precision
|
|
131
|
-
self.small_value_err_ratio = 1
|
|
132
|
-
self.rmse_ratio = 1
|
|
133
|
-
self.max_rel_err_ratio = 1
|
|
134
|
-
self.mean_rel_err_ratio = 1
|
|
135
|
-
self.eb_ratio = 1
|
|
136
|
-
self.small_value_err_status = CompareConst.PASS
|
|
137
|
-
self.rmse_status = CompareConst.PASS
|
|
138
|
-
self.max_rel_err_status = CompareConst.PASS
|
|
139
|
-
self.mean_rel_err_status = CompareConst.PASS
|
|
140
|
-
self.eb_status = CompareConst.PASS
|
|
141
|
-
self.check_result_list = []
|
|
142
|
-
self.final_result = CompareConst.PASS
|
|
143
|
-
self.compare_message = ""
|
|
144
|
-
|
|
145
|
-
def __str__(self):
|
|
146
|
-
return "%s" % (self.api_name)
|
|
147
|
-
|
|
148
|
-
@staticmethod
|
|
149
|
-
def _get_status(ratio, algorithm):
|
|
150
|
-
if math.isnan(ratio) or math.isinf(ratio):
|
|
151
|
-
return CompareConst.PASS
|
|
152
|
-
error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
|
|
153
|
-
warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
|
|
154
|
-
DEFAULT_THRESHOLD)
|
|
155
|
-
if ratio > error_threshold:
|
|
156
|
-
return CompareConst.ERROR
|
|
157
|
-
elif ratio > warning_threshold:
|
|
158
|
-
return CompareConst.WARNING
|
|
159
|
-
return CompareConst.PASS
|
|
160
|
-
|
|
161
|
-
def get_result(self):
|
|
162
|
-
inf_nan_consistency = self._compare_ratio()
|
|
163
|
-
small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
|
|
164
|
-
rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
|
|
165
|
-
max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
|
|
166
|
-
mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
|
|
167
|
-
eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
|
|
168
|
-
self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
|
|
169
|
-
small_value_inf_nan_consistency else CompareConst.ERROR
|
|
170
|
-
self.check_result_list.append(self.small_value_err_status)
|
|
171
|
-
self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
|
|
172
|
-
else CompareConst.ERROR
|
|
173
|
-
self.check_result_list.append(self.rmse_status)
|
|
174
|
-
self.max_rel_err_status = self._get_status(
|
|
175
|
-
self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency else CompareConst.ERROR
|
|
176
|
-
self.check_result_list.append(self.max_rel_err_status)
|
|
177
|
-
self.mean_rel_err_status = self._get_status(
|
|
178
|
-
self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency else CompareConst.ERROR
|
|
179
|
-
self.check_result_list.append(self.mean_rel_err_status)
|
|
180
|
-
self.eb_status = self._get_status(self.eb_ratio, 'eb')
|
|
181
|
-
if CompareConst.ERROR in self.check_result_list:
|
|
182
|
-
self.final_result = CompareConst.ERROR
|
|
183
|
-
elif CompareConst.WARNING in self.check_result_list:
|
|
184
|
-
self.final_result = CompareConst.WARNING
|
|
185
|
-
|
|
186
|
-
def to_column_value(self):
|
|
187
|
-
return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
|
|
188
|
-
self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
|
|
189
|
-
self.mean_rel_err_status, self.eb_ratio, self.eb_status]
|
|
190
|
-
|
|
191
|
-
def _compare_ratio(self):
|
|
192
|
-
|
|
193
|
-
self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
|
|
194
|
-
ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
|
|
195
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
|
|
196
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
|
|
197
|
-
self.compare_message += small_value_message
|
|
198
|
-
self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
|
|
199
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
|
|
200
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
|
|
201
|
-
self.compare_message += rmse_message
|
|
202
|
-
self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
|
|
203
|
-
ApiPrecisionCompareColumn.MAX_REL_ERR,
|
|
204
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
|
|
205
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
|
|
206
|
-
self.compare_message += max_rel_message
|
|
207
|
-
self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
|
|
208
|
-
ApiPrecisionCompareColumn.MEAN_REL_ERR,
|
|
209
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
|
|
210
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
|
|
211
|
-
self.compare_message += mean_rel_message
|
|
212
|
-
self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
|
|
213
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.EB),
|
|
214
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
|
|
215
|
-
self.compare_message += eb_message
|
|
216
|
-
|
|
217
|
-
return BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
|
|
218
|
-
max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
|
|
219
|
-
eb_inf_nan_consistency)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
class ULPStandard(Standard):
|
|
223
|
-
def __init__(self, api_name, npu_precision, gpu_precision):
|
|
224
|
-
self.api_name = api_name
|
|
225
|
-
self.npu_precision = npu_precision
|
|
226
|
-
self.gpu_precision = gpu_precision
|
|
227
|
-
self.mean_ulp_err = 0
|
|
228
|
-
self.ulp_err_proportion = 0
|
|
229
|
-
self.ulp_err_proportion_ratio = 1
|
|
230
|
-
self.ulp_err_status = CompareConst.PASS
|
|
231
|
-
self.compare_message = ""
|
|
232
|
-
|
|
233
|
-
def __str__(self):
|
|
234
|
-
return f"{self.api_name}"
|
|
235
|
-
|
|
236
|
-
def get_result(self):
|
|
237
|
-
self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
|
|
238
|
-
gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
|
|
239
|
-
inf_nan_consistency = True
|
|
240
|
-
if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
|
|
241
|
-
_, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
|
|
242
|
-
ApiPrecisionCompareColumn.MEAN_ULP_ERR)
|
|
243
|
-
self.compare_message += message
|
|
244
|
-
self.ulp_err_proportion = convert_str_to_float(
|
|
245
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
|
|
246
|
-
self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
|
|
247
|
-
ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
|
|
248
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
|
|
249
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
|
|
250
|
-
inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
|
|
251
|
-
self.compare_message += message
|
|
252
|
-
if inf_nan_consistency:
|
|
253
|
-
self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
|
|
254
|
-
else:
|
|
255
|
-
self.ulp_err_status = CompareConst.ERROR
|
|
256
|
-
|
|
257
|
-
def _get_ulp_status(self, dtype):
|
|
258
|
-
if dtype == torch.float32:
|
|
259
|
-
if self.mean_ulp_err < 64:
|
|
260
|
-
return CompareConst.PASS
|
|
261
|
-
elif self.ulp_err_proportion < 0.05:
|
|
262
|
-
return CompareConst.PASS
|
|
263
|
-
elif self.ulp_err_proportion_ratio < 1:
|
|
264
|
-
return CompareConst.PASS
|
|
265
|
-
else:
|
|
266
|
-
self.compare_message += "ERROR: ULP误差不满足标准\n"
|
|
267
|
-
return CompareConst.ERROR
|
|
268
|
-
else:
|
|
269
|
-
if self.ulp_err_proportion < 0.001:
|
|
270
|
-
return CompareConst.PASS
|
|
271
|
-
elif self.ulp_err_proportion_ratio < 1:
|
|
272
|
-
return CompareConst.PASS
|
|
273
|
-
else:
|
|
274
|
-
self.compare_message += "ERROR: ULP误差不满足标准\n"
|
|
275
|
-
return CompareConst.ERROR
|
|
276
|
-
|
|
277
|
-
|
|
278
75
|
def write_detail_csv(content, save_path):
|
|
279
76
|
rows = []
|
|
280
77
|
content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
|
|
@@ -283,6 +80,17 @@ def write_detail_csv(content, save_path):
|
|
|
283
80
|
write_csv(rows, save_path)
|
|
284
81
|
|
|
285
82
|
|
|
83
|
+
def register_compare_func():
|
|
84
|
+
registry = StandardRegistry()
|
|
85
|
+
registry.register(CompareConst.ABSOLUTE_THRESHOLD, record_absolute_threshold_result)
|
|
86
|
+
registry.register(CompareConst.BINARY_CONSISTENCY, record_binary_consistency_result)
|
|
87
|
+
registry.register(CompareConst.ULP_COMPARE, record_ulp_compare_result)
|
|
88
|
+
registry.register(CompareConst.THOUSANDTH_STANDARD, record_thousandth_threshold_result)
|
|
89
|
+
registry.register(CompareConst.BENCHMARK, record_benchmark_compare_result)
|
|
90
|
+
registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, record_accumulative_error_compare_result)
|
|
91
|
+
return registry
|
|
92
|
+
|
|
93
|
+
|
|
286
94
|
def api_precision_compare(config):
|
|
287
95
|
logger.info("Start compare task")
|
|
288
96
|
logger.info(f"Compare task result will be saved in {config.result_csv_path}")
|
|
@@ -337,6 +145,8 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
337
145
|
forward_status, backward_status = [], []
|
|
338
146
|
last_api_name, last_api_dtype, last_api_full_name = None, None, None
|
|
339
147
|
last_api_skip_message = ''
|
|
148
|
+
registry = register_compare_func()
|
|
149
|
+
|
|
340
150
|
for _, row_npu in npu_data.iterrows():
|
|
341
151
|
message = ''
|
|
342
152
|
compare_column = ApiPrecisionOutputColumn()
|
|
@@ -362,7 +172,7 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
362
172
|
row_gpu = row_gpu.iloc[0]
|
|
363
173
|
new_status = CompareConst.SPACE
|
|
364
174
|
try:
|
|
365
|
-
new_status = get_api_status(row_npu, row_gpu, api_name, compare_column)
|
|
175
|
+
new_status = get_api_status(row_npu, row_gpu, api_name, compare_column, registry)
|
|
366
176
|
except Exception as err:
|
|
367
177
|
logger.error(f"Get api status error: {str(err)}")
|
|
368
178
|
compare_column.api_name = full_api_name_with_direction_status
|
|
@@ -383,7 +193,8 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
383
193
|
else:
|
|
384
194
|
forward_result = get_api_checker_result(forward_status)
|
|
385
195
|
backward_result = get_api_checker_result(backward_status)
|
|
386
|
-
|
|
196
|
+
_, base_api_name = extract_basic_api_segments(last_api_name)
|
|
197
|
+
message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
387
198
|
message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
|
|
388
199
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
389
200
|
print_test_success(last_api_name, forward_result, backward_result)
|
|
@@ -415,37 +226,30 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
415
226
|
else:
|
|
416
227
|
forward_result = get_api_checker_result(forward_status)
|
|
417
228
|
backward_result = get_api_checker_result(backward_status)
|
|
418
|
-
|
|
229
|
+
_, base_api_name = extract_basic_api_segments(last_api_name)
|
|
230
|
+
message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
419
231
|
message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
|
|
420
232
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
421
233
|
print_test_success(last_api_name, forward_result, backward_result)
|
|
422
234
|
last_api_skip_message = ''
|
|
423
235
|
|
|
424
236
|
|
|
425
|
-
def get_api_status(row_npu, row_gpu, api_name, compare_column):
|
|
237
|
+
def get_api_status(row_npu, row_gpu, api_name, compare_column, registry):
|
|
426
238
|
full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
|
|
427
239
|
# 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
|
|
428
|
-
if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace()
|
|
240
|
+
if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace() or \
|
|
241
|
+
row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in API_PRECISION_COMPARE_UNSUPPORT_LIST or \
|
|
242
|
+
row_npu[ApiPrecisionCompareColumn.SHAPE] == CompareConst.ZERO_SHAPE:
|
|
429
243
|
compare_column.api_name = full_api_name_with_direction_status
|
|
430
244
|
compare_column.compare_result = CompareConst.SKIP
|
|
431
245
|
compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
|
|
432
246
|
new_status = CompareConst.SKIP
|
|
433
247
|
else:
|
|
434
248
|
compare_column.api_name = full_api_name_with_direction_status
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
|
|
440
|
-
elif api_name in absolute_standard_api:
|
|
441
|
-
new_status = record_absolute_threshold_result(compare_column, row_npu)
|
|
442
|
-
elif api_name in ulp_standard_api and \
|
|
443
|
-
row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
|
|
444
|
-
us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
|
|
445
|
-
new_status = record_ulp_compare_result(compare_column, us)
|
|
446
|
-
elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
|
|
447
|
-
bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
|
|
448
|
-
new_status = record_benchmark_compare_result(compare_column, bs)
|
|
249
|
+
dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
|
|
250
|
+
input_data = PrecisionCompareInput(row_npu, row_gpu, dtype, compare_column)
|
|
251
|
+
comparison_func = registry.get_comparison_function(api_name, dtype)
|
|
252
|
+
new_status = comparison_func(input_data)
|
|
449
253
|
return new_status
|
|
450
254
|
|
|
451
255
|
|
|
@@ -505,21 +309,24 @@ def check_csv_columns(columns, csv_type):
|
|
|
505
309
|
raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
|
|
506
310
|
|
|
507
311
|
|
|
508
|
-
def record_binary_consistency_result(
|
|
312
|
+
def record_binary_consistency_result(input_data):
|
|
313
|
+
row_npu = input_data.row_npu
|
|
314
|
+
compare_column = input_data.compare_column
|
|
509
315
|
new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
|
|
510
316
|
compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
|
|
511
317
|
compare_column.error_rate_status = new_status
|
|
512
318
|
compare_column.compare_result = new_status
|
|
513
|
-
compare_column.compare_algorithm =
|
|
319
|
+
compare_column.compare_algorithm = CompareConst.BINARY_CONSISTENCY_ALGORITHM_NAME
|
|
514
320
|
message = ''
|
|
515
321
|
if compare_column.error_rate_status == CompareConst.ERROR:
|
|
516
322
|
message += "ERROR: 二进制一致错误率超过阈值\n"
|
|
517
|
-
message += CompareMessage.get(api_name, "")
|
|
518
323
|
compare_column.compare_message = message
|
|
519
324
|
return new_status
|
|
520
325
|
|
|
521
326
|
|
|
522
|
-
def record_absolute_threshold_result(
|
|
327
|
+
def record_absolute_threshold_result(input_data):
|
|
328
|
+
row_npu = input_data.row_npu
|
|
329
|
+
compare_column = input_data.compare_column
|
|
523
330
|
absolute_threshold_result = get_absolute_threshold_result(row_npu)
|
|
524
331
|
compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
|
|
525
332
|
compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
|
|
@@ -528,62 +335,88 @@ def record_absolute_threshold_result(compare_column, row_npu):
|
|
|
528
335
|
compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
|
|
529
336
|
compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
|
|
530
337
|
compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
|
|
531
|
-
compare_column.compare_algorithm =
|
|
338
|
+
compare_column.compare_algorithm = CompareConst.ABSOLUTE_THRESHOLD_ALGORITHM_NAME
|
|
532
339
|
message = ''
|
|
533
340
|
if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
|
|
534
|
-
message += "ERROR: inf/nan
|
|
341
|
+
message += "ERROR: inf/nan错误率超过阈值"
|
|
535
342
|
if compare_column.rel_err_ratio_status == CompareConst.ERROR:
|
|
536
|
-
message += "ERROR:
|
|
343
|
+
message += "ERROR: 相对误差错误率超过阈值"
|
|
537
344
|
if compare_column.abs_err_ratio_status == CompareConst.ERROR:
|
|
538
|
-
message += "ERROR:
|
|
345
|
+
message += "ERROR: 绝对误差错误率超过阈值"
|
|
539
346
|
compare_column.compare_message = message
|
|
540
347
|
return compare_column.compare_result
|
|
541
348
|
|
|
542
349
|
|
|
543
|
-
def record_benchmark_compare_result(
|
|
544
|
-
bs
|
|
545
|
-
|
|
546
|
-
compare_column.small_value_err_status = bs.small_value_err_status
|
|
547
|
-
compare_column.rmse_ratio = bs.rmse_ratio
|
|
548
|
-
compare_column.rmse_status = bs.rmse_status
|
|
549
|
-
compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
|
|
550
|
-
compare_column.max_rel_err_status = bs.max_rel_err_status
|
|
551
|
-
compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
|
|
552
|
-
compare_column.mean_rel_err_status = bs.mean_rel_err_status
|
|
553
|
-
compare_column.eb_ratio = bs.eb_ratio
|
|
554
|
-
compare_column.eb_status = bs.eb_status
|
|
555
|
-
compare_column.compare_result = bs.final_result
|
|
556
|
-
compare_column.compare_algorithm = "标杆比对法"
|
|
557
|
-
compare_column.compare_message = bs.compare_message
|
|
350
|
+
def record_benchmark_compare_result(input_data):
|
|
351
|
+
bs = BenchmarkPrecisionCompare(input_data)
|
|
352
|
+
compare_result = bs.compare()
|
|
558
353
|
for status_attr, messages in benchmark_message.items():
|
|
559
|
-
status_value = getattr(compare_column, status_attr)
|
|
354
|
+
status_value = getattr(input_data.compare_column, status_attr)
|
|
560
355
|
if status_value in messages:
|
|
561
|
-
compare_column.compare_message += messages[status_value]
|
|
562
|
-
return
|
|
356
|
+
input_data.compare_column.compare_message += messages[status_value]
|
|
357
|
+
return compare_result
|
|
358
|
+
|
|
563
359
|
|
|
360
|
+
def record_ulp_compare_result(input_data):
|
|
361
|
+
us = UlpPrecisionCompare(input_data)
|
|
362
|
+
compare_result = us.compare()
|
|
363
|
+
return compare_result
|
|
564
364
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
compare_column
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
365
|
+
|
|
366
|
+
def record_accumulative_error_compare_result(input_data):
|
|
367
|
+
row_npu = input_data.row_npu
|
|
368
|
+
compare_column = input_data.compare_column
|
|
369
|
+
absolute_threshold_result = get_absolute_threshold_result(row_npu)
|
|
370
|
+
threshold_result = absolute_threshold_result.get("absolute_threshold_result")
|
|
371
|
+
eb, eb_result = check_eb(row_npu)
|
|
372
|
+
accumulative_error_compare_result = CompareConst.PASS
|
|
373
|
+
if CompareConst.ERROR in [threshold_result, eb_result]:
|
|
374
|
+
accumulative_error_compare_result = CompareConst.ERROR
|
|
375
|
+
|
|
376
|
+
compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
|
|
377
|
+
compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
|
|
378
|
+
compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
|
|
379
|
+
compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
|
|
380
|
+
compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
|
|
381
|
+
compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
|
|
382
|
+
compare_column.eb_ratio = eb
|
|
383
|
+
compare_column.eb_status = eb_result
|
|
384
|
+
compare_column.compare_result = accumulative_error_compare_result
|
|
385
|
+
compare_column.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME
|
|
386
|
+
message = []
|
|
387
|
+
if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
|
|
388
|
+
message.append("ERROR: inf/nan错误率超过阈值\n")
|
|
389
|
+
if compare_column.rel_err_ratio_status == CompareConst.ERROR:
|
|
390
|
+
message.append("ERROR: 相对误差错误率超过阈值\n")
|
|
391
|
+
if compare_column.abs_err_ratio_status == CompareConst.ERROR:
|
|
392
|
+
message.append("ERROR: 绝对误差错误率超过阈值\n")
|
|
393
|
+
if compare_column.eb_status == CompareConst.ERROR:
|
|
394
|
+
message.append("ERROR: 误差均衡性超过阈值\n")
|
|
395
|
+
compare_column.compare_message = "\n".join(message)
|
|
574
396
|
return compare_column.compare_result
|
|
575
397
|
|
|
576
398
|
|
|
399
|
+
def check_eb(row_npu):
|
|
400
|
+
eb = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.EB])
|
|
401
|
+
dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
|
|
402
|
+
eb_threshold = StandardConfig.get_accumulative_error_eb_threshold(dtype)
|
|
403
|
+
eb_result = CompareConst.PASS if eb <= eb_threshold else CompareConst.ERROR
|
|
404
|
+
return eb, eb_result
|
|
405
|
+
|
|
406
|
+
|
|
577
407
|
def check_thousandth_rate(thousandth_rate):
|
|
578
|
-
return CompareConst.PASS if convert_str_to_float(thousandth_rate) >=
|
|
408
|
+
return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= CompareConst.THOUSANDTH_PASS_VALUE \
|
|
409
|
+
else CompareConst.ERROR
|
|
579
410
|
|
|
580
411
|
|
|
581
|
-
def record_thousandth_threshold_result(
|
|
412
|
+
def record_thousandth_threshold_result(input_data):
|
|
413
|
+
row_npu = input_data.row_npu
|
|
414
|
+
compare_column = input_data.compare_column
|
|
582
415
|
new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
|
|
583
416
|
compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
|
|
584
417
|
compare_column.rel_err_thousandth_status = new_status
|
|
585
418
|
compare_column.compare_result = new_status
|
|
586
|
-
compare_column.compare_algorithm =
|
|
419
|
+
compare_column.compare_algorithm = CompareConst.THOUSANDTH_STANDARD_ALGORITHM_NAME
|
|
587
420
|
message = ''
|
|
588
421
|
if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
|
|
589
422
|
message += "ERROR: 双千指标不达标\n"
|
|
@@ -602,8 +435,7 @@ def _api_precision_compare(parser=None):
|
|
|
602
435
|
def _api_precision_compare_command(args):
|
|
603
436
|
npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
|
|
604
437
|
gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
|
|
605
|
-
out_path =
|
|
606
|
-
check_path_before_create(out_path)
|
|
438
|
+
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
607
439
|
create_directory(out_path)
|
|
608
440
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
609
441
|
out_path = out_path_checker.common_check()
|
|
@@ -621,7 +453,7 @@ def _api_precision_compare_parser(parser):
|
|
|
621
453
|
parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
|
|
622
454
|
help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
|
|
623
455
|
"api_accuracy_checker tool.",
|
|
624
|
-
required=
|
|
456
|
+
required=True)
|
|
625
457
|
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
626
458
|
help="<optional> The api precision compare task result out path.",
|
|
627
459
|
required=False)
|
|
@@ -66,6 +66,7 @@ BinaryCompareStandard:
|
|
|
66
66
|
- greater_
|
|
67
67
|
- greater_equal
|
|
68
68
|
- greater_equal_
|
|
69
|
+
- histc
|
|
69
70
|
- isfinite
|
|
70
71
|
- isnan
|
|
71
72
|
- less
|
|
@@ -130,4 +131,6 @@ ULPStandard:
|
|
|
130
131
|
ThousandthStandard:
|
|
131
132
|
- conv1d
|
|
132
133
|
- conv2d
|
|
133
|
-
|
|
134
|
+
|
|
135
|
+
AccumulativeErrorStandard:
|
|
136
|
+
- test_api
|