mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -24,15 +24,20 @@ from msprobe.core.common.utils import CompareException
|
|
|
24
24
|
from msprobe.core.common.file_utils import get_json_contents, write_csv
|
|
25
25
|
import torch
|
|
26
26
|
from msprobe.core.common.const import CompareConst
|
|
27
|
-
from msprobe.pytorch.api_accuracy_checker.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
|
|
28
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.absolute_threshold import AbsolutethdCompare
|
|
29
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkCompare
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpCompare
|
|
31
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.binary_consistency import BinaryCompare
|
|
32
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.thousandth_standard import ThousandthStdCompare
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.accumulative_error_compare import AccumulativeErrorCompare
|
|
34
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_input import CompareInput
|
|
35
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_err, get_max_abs_err, get_rel_err_ratio, \
|
|
36
|
+
cosine_sim, get_rel_err_origin, get_abs_bench_with_eps, compare_bool_tensor
|
|
31
37
|
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
32
38
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
33
39
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
|
|
34
|
-
DETAIL_TEST_ROWS,
|
|
35
|
-
ulp_standard_api, thousandth_standard_api, apis_threshold
|
|
40
|
+
DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST
|
|
36
41
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
37
42
|
from msprobe.pytorch.common.log import logger
|
|
38
43
|
|
|
@@ -42,6 +47,7 @@ ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'b
|
|
|
42
47
|
|
|
43
48
|
|
|
44
49
|
INDEX_TEST_RESULT_GROUP = 3
|
|
50
|
+
BACKWARD_RESULT_GROUP = 4
|
|
45
51
|
INDEX_FIRST_GROUP = 0
|
|
46
52
|
INDEX_MESSAGE = -1
|
|
47
53
|
|
|
@@ -66,6 +72,8 @@ class Comparator:
|
|
|
66
72
|
self.detail_save_path_list = \
|
|
67
73
|
[self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
68
74
|
|
|
75
|
+
self.registry = self._register_compare_func()
|
|
76
|
+
|
|
69
77
|
if not is_continue_run_ut:
|
|
70
78
|
self.write_csv_title()
|
|
71
79
|
if stack_info_json_path:
|
|
@@ -101,22 +109,6 @@ class Comparator:
|
|
|
101
109
|
compare_column.error_rate = 0
|
|
102
110
|
return CompareConst.PASS, compare_column, ""
|
|
103
111
|
|
|
104
|
-
@staticmethod
|
|
105
|
-
def _compare_bool_tensor(bench_output, device_output):
|
|
106
|
-
error_nums = (bench_output != device_output).sum()
|
|
107
|
-
if bench_output.size == 0:
|
|
108
|
-
return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
|
|
109
|
-
error_rate = float(error_nums / bench_output.size)
|
|
110
|
-
result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
|
|
111
|
-
return error_rate, result, ""
|
|
112
|
-
|
|
113
|
-
@staticmethod
|
|
114
|
-
def _get_absolute_threshold_attribute(api_name, dtype):
|
|
115
|
-
small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
|
|
116
|
-
small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
|
|
117
|
-
rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
|
|
118
|
-
return small_value_threshold, small_value_atol, rtol
|
|
119
|
-
|
|
120
112
|
@staticmethod
|
|
121
113
|
def _get_run_ut_detail(test_result):
|
|
122
114
|
"""get run_ut detail before write to csv, called by online run_ut"""
|
|
@@ -143,6 +135,36 @@ class Comparator:
|
|
|
143
135
|
test_rows.append([subject] + list(test_subject))
|
|
144
136
|
return test_rows
|
|
145
137
|
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _binary_standard_compare(input_data):
|
|
140
|
+
binary_compare = BinaryCompare(input_data)
|
|
141
|
+
binary_compare.compare()
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def _thousandth_standard_compare(input_data):
|
|
145
|
+
thousandth_compare = ThousandthStdCompare(input_data)
|
|
146
|
+
thousandth_compare.compare()
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _absolute_standard_compare(input_data):
|
|
150
|
+
absolute_compare = AbsolutethdCompare(input_data)
|
|
151
|
+
absolute_compare.compare()
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def _ulp_compare(input_data):
|
|
155
|
+
ulp_compare = UlpCompare(input_data)
|
|
156
|
+
ulp_compare.compare()
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def _benchmark_compare(input_data):
|
|
160
|
+
benchmark_compare = BenchmarkCompare(input_data)
|
|
161
|
+
benchmark_compare.compare()
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def _accumulative_error_compare(input_data):
|
|
165
|
+
accumulative_error_compare = AccumulativeErrorCompare(input_data)
|
|
166
|
+
accumulative_error_compare.compare()
|
|
167
|
+
|
|
146
168
|
def write_csv_title(self):
|
|
147
169
|
summary_test_rows = [
|
|
148
170
|
[self.COLUMN_API_NAME,
|
|
@@ -163,6 +185,8 @@ class Comparator:
|
|
|
163
185
|
df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
|
|
164
186
|
if test_result[1] == CompareConst.SKIP:
|
|
165
187
|
df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
|
|
188
|
+
elif test_result[2] == CompareConst.SKIP:
|
|
189
|
+
df_row.append(test_result[BACKWARD_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
|
|
166
190
|
if self.stack_info:
|
|
167
191
|
stack_info = "\n".join(self.stack_info[name])
|
|
168
192
|
df_row.append(stack_info)
|
|
@@ -211,6 +235,7 @@ class Comparator:
|
|
|
211
235
|
if backward_message:
|
|
212
236
|
backward_column = CompareColumn()
|
|
213
237
|
bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
|
|
238
|
+
bwd_success_status = CompareConst.SKIP
|
|
214
239
|
else:
|
|
215
240
|
bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
|
|
216
241
|
result_info = ResultInfo(full_api_name,
|
|
@@ -226,6 +251,16 @@ class Comparator:
|
|
|
226
251
|
return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
|
|
227
252
|
or bwd_success_status == CompareConst.SPACE
|
|
228
253
|
|
|
254
|
+
def _register_compare_func(self):
|
|
255
|
+
registry = StandardRegistry()
|
|
256
|
+
registry.register(CompareConst.ABSOLUTE_THRESHOLD, self._absolute_standard_compare)
|
|
257
|
+
registry.register(CompareConst.BINARY_CONSISTENCY, self._binary_standard_compare)
|
|
258
|
+
registry.register(CompareConst.ULP_COMPARE, self._ulp_compare)
|
|
259
|
+
registry.register(CompareConst.THOUSANDTH_STANDARD, self._thousandth_standard_compare)
|
|
260
|
+
registry.register(CompareConst.BENCHMARK, self._benchmark_compare)
|
|
261
|
+
registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare)
|
|
262
|
+
return registry
|
|
263
|
+
|
|
229
264
|
def _compare_core_wrapper(self, api_name, bench_output, device_output):
|
|
230
265
|
detailed_result_total = []
|
|
231
266
|
test_final_success = CompareConst.PASS
|
|
@@ -308,11 +343,13 @@ class Comparator:
|
|
|
308
343
|
return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
|
|
309
344
|
f"npu output dtype is {device_output.dtype}, cannot compare."
|
|
310
345
|
message = ""
|
|
346
|
+
if bench_output.size == 0:
|
|
347
|
+
return CompareConst.ERROR, compare_column, "There is not bench calculation result."
|
|
311
348
|
if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
|
|
312
349
|
np.int64, np.uint64]:
|
|
313
350
|
message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
|
|
314
351
|
f"Only judged by Error Rate."
|
|
315
|
-
err_rate, status, msg =
|
|
352
|
+
err_rate, status, msg = compare_bool_tensor(bench_output, device_output)
|
|
316
353
|
message += msg + "\n"
|
|
317
354
|
compare_column.error_rate = err_rate
|
|
318
355
|
return status, compare_column, message
|
|
@@ -321,56 +358,20 @@ class Comparator:
|
|
|
321
358
|
compare_column, npu_dtype)
|
|
322
359
|
return status, compare_column, message
|
|
323
360
|
|
|
361
|
+
def _perform_comparison(self, api_name, input_data):
|
|
362
|
+
comparison_func = self.registry.get_comparison_function(api_name, None)
|
|
363
|
+
comparison_func(input_data)
|
|
364
|
+
|
|
324
365
|
def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
|
|
325
366
|
message = ""
|
|
326
|
-
|
|
367
|
+
_, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
|
|
327
368
|
abs_err = get_abs_err(bench_output, device_output)
|
|
328
369
|
rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
|
|
329
|
-
|
|
330
|
-
thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
|
|
331
|
-
compare_column.rel_err_thousandth = thousand_res
|
|
370
|
+
input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign)
|
|
332
371
|
if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
compare_column.error_rate = err_rate
|
|
337
|
-
elif api_name in absolute_standard_api:
|
|
338
|
-
small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
|
|
339
|
-
api_name, str(dtype))
|
|
340
|
-
rel_err = abs_err / abs_bench_with_eps
|
|
341
|
-
small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
|
|
342
|
-
normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
|
|
343
|
-
compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
|
|
344
|
-
dtype, rtol)
|
|
345
|
-
compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
|
|
346
|
-
compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
|
|
347
|
-
elif api_name in ulp_standard_api:
|
|
348
|
-
if bench_output.size == 0:
|
|
349
|
-
compare_column.max_ulp_error = 0
|
|
350
|
-
compare_column.mean_ulp_error = 0
|
|
351
|
-
compare_column.ulp_error_proportion = 0
|
|
352
|
-
else:
|
|
353
|
-
ulp_err = get_ulp_err(bench_output, device_output, dtype)
|
|
354
|
-
compare_column.max_ulp_error = np.max(ulp_err)
|
|
355
|
-
compare_column.mean_ulp_error = np.mean(ulp_err)
|
|
356
|
-
if dtype == torch.float32:
|
|
357
|
-
compare_column.ulp_error_proportion = \
|
|
358
|
-
np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
|
|
359
|
-
else:
|
|
360
|
-
compare_column.ulp_error_proportion = \
|
|
361
|
-
np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
|
|
362
|
-
else:
|
|
363
|
-
dtype_config = precision_configs.get(dtype)
|
|
364
|
-
small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
|
|
365
|
-
abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
|
|
366
|
-
compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
|
|
367
|
-
rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
|
|
368
|
-
compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
|
|
369
|
-
compare_column.eb = get_error_balance(bench_output, device_output)
|
|
370
|
-
if rel_err.size == 0:
|
|
371
|
-
return CompareConst.ERROR, compare_column, "Relative error result list is empty."
|
|
372
|
-
compare_column.max_rel_error = get_max_rel_err(rel_err)
|
|
373
|
-
compare_column.mean_rel_error = get_mean_rel_err(rel_err)
|
|
372
|
+
self._perform_comparison(api_name, input_data)
|
|
373
|
+
else:
|
|
374
|
+
message += f"The data type {dtype} is not supported for new precision standard."
|
|
374
375
|
|
|
375
376
|
cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
|
|
376
377
|
compare_column.cosine_sim = cos_res
|
|
@@ -16,9 +16,17 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import CompareConst
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class CompareColumn:
|
|
23
|
+
__slots__ = [
|
|
24
|
+
'bench_type', 'npu_type', 'shape', 'cosine_sim', 'max_abs_err', 'rel_err_hundredth',
|
|
25
|
+
'rel_err_ten_thousandth', 'inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio',
|
|
26
|
+
'small_value_err_ratio', 'max_rel_error', 'mean_rel_error', 'rmse', 'eb', 'max_ulp_error',
|
|
27
|
+
'mean_ulp_error', 'ulp_error_proportion', 'error_rate', 'rel_err_thousandth'
|
|
28
|
+
]
|
|
29
|
+
|
|
22
30
|
def __init__(self):
|
|
23
31
|
self.bench_type = CompareConst.SPACE
|
|
24
32
|
self.npu_type = CompareConst.SPACE
|
|
@@ -41,6 +49,24 @@ class CompareColumn:
|
|
|
41
49
|
self.mean_ulp_error = CompareConst.SPACE
|
|
42
50
|
self.ulp_error_proportion = CompareConst.SPACE
|
|
43
51
|
|
|
52
|
+
def update(self, metrics):
|
|
53
|
+
"""
|
|
54
|
+
Updates the object's attributes with the provided metrics.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
metrics (dict): A dictionary containing attribute names and their corresponding values.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
AttributeError: If the metric key is not a valid attribute of CompareColumn.
|
|
61
|
+
"""
|
|
62
|
+
for key, value in metrics.items():
|
|
63
|
+
if value is None:
|
|
64
|
+
continue
|
|
65
|
+
if key not in self.__slots__:
|
|
66
|
+
logger.error(f"The key '{key}' is not a valid attribute of CompareColumn.")
|
|
67
|
+
continue
|
|
68
|
+
setattr(self, key, value)
|
|
69
|
+
|
|
44
70
|
def to_column_value(self, is_pass, message):
|
|
45
71
|
return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
|
|
46
72
|
self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.eb, self.rmse,
|
|
@@ -50,6 +76,16 @@ class CompareColumn:
|
|
|
50
76
|
|
|
51
77
|
|
|
52
78
|
class ApiPrecisionOutputColumn:
|
|
79
|
+
__slots__ = [
|
|
80
|
+
'api_name', 'small_value_err_ratio', 'small_value_err_status', 'rmse_ratio', 'rmse_status',
|
|
81
|
+
'max_rel_err_ratio', 'max_rel_err_status', 'mean_rel_err_ratio', 'mean_rel_err_status', 'eb_ratio',
|
|
82
|
+
'eb_status', 'inf_nan_error_ratio', 'inf_nan_error_ratio_status', 'rel_err_ratio',
|
|
83
|
+
'rel_err_ratio_status', 'abs_err_ratio', 'abs_err_ratio_status', 'error_rate', 'error_rate_status',
|
|
84
|
+
'mean_ulp_err', 'ulp_err_proportion', 'ulp_err_proportion_ratio', 'ulp_err_status',
|
|
85
|
+
'rel_err_thousandth', 'rel_err_thousandth_status', 'compare_result', 'compare_algorithm',
|
|
86
|
+
'compare_message'
|
|
87
|
+
]
|
|
88
|
+
|
|
53
89
|
def __init__(self):
|
|
54
90
|
self.api_name = CompareConst.SPACE
|
|
55
91
|
self.small_value_err_ratio = CompareConst.SPACE
|
|
@@ -80,6 +116,24 @@ class ApiPrecisionOutputColumn:
|
|
|
80
116
|
self.compare_algorithm = CompareConst.SPACE
|
|
81
117
|
self.compare_message = CompareConst.SPACE
|
|
82
118
|
|
|
119
|
+
def update(self, metrics):
|
|
120
|
+
"""
|
|
121
|
+
Updates the object's attributes with the provided metrics.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
metrics (dict): A dictionary containing attribute names and their corresponding values.
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
AttributeError: If the metric key is not a valid attribute of CompareColumn.
|
|
128
|
+
"""
|
|
129
|
+
for key, value in metrics.items():
|
|
130
|
+
if value is None:
|
|
131
|
+
continue
|
|
132
|
+
if key not in self.__slots__:
|
|
133
|
+
logger.error("The key '%s' is not a valid attribute of CompareColumn.", key)
|
|
134
|
+
continue
|
|
135
|
+
setattr(self, key, value)
|
|
136
|
+
|
|
83
137
|
def to_column_value(self):
|
|
84
138
|
return [self.api_name, self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
|
|
85
139
|
self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CompareInput:
|
|
22
|
+
"""
|
|
23
|
+
A class to encapsulate the input data required for comparison operations.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
bench_output (np.ndarray): The benchmark output values.
|
|
27
|
+
device_output (np.ndarray): The device output values.
|
|
28
|
+
compare_column (class): A clasee to store and update comparison metrics.
|
|
29
|
+
dtype (type, optional): The data type of the outputs. Defaults to None.
|
|
30
|
+
rel_err_orign (float or array-like, optional): The original relative error values. Defaults to None.
|
|
31
|
+
|
|
32
|
+
Methods:
|
|
33
|
+
__init__(bench_output, device_output, compare_column, dtype, rel_err_orign):
|
|
34
|
+
Initializes an instance of CompareInput.
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, bench_output, device_output, compare_column, dtype=None, rel_err_orign=None):
|
|
37
|
+
self.bench_output = bench_output
|
|
38
|
+
self.device_output = device_output
|
|
39
|
+
if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray):
|
|
40
|
+
raise TypeError("The input should be numpy array")
|
|
41
|
+
self.compare_column = compare_column
|
|
42
|
+
self.dtype = dtype
|
|
43
|
+
self.rel_err_orign = rel_err_orign
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class PrecisionCompareInput:
|
|
47
|
+
def __init__(self, row_npu, row_gpu, dtype, compare_column):
|
|
48
|
+
self.row_npu = row_npu
|
|
49
|
+
self.row_gpu = row_gpu
|
|
50
|
+
self.dtype = dtype
|
|
51
|
+
self.compare_column = compare_column
|
|
@@ -43,10 +43,7 @@ absolute_standard_api = apis.get('AbsoluteThreshStandard')
|
|
|
43
43
|
binary_standard_api = apis.get('BinaryCompareStandard')
|
|
44
44
|
ulp_standard_api = apis.get('ULPStandard')
|
|
45
45
|
thousandth_standard_api = apis.get('ThousandthStandard')
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
|
|
49
|
-
apis_threshold = load_yaml(threshold_yaml_path)
|
|
46
|
+
accumulative_error_standard_api = apis.get('AccumulativeErrorStandard')
|
|
50
47
|
|
|
51
48
|
|
|
52
49
|
DETAIL_TEST_ROWS = [
|
|
@@ -134,6 +131,7 @@ ULP_PARAMETERS = {
|
|
|
134
131
|
class ApiPrecisionCompareColumn:
|
|
135
132
|
API_NAME = 'API Name'
|
|
136
133
|
DEVICE_DTYPE = 'DEVICE Dtype'
|
|
134
|
+
SHAPE = 'Shape'
|
|
137
135
|
SMALL_VALUE_ERROR_RATE = '小值域错误占比'
|
|
138
136
|
RMSE = '均方根误差'
|
|
139
137
|
MAX_REL_ERR = '相对误差最大值'
|