mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
msprobe/core/compare/check.py
CHANGED
|
@@ -14,117 +14,46 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.log import logger
|
|
17
|
-
from msprobe.core.compare.utils import rename_api
|
|
18
17
|
from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
|
|
19
|
-
from msprobe.core.common.const import
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
"Int8": "
|
|
23
|
-
"
|
|
24
|
-
"
|
|
25
|
-
"
|
|
26
|
-
"
|
|
27
|
-
"
|
|
28
|
-
"
|
|
29
|
-
"
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
-
"
|
|
33
|
-
"
|
|
34
|
-
"
|
|
35
|
-
"
|
|
36
|
-
"
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
|
|
20
|
+
cross_dtype_mapping = {
|
|
21
|
+
"Int8": "int",
|
|
22
|
+
"torch.int8": "int",
|
|
23
|
+
"UInt8": "int",
|
|
24
|
+
"torch.uint8": "int",
|
|
25
|
+
"Int16": "int",
|
|
26
|
+
"torch.int16": "int",
|
|
27
|
+
"UInt16": "int",
|
|
28
|
+
"torch.uint16": "int",
|
|
29
|
+
"Int32": "int",
|
|
30
|
+
"torch.int32": "int",
|
|
31
|
+
"UInt32": "int",
|
|
32
|
+
"torch.uint32": "int",
|
|
33
|
+
"Int64": "int",
|
|
34
|
+
"torch.int64": "int",
|
|
35
|
+
"UInt64": "int",
|
|
36
|
+
"torch.uint64": "int",
|
|
37
|
+
|
|
38
|
+
"Float16": "float",
|
|
39
|
+
"torch.float16": "float",
|
|
40
|
+
"Float32": "float",
|
|
41
|
+
"torch.float32": "float",
|
|
42
|
+
"Float64": "float",
|
|
43
|
+
"torch.float64": "float",
|
|
44
|
+
"BFloat16": "float",
|
|
45
|
+
"torch.bfloat16": "float",
|
|
46
|
+
|
|
47
|
+
"Bool": "bool",
|
|
48
|
+
"torch.bool": "bool",
|
|
49
|
+
|
|
50
|
+
"Complex64": "complex",
|
|
51
|
+
"torch.complex64": "complex",
|
|
52
|
+
"Complex128": "complex",
|
|
53
|
+
"torch.complex128": "complex",
|
|
37
54
|
}
|
|
38
55
|
|
|
39
56
|
|
|
40
|
-
def compare_op_dict_struct(npu_dict, bench_dict):
|
|
41
|
-
return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def check_struct_match(npu_dict, bench_dict):
|
|
45
|
-
is_match = compare_op_dict_struct(npu_dict, bench_dict)
|
|
46
|
-
if not is_match:
|
|
47
|
-
struct_match_list = []
|
|
48
|
-
try:
|
|
49
|
-
for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY):
|
|
50
|
-
# 首先额外检查input_struct是否空,input_struct不可能为空
|
|
51
|
-
if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])):
|
|
52
|
-
return False
|
|
53
|
-
struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, [])))
|
|
54
|
-
except CompareException as error:
|
|
55
|
-
err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
|
|
56
|
-
f'npu_dict: {npu_dict}' \
|
|
57
|
-
f'bench_dict: {bench_dict}'
|
|
58
|
-
logger.error(err_msg)
|
|
59
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
60
|
-
is_match = all(struct_match_list)
|
|
61
|
-
return is_match
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def check_type_shape_match(npu_struct, bench_struct):
|
|
65
|
-
"""
|
|
66
|
-
further check dtypes with a dtype mapping list when dtypes are not entirely consistent.
|
|
67
|
-
"""
|
|
68
|
-
if len(npu_struct) != len(bench_struct):
|
|
69
|
-
return False
|
|
70
|
-
if not npu_struct and not bench_struct:
|
|
71
|
-
return True
|
|
72
|
-
|
|
73
|
-
struct_match = False
|
|
74
|
-
for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
|
|
75
|
-
try:
|
|
76
|
-
npu_type = npu_type_shape[0]
|
|
77
|
-
npu_shape = npu_type_shape[1]
|
|
78
|
-
bench_type = bench_type_shape[0]
|
|
79
|
-
bench_shape = bench_type_shape[1]
|
|
80
|
-
except IndexError as error:
|
|
81
|
-
logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} '
|
|
82
|
-
f'should both be 2, please check!')
|
|
83
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
84
|
-
shape_match = npu_shape == bench_shape
|
|
85
|
-
type_match = npu_type == bench_type
|
|
86
|
-
if not type_match:
|
|
87
|
-
if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
|
|
88
|
-
type_match = True
|
|
89
|
-
else:
|
|
90
|
-
type_match = False
|
|
91
|
-
struct_match = shape_match and type_match
|
|
92
|
-
if not struct_match:
|
|
93
|
-
return False
|
|
94
|
-
return struct_match
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def check_graph_mode(a_op_name, b_op_name):
|
|
98
|
-
if Const.ATEN in a_op_name and Const.ATEN not in b_op_name:
|
|
99
|
-
return True
|
|
100
|
-
if Const.ATEN not in a_op_name and Const.ATEN in b_op_name:
|
|
101
|
-
return True
|
|
102
|
-
return False
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def fuzzy_check_op(npu_name_list, bench_name_list):
|
|
106
|
-
# 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0
|
|
107
|
-
# 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1
|
|
108
|
-
if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
|
|
109
|
-
return False
|
|
110
|
-
is_match = True
|
|
111
|
-
for npu_name, bench_name in zip(npu_name_list, bench_name_list):
|
|
112
|
-
is_match = fuzzy_check_name(npu_name, bench_name)
|
|
113
|
-
if not is_match:
|
|
114
|
-
break
|
|
115
|
-
return is_match
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def fuzzy_check_name(npu_name, bench_name):
|
|
119
|
-
if Const.FORWARD in npu_name and Const.FORWARD in bench_name:
|
|
120
|
-
is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD)
|
|
121
|
-
elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name:
|
|
122
|
-
is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD)
|
|
123
|
-
else:
|
|
124
|
-
is_match = npu_name == bench_name
|
|
125
|
-
return is_match
|
|
126
|
-
|
|
127
|
-
|
|
128
57
|
def check_dump_json_str(op_data, op_name):
|
|
129
58
|
input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get(
|
|
130
59
|
Const.INPUT, None)
|
|
@@ -38,6 +38,7 @@ def compare_cli(args):
|
|
|
38
38
|
else:
|
|
39
39
|
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
40
40
|
from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
|
|
41
|
+
from msprobe.mindspore.compare.common_dir_compare import common_dir_compare
|
|
41
42
|
|
|
42
43
|
common_kwargs = {
|
|
43
44
|
"auto_analyze": auto_analyze,
|
|
@@ -78,6 +79,9 @@ def compare_cli(args):
|
|
|
78
79
|
if input_param.get("rank_id") is not None:
|
|
79
80
|
ms_graph_compare(input_param, args.output_path)
|
|
80
81
|
return
|
|
82
|
+
if input_param.get('common', False):
|
|
83
|
+
common_dir_compare(input_param, args.output_path)
|
|
84
|
+
return
|
|
81
85
|
if frame_name == Const.PT_FRAMEWORK:
|
|
82
86
|
compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
|
|
83
87
|
else:
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
19
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ModeConfig:
|
|
23
|
+
def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.SUMMARY,
|
|
24
|
+
compared_file_type=Const.DUMP_JSON_FILE):
|
|
25
|
+
self.stack_mode = stack_mode
|
|
26
|
+
self.auto_analyze = auto_analyze
|
|
27
|
+
self.fuzzy_match = fuzzy_match
|
|
28
|
+
self.dump_mode = dump_mode
|
|
29
|
+
self.compared_file_type = compared_file_type
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MappingConfig:
|
|
33
|
+
def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
|
|
34
|
+
self.cell_mapping = cell_mapping
|
|
35
|
+
self.api_mapping = api_mapping
|
|
36
|
+
self.data_mapping = data_mapping
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MappingDict:
|
|
40
|
+
def __init__(self, mapping_config: MappingConfig):
|
|
41
|
+
self.cell_mapping_dict = self.load_mapping_file(mapping_config.cell_mapping)
|
|
42
|
+
self.api_mapping_dict = self.load_mapping_file(mapping_config.api_mapping)
|
|
43
|
+
if mapping_config.api_mapping is not None:
|
|
44
|
+
self.ms_to_pt_mapping = self.load_internal_api()
|
|
45
|
+
self.data_mapping_dict = self.init_data_mapping(mapping_config.data_mapping)
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def load_internal_api():
|
|
49
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
50
|
+
yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
|
|
51
|
+
return load_yaml(yaml_path)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def load_mapping_file(mapping_file):
|
|
55
|
+
if isinstance(mapping_file, str):
|
|
56
|
+
mapping_dict = load_yaml(mapping_file)
|
|
57
|
+
else:
|
|
58
|
+
mapping_dict = {}
|
|
59
|
+
return mapping_dict
|
|
60
|
+
|
|
61
|
+
def init_data_mapping(self, data_mapping):
|
|
62
|
+
"""
|
|
63
|
+
初始化data_mapping_dict
|
|
64
|
+
"""
|
|
65
|
+
if isinstance(data_mapping, str) or data_mapping is None:
|
|
66
|
+
data_mapping_dict = self.load_mapping_file(data_mapping)
|
|
67
|
+
elif isinstance(data_mapping, dict):
|
|
68
|
+
data_mapping_dict = data_mapping
|
|
69
|
+
else:
|
|
70
|
+
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
71
|
+
f"{type(data_mapping)}")
|
|
72
|
+
return data_mapping_dict
|