mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
msprobe/core/compare/check.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,8 +16,7 @@
|
|
|
16
16
|
from msprobe.core.common.log import logger
|
|
17
17
|
from msprobe.core.compare.utils import rename_api
|
|
18
18
|
from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
|
|
19
|
-
from msprobe.core.common.const import Const
|
|
20
|
-
|
|
19
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
21
20
|
|
|
22
21
|
dtype_mapping = {
|
|
23
22
|
"Int8": "torch.int8",
|
|
@@ -35,37 +34,43 @@ dtype_mapping = {
|
|
|
35
34
|
"BFloat16": "torch.bfloat16",
|
|
36
35
|
"Complex64": "torch.complex64",
|
|
37
36
|
"Complex128": "torch.complex128"
|
|
38
|
-
|
|
37
|
+
}
|
|
38
|
+
|
|
39
39
|
|
|
40
|
+
def compare_op_dict_struct(npu_dict, bench_dict):
|
|
41
|
+
return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY)
|
|
40
42
|
|
|
41
|
-
def check_struct_match(npu_dict, bench_dict, cross_frame=False):
|
|
42
|
-
npu_struct_in = npu_dict.get("input_struct")
|
|
43
|
-
bench_struct_in = bench_dict.get("input_struct")
|
|
44
|
-
npu_struct_out = npu_dict.get("output_struct")
|
|
45
|
-
bench_struct_out = bench_dict.get("output_struct")
|
|
46
43
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
|
|
50
|
-
is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
|
|
44
|
+
def check_struct_match(npu_dict, bench_dict):
|
|
45
|
+
is_match = compare_op_dict_struct(npu_dict, bench_dict)
|
|
51
46
|
if not is_match:
|
|
52
|
-
|
|
53
|
-
return False
|
|
47
|
+
struct_match_list = []
|
|
54
48
|
try:
|
|
55
|
-
|
|
56
|
-
|
|
49
|
+
for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY):
|
|
50
|
+
# 首先额外检查input_struct是否空,input_struct不可能为空
|
|
51
|
+
if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])):
|
|
52
|
+
return False
|
|
53
|
+
struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, [])))
|
|
57
54
|
except CompareException as error:
|
|
58
55
|
err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
|
|
59
56
|
f'npu_dict: {npu_dict}' \
|
|
60
57
|
f'bench_dict: {bench_dict}'
|
|
61
58
|
logger.error(err_msg)
|
|
62
59
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
63
|
-
is_match =
|
|
60
|
+
is_match = all(struct_match_list)
|
|
64
61
|
return is_match
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
def check_type_shape_match(npu_struct, bench_struct):
|
|
68
|
-
|
|
65
|
+
"""
|
|
66
|
+
further check dtypes with a dtype mapping list when dtypes are not entirely consistent.
|
|
67
|
+
"""
|
|
68
|
+
if len(npu_struct) != len(bench_struct):
|
|
69
|
+
return False
|
|
70
|
+
if not npu_struct and not bench_struct:
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
struct_match = False
|
|
69
74
|
for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
|
|
70
75
|
try:
|
|
71
76
|
npu_type = npu_type_shape[0]
|
|
@@ -79,22 +84,14 @@ def check_type_shape_match(npu_struct, bench_struct):
|
|
|
79
84
|
shape_match = npu_shape == bench_shape
|
|
80
85
|
type_match = npu_type == bench_type
|
|
81
86
|
if not type_match:
|
|
82
|
-
|
|
83
|
-
[Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
|
|
84
|
-
[Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
|
|
85
|
-
]
|
|
86
|
-
torch_type = [
|
|
87
|
-
[Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
|
|
88
|
-
[Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
|
|
89
|
-
]
|
|
90
|
-
if ([npu_type, bench_type] in ms_type) or ([npu_type, bench_type] in torch_type):
|
|
87
|
+
if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
|
|
91
88
|
type_match = True
|
|
92
89
|
else:
|
|
93
90
|
type_match = False
|
|
94
|
-
|
|
95
|
-
if not
|
|
91
|
+
struct_match = shape_match and type_match
|
|
92
|
+
if not struct_match:
|
|
96
93
|
return False
|
|
97
|
-
return
|
|
94
|
+
return struct_match
|
|
98
95
|
|
|
99
96
|
|
|
100
97
|
def check_graph_mode(a_op_name, b_op_name):
|
|
@@ -106,6 +103,8 @@ def check_graph_mode(a_op_name, b_op_name):
|
|
|
106
103
|
|
|
107
104
|
|
|
108
105
|
def fuzzy_check_op(npu_name_list, bench_name_list):
|
|
106
|
+
# 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0
|
|
107
|
+
# 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1
|
|
109
108
|
if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
|
|
110
109
|
return False
|
|
111
110
|
is_match = True
|
|
@@ -151,11 +150,11 @@ def check_json_key_value(input_output, op_name, depth=0):
|
|
|
151
150
|
return
|
|
152
151
|
if isinstance(input_output, list):
|
|
153
152
|
for item in input_output:
|
|
154
|
-
check_json_key_value(item, op_name, depth+1)
|
|
153
|
+
check_json_key_value(item, op_name, depth + 1)
|
|
155
154
|
elif isinstance(input_output, dict):
|
|
156
155
|
for key, value in input_output.items():
|
|
157
156
|
if isinstance(value, dict):
|
|
158
|
-
check_json_key_value(value, op_name, depth+1)
|
|
157
|
+
check_json_key_value(value, op_name, depth + 1)
|
|
159
158
|
else:
|
|
160
159
|
valid_key_value(key, value, op_name)
|
|
161
160
|
|
|
@@ -14,17 +14,22 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
from msprobe.core.common.file_utils import
|
|
17
|
+
from msprobe.core.common.file_utils import check_file_type, load_json
|
|
18
18
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
19
19
|
from msprobe.core.common.utils import CompareException
|
|
20
20
|
from msprobe.core.common.log import logger
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def compare_cli(args):
|
|
24
|
-
|
|
25
|
-
input_param = json.load(file)
|
|
24
|
+
input_param = load_json(args.input_path)
|
|
26
25
|
npu_path = input_param.get("npu_path", None)
|
|
27
26
|
bench_path = input_param.get("bench_path", None)
|
|
27
|
+
if not npu_path:
|
|
28
|
+
logger.error(f"Missing npu_path in configuration file {args.input_path}, please check!")
|
|
29
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
30
|
+
if not bench_path:
|
|
31
|
+
logger.error(f"Missing bench_path in configuration file {args.input_path}, please check!")
|
|
32
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
28
33
|
frame_name = args.framework
|
|
29
34
|
auto_analyze = not args.compare_only
|
|
30
35
|
if frame_name == Const.PT_FRAMEWORK:
|
|
@@ -33,30 +38,43 @@ def compare_cli(args):
|
|
|
33
38
|
else:
|
|
34
39
|
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
35
40
|
from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
|
|
41
|
+
|
|
42
|
+
common_kwargs = {
|
|
43
|
+
"auto_analyze": auto_analyze,
|
|
44
|
+
"fuzzy_match": args.fuzzy_match,
|
|
45
|
+
"data_mapping": args.data_mapping,
|
|
46
|
+
}
|
|
47
|
+
|
|
36
48
|
if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
|
|
37
49
|
input_param["npu_json_path"] = input_param.pop("npu_path")
|
|
38
50
|
input_param["bench_json_path"] = input_param.pop("bench_path")
|
|
39
|
-
|
|
51
|
+
if "stack_path" not in input_param:
|
|
52
|
+
logger.warning(f"Missing stack_path in the configuration file. "
|
|
53
|
+
f"Automatically detecting stack.json to determine whether to display NPU_Stack_Info.")
|
|
54
|
+
else:
|
|
55
|
+
input_param["stack_json_path"] = input_param.pop("stack_path")
|
|
56
|
+
|
|
40
57
|
if frame_name == Const.PT_FRAMEWORK:
|
|
41
|
-
kwargs = {
|
|
42
|
-
|
|
43
|
-
}
|
|
44
|
-
compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
|
|
45
|
-
fuzzy_match=args.fuzzy_match, **kwargs)
|
|
58
|
+
kwargs = {**common_kwargs, "stack_mode": args.stack_mode}
|
|
59
|
+
compare(input_param, args.output_path, **kwargs)
|
|
46
60
|
else:
|
|
47
61
|
kwargs = {
|
|
62
|
+
**common_kwargs,
|
|
48
63
|
"stack_mode": args.stack_mode,
|
|
49
|
-
"auto_analyze": auto_analyze,
|
|
50
|
-
"fuzzy_match": args.fuzzy_match,
|
|
51
64
|
"cell_mapping": args.cell_mapping,
|
|
52
65
|
"api_mapping": args.api_mapping,
|
|
53
|
-
"data_mapping": args.data_mapping,
|
|
54
66
|
"layer_mapping": args.layer_mapping
|
|
55
67
|
}
|
|
56
|
-
|
|
57
68
|
ms_compare(input_param, args.output_path, **kwargs)
|
|
58
69
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
59
|
-
kwargs = {
|
|
70
|
+
kwargs = {
|
|
71
|
+
**common_kwargs,
|
|
72
|
+
"stack_mode": args.stack_mode,
|
|
73
|
+
"is_print_compare_log": input_param.get("is_print_compare_log", True),
|
|
74
|
+
"cell_mapping": args.cell_mapping,
|
|
75
|
+
"api_mapping": args.api_mapping,
|
|
76
|
+
"layer_mapping": args.layer_mapping
|
|
77
|
+
}
|
|
60
78
|
if input_param.get("rank_id") is not None:
|
|
61
79
|
ms_graph_compare(input_param, args.output_path)
|
|
62
80
|
return
|