mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -14,18 +14,31 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
|
+
|
|
17
18
|
import numpy as np
|
|
18
|
-
|
|
19
|
+
|
|
19
20
|
from msprobe.core.common.const import Const, CompareConst
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.common.utils import CompareException, format_value
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
def handle_inf_nan(n_value, b_value):
|
|
26
|
+
def convert_to_float(value):
|
|
27
|
+
try:
|
|
28
|
+
if isinstance(value, np.ndarray):
|
|
29
|
+
return value.astype(float)
|
|
30
|
+
else:
|
|
31
|
+
return float(value)
|
|
32
|
+
except ValueError as e:
|
|
33
|
+
logger.error('\n'.join(e.args))
|
|
34
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
35
|
+
|
|
36
|
+
n_value_convert, b_value_convert = convert_to_float(n_value), convert_to_float(b_value)
|
|
24
37
|
"""处理inf和nan的数据"""
|
|
25
|
-
n_inf = np.isinf(
|
|
26
|
-
b_inf = np.isinf(
|
|
27
|
-
n_nan = np.isnan(
|
|
28
|
-
b_nan = np.isnan(
|
|
38
|
+
n_inf = np.isinf(n_value_convert)
|
|
39
|
+
b_inf = np.isinf(b_value_convert)
|
|
40
|
+
n_nan = np.isnan(n_value_convert)
|
|
41
|
+
b_nan = np.isnan(b_value_convert)
|
|
29
42
|
n_invalid = np.any(n_inf) or np.any(n_nan)
|
|
30
43
|
b_invalid = np.any(b_inf) or np.any(b_nan)
|
|
31
44
|
if n_invalid or b_invalid:
|
|
@@ -39,58 +52,66 @@ def handle_inf_nan(n_value, b_value):
|
|
|
39
52
|
return n_value, b_value
|
|
40
53
|
|
|
41
54
|
|
|
42
|
-
def
|
|
43
|
-
"""判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag"""
|
|
55
|
+
def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
|
|
56
|
+
"""判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag和error_msg"""
|
|
57
|
+
err_msg = ""
|
|
44
58
|
if error_flag:
|
|
45
|
-
|
|
59
|
+
if error_file == "no_bench_data":
|
|
60
|
+
err_msg = "Bench does not have data file."
|
|
61
|
+
elif error_file:
|
|
62
|
+
err_msg = f"Dump file: {error_file} not found."
|
|
63
|
+
else:
|
|
64
|
+
err_msg = CompareConst.NO_BENCH
|
|
65
|
+
error_flag = True
|
|
66
|
+
return CompareConst.READ_NONE, CompareConst.READ_NONE, error_flag, err_msg
|
|
67
|
+
|
|
46
68
|
if n_value.size == 0: # 判断读取到的数据是否为空
|
|
47
|
-
|
|
69
|
+
err_msg = "This is empty data, can not compare."
|
|
70
|
+
error_flag = True
|
|
71
|
+
return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
|
|
72
|
+
if not n_value.shape: # 判断数据是否为0维张量
|
|
73
|
+
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
|
|
74
|
+
f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
|
|
75
|
+
error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
|
|
76
|
+
return n_value, b_value, error_flag, err_msg
|
|
48
77
|
if n_value.shape != b_value.shape: # 判断NPU和bench的数据结构是否一致
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
return
|
|
78
|
+
err_msg = "Shape of NPU and bench tensor do not match. Skipped."
|
|
79
|
+
error_flag = True
|
|
80
|
+
return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, error_flag, err_msg
|
|
52
81
|
|
|
53
|
-
|
|
82
|
+
try:
|
|
83
|
+
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
|
|
84
|
+
except CompareException:
|
|
85
|
+
logger.error('Numpy data is unreadable, please check!')
|
|
86
|
+
err_msg = "Data is unreadable."
|
|
87
|
+
error_flag = True
|
|
88
|
+
return CompareConst.UNREADABLE, CompareConst.UNREADABLE, error_flag, err_msg
|
|
54
89
|
if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
|
|
55
|
-
|
|
56
|
-
|
|
90
|
+
err_msg = "The position of inf or nan in NPU and bench Tensor do not match."
|
|
91
|
+
error_flag = True
|
|
92
|
+
return CompareConst.NAN, CompareConst.NAN, error_flag, err_msg
|
|
93
|
+
|
|
94
|
+
if n_value.dtype != b_value.dtype: # 判断数据的dtype是否一致
|
|
95
|
+
err_msg = "Dtype of NPU and bench tensor do not match."
|
|
96
|
+
error_flag = False
|
|
97
|
+
return n_value, b_value, error_flag, err_msg
|
|
98
|
+
|
|
99
|
+
return n_value, b_value, error_flag, err_msg
|
|
57
100
|
|
|
58
101
|
|
|
59
102
|
def reshape_value(n_value, b_value):
|
|
60
103
|
"""返回reshape后的数据"""
|
|
61
|
-
if not n_value.shape: #
|
|
104
|
+
if not n_value.shape: # 判断数据是否为0维tensor, 如果0维tensor,不会转成1维tensor,直接返回
|
|
62
105
|
if n_value.dtype == bool:
|
|
63
106
|
n_value = n_value.astype(float)
|
|
64
107
|
b_value = b_value.astype(float)
|
|
65
108
|
return n_value, b_value
|
|
66
109
|
|
|
67
|
-
n_value = n_value.reshape(-1).astype(float)
|
|
110
|
+
n_value = n_value.reshape(-1).astype(float) # 32转64为了防止某些数转dataframe时出现误差
|
|
68
111
|
b_value = b_value.reshape(-1).astype(float)
|
|
69
112
|
return n_value, b_value
|
|
70
113
|
|
|
71
114
|
|
|
72
|
-
def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
|
|
73
|
-
"""获取异常情况的错误信息"""
|
|
74
|
-
if error_flag:
|
|
75
|
-
if n_value == CompareConst.READ_NONE:
|
|
76
|
-
if error_file:
|
|
77
|
-
return "Dump file: {} not found.".format(error_file)
|
|
78
|
-
return CompareConst.NO_BENCH
|
|
79
|
-
if n_value == CompareConst.NONE:
|
|
80
|
-
return "This is empty data, can not compare."
|
|
81
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
82
|
-
return "Shape of NPU and bench Tensor do not match. Skipped."
|
|
83
|
-
if n_value == CompareConst.NAN:
|
|
84
|
-
return "The position of inf or nan in NPU and bench Tensor do not match."
|
|
85
|
-
else:
|
|
86
|
-
if not n_value.shape:
|
|
87
|
-
return "This is type of scalar data, can not compare."
|
|
88
|
-
if n_value.dtype != b_value.dtype:
|
|
89
|
-
logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
|
|
90
|
-
return "Dtype of NPU and bench Tensor do not match."
|
|
91
|
-
return ""
|
|
92
|
-
|
|
93
|
-
|
|
94
115
|
def npy_data_check(n_value, b_value):
|
|
95
116
|
error_message = ""
|
|
96
117
|
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
@@ -109,7 +130,11 @@ def npy_data_check(n_value, b_value):
|
|
|
109
130
|
error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
|
|
110
131
|
|
|
111
132
|
if not error_message:
|
|
112
|
-
|
|
133
|
+
try:
|
|
134
|
+
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
|
|
135
|
+
except CompareException:
|
|
136
|
+
logger.error('Numpy data is unreadable, please check!')
|
|
137
|
+
return True, 'Numpy data is unreadable, please check!'
|
|
113
138
|
# handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
|
|
114
139
|
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
115
140
|
error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
|
|
@@ -144,10 +169,25 @@ def statistics_data_check(result_dict):
|
|
|
144
169
|
class TensorComparisonBasic(abc.ABC):
|
|
145
170
|
"""NPU和bench中npy数据的比较模板"""
|
|
146
171
|
@abc.abstractmethod
|
|
147
|
-
def apply(self, n_value, b_value,
|
|
172
|
+
def apply(self, n_value, b_value, relative_err):
|
|
148
173
|
raise NotImplementedError
|
|
149
174
|
|
|
150
175
|
|
|
176
|
+
def get_relative_err(n_value, b_value):
|
|
177
|
+
"""计算相对误差"""
|
|
178
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
179
|
+
if b_value.dtype not in CompareConst.FLOAT_TYPE:
|
|
180
|
+
n_value, b_value = n_value.astype(float), b_value.astype(float)
|
|
181
|
+
|
|
182
|
+
n_value_copy = n_value.copy()
|
|
183
|
+
b_value_copy = b_value.copy()
|
|
184
|
+
zero_mask = (b_value_copy == 0)
|
|
185
|
+
b_value_copy[zero_mask] += Const.FLOAT_EPSILON
|
|
186
|
+
n_value_copy[zero_mask] += Const.FLOAT_EPSILON
|
|
187
|
+
relative_err = np.divide((n_value_copy - b_value_copy), b_value_copy)
|
|
188
|
+
return np.abs(relative_err)
|
|
189
|
+
|
|
190
|
+
|
|
151
191
|
class GetCosineSimilarity(TensorComparisonBasic):
|
|
152
192
|
"""计算cosine相似度"""
|
|
153
193
|
@staticmethod
|
|
@@ -158,137 +198,67 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
158
198
|
return round(float(result), 6)
|
|
159
199
|
return result
|
|
160
200
|
|
|
161
|
-
def apply(self, n_value, b_value,
|
|
162
|
-
if error_flag:
|
|
163
|
-
if n_value == CompareConst.READ_NONE:
|
|
164
|
-
return CompareConst.NONE, ''
|
|
165
|
-
if n_value == CompareConst.NONE:
|
|
166
|
-
return CompareConst.UNSUPPORTED, ''
|
|
167
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
168
|
-
return CompareConst.SHAPE_UNMATCH, ''
|
|
169
|
-
if n_value == CompareConst.NAN:
|
|
170
|
-
return "N/A", ''
|
|
171
|
-
|
|
201
|
+
def apply(self, n_value, b_value, relative_err):
|
|
172
202
|
if not n_value.shape:
|
|
173
|
-
return CompareConst.UNSUPPORTED,
|
|
203
|
+
return CompareConst.UNSUPPORTED, ""
|
|
174
204
|
|
|
175
|
-
with np.errstate(divide=
|
|
205
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
176
206
|
if len(n_value) == 1:
|
|
177
|
-
return CompareConst.UNSUPPORTED, "This tensor
|
|
207
|
+
return CompareConst.UNSUPPORTED, "This is a 1-d tensor of length 1."
|
|
178
208
|
num = n_value.dot(b_value)
|
|
179
209
|
a_norm = np.linalg.norm(n_value)
|
|
180
210
|
b_norm = np.linalg.norm(b_value)
|
|
181
211
|
|
|
182
212
|
if a_norm <= Const.FLOAT_EPSILON and b_norm <= Const.FLOAT_EPSILON:
|
|
183
|
-
return 1.0,
|
|
213
|
+
return 1.0, ""
|
|
184
214
|
if a_norm <= Const.FLOAT_EPSILON:
|
|
185
|
-
return CompareConst.NAN,
|
|
215
|
+
return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in npu dump data."
|
|
186
216
|
if b_norm <= Const.FLOAT_EPSILON:
|
|
187
|
-
return CompareConst.NAN,
|
|
217
|
+
return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data."
|
|
188
218
|
|
|
189
219
|
cos = num / (a_norm * b_norm)
|
|
190
220
|
if np.isnan(cos):
|
|
191
|
-
return CompareConst.NAN,
|
|
221
|
+
return CompareConst.NAN, "Cannot compare by Cosine Similarity, the dump data has NaN."
|
|
192
222
|
result = format_value(cos)
|
|
193
223
|
result = self.correct_data(result)
|
|
194
|
-
return
|
|
224
|
+
return result, ""
|
|
195
225
|
|
|
196
226
|
|
|
197
227
|
class GetMaxAbsErr(TensorComparisonBasic):
|
|
198
228
|
"""计算最大绝对误差"""
|
|
199
|
-
def apply(self, n_value, b_value,
|
|
200
|
-
if error_flag:
|
|
201
|
-
if n_value == CompareConst.READ_NONE:
|
|
202
|
-
return CompareConst.NONE, ""
|
|
203
|
-
if n_value == CompareConst.NONE:
|
|
204
|
-
return 0, ""
|
|
205
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
206
|
-
return CompareConst.SHAPE_UNMATCH, ""
|
|
207
|
-
if n_value == CompareConst.NAN:
|
|
208
|
-
return "N/A", ""
|
|
209
|
-
|
|
229
|
+
def apply(self, n_value, b_value, relative_err):
|
|
210
230
|
temp_res = n_value - b_value
|
|
211
231
|
max_value = np.max(np.abs(temp_res))
|
|
232
|
+
if np.isnan(max_value):
|
|
233
|
+
msg = "Cannot compare by MaxAbsError, the data contains nan/inf/-inf in dump data."
|
|
234
|
+
return CompareConst.NAN, msg
|
|
212
235
|
return format_value(max_value), ""
|
|
213
236
|
|
|
214
237
|
|
|
215
|
-
def get_relative_err(n_value, b_value):
|
|
216
|
-
"""计算相对误差"""
|
|
217
|
-
with np.errstate(divide='ignore', invalid='ignore'):
|
|
218
|
-
if b_value.dtype not in CompareConst.FLOAT_TYPE:
|
|
219
|
-
n_value, b_value = n_value.astype(float), b_value.astype(float)
|
|
220
|
-
zero_mask = (b_value == 0)
|
|
221
|
-
b_value[zero_mask] += np.finfo(b_value.dtype).eps
|
|
222
|
-
n_value[zero_mask] += np.finfo(b_value.dtype).eps
|
|
223
|
-
relative_err = np.divide((n_value - b_value), b_value)
|
|
224
|
-
return np.abs(relative_err)
|
|
225
|
-
|
|
226
|
-
|
|
227
238
|
class GetMaxRelativeErr(TensorComparisonBasic):
|
|
228
239
|
"""计算最大相对误差"""
|
|
229
|
-
def apply(self, n_value, b_value,
|
|
230
|
-
if error_flag:
|
|
231
|
-
if n_value == CompareConst.READ_NONE:
|
|
232
|
-
return CompareConst.NONE, ''
|
|
233
|
-
if n_value == CompareConst.NONE:
|
|
234
|
-
return 0, ''
|
|
235
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
236
|
-
return CompareConst.SHAPE_UNMATCH, ''
|
|
237
|
-
if n_value == CompareConst.NAN:
|
|
238
|
-
return "N/A", ''
|
|
239
|
-
|
|
240
|
-
if relative_err is None:
|
|
241
|
-
relative_err = get_relative_err(n_value, b_value)
|
|
240
|
+
def apply(self, n_value, b_value, relative_err):
|
|
242
241
|
max_relative_err = np.max(np.abs(relative_err))
|
|
243
242
|
if np.isnan(max_relative_err):
|
|
244
|
-
|
|
245
|
-
return CompareConst.NAN,
|
|
246
|
-
return format_value(max_relative_err),
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
class GetThousandErrRatio(TensorComparisonBasic):
|
|
250
|
-
"""计算相对误差小于千分之一的比例"""
|
|
251
|
-
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
252
|
-
if error_flag:
|
|
253
|
-
if n_value == CompareConst.READ_NONE:
|
|
254
|
-
return CompareConst.NONE, ""
|
|
255
|
-
if n_value == CompareConst.NONE:
|
|
256
|
-
return 0, ""
|
|
257
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
258
|
-
return CompareConst.SHAPE_UNMATCH, ""
|
|
259
|
-
if n_value == CompareConst.NAN:
|
|
260
|
-
return "N/A", ""
|
|
243
|
+
msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
|
|
244
|
+
return CompareConst.NAN, msg
|
|
245
|
+
return format_value(max_relative_err), ""
|
|
261
246
|
|
|
262
|
-
if not n_value.shape:
|
|
263
|
-
return CompareConst.NAN, ""
|
|
264
|
-
if relative_err is None:
|
|
265
|
-
relative_err = get_relative_err(n_value, b_value)
|
|
266
|
-
if not np.size(relative_err):
|
|
267
|
-
return CompareConst.NAN, ""
|
|
268
|
-
return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
class GetFiveThousandErrRatio(TensorComparisonBasic):
|
|
272
|
-
"""计算相对误差小于千分之五的比例"""
|
|
273
|
-
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
274
|
-
if error_flag:
|
|
275
|
-
if n_value == CompareConst.READ_NONE:
|
|
276
|
-
return CompareConst.NONE, ""
|
|
277
|
-
if n_value == CompareConst.NONE:
|
|
278
|
-
return 0, ""
|
|
279
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
280
|
-
return CompareConst.SHAPE_UNMATCH, ""
|
|
281
|
-
if n_value == CompareConst.NAN:
|
|
282
|
-
return "N/A", ""
|
|
283
247
|
|
|
248
|
+
class GetErrRatio(TensorComparisonBasic):
|
|
249
|
+
"""计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
|
|
250
|
+
def __init__(self, threshold):
|
|
251
|
+
self.threshold = threshold
|
|
252
|
+
|
|
253
|
+
def apply(self, n_value, b_value, relative_err):
|
|
284
254
|
if not n_value.shape:
|
|
285
|
-
return CompareConst.
|
|
286
|
-
|
|
287
|
-
relative_err = get_relative_err(n_value, b_value)
|
|
255
|
+
return CompareConst.UNSUPPORTED, ""
|
|
256
|
+
|
|
288
257
|
if not np.size(relative_err):
|
|
289
258
|
return CompareConst.NAN, ""
|
|
290
|
-
|
|
291
|
-
|
|
259
|
+
|
|
260
|
+
ratio = np.sum(relative_err < self.threshold) / np.size(relative_err)
|
|
261
|
+
return format_value(ratio), ""
|
|
292
262
|
|
|
293
263
|
|
|
294
264
|
class CompareOps:
|
|
@@ -296,15 +266,36 @@ class CompareOps:
|
|
|
296
266
|
"cosine_similarity": GetCosineSimilarity(),
|
|
297
267
|
"max_abs_error": GetMaxAbsErr(),
|
|
298
268
|
"max_relative_error": GetMaxRelativeErr(),
|
|
299
|
-
"one_thousand_err_ratio":
|
|
300
|
-
"five_thousand_err_ratio":
|
|
269
|
+
"one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
|
|
270
|
+
"five_thousand_err_ratio": GetErrRatio(CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD)
|
|
301
271
|
}
|
|
302
272
|
|
|
303
273
|
|
|
304
|
-
def
|
|
274
|
+
def error_value_process(n_value):
|
|
275
|
+
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
276
|
+
return CompareConst.UNSUPPORTED, ""
|
|
277
|
+
if n_value == CompareConst.NONE:
|
|
278
|
+
return 0, ""
|
|
279
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
280
|
+
return CompareConst.SHAPE_UNMATCH, ""
|
|
281
|
+
if n_value == CompareConst.NAN:
|
|
282
|
+
return CompareConst.N_A, ""
|
|
283
|
+
return CompareConst.N_A, ""
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def compare_ops_apply(n_value, b_value, error_flag, err_msg):
|
|
305
287
|
result_list = []
|
|
288
|
+
if error_flag:
|
|
289
|
+
result, msg = error_value_process(n_value)
|
|
290
|
+
result_list = [result] * len(CompareOps.compare_ops)
|
|
291
|
+
err_msg += msg * len(CompareOps.compare_ops)
|
|
292
|
+
return result_list, err_msg
|
|
293
|
+
|
|
294
|
+
relative_err = get_relative_err(n_value, b_value)
|
|
295
|
+
n_value, b_value = reshape_value(n_value, b_value)
|
|
296
|
+
|
|
306
297
|
for op in CompareOps.compare_ops.values():
|
|
307
|
-
result, msg = op.apply(n_value, b_value,
|
|
308
|
-
err_msg += msg
|
|
298
|
+
result, msg = op.apply(n_value, b_value, relative_err)
|
|
309
299
|
result_list.append(result)
|
|
300
|
+
err_msg += msg
|
|
310
301
|
return result_list, err_msg
|