mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from mindspore.common import dtype as mstype
|
|
2
|
+
import numpy as np
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
INT8 = "Int8"
|
|
7
|
+
UINT8 = "UInt8"
|
|
8
|
+
INT16 = "Int16"
|
|
9
|
+
UINT16 = "UInt16"
|
|
10
|
+
INT32 = "Int32"
|
|
11
|
+
UINT32 = "UInt32"
|
|
12
|
+
INT64 = "Int64"
|
|
13
|
+
UINT64 = "UInt64"
|
|
14
|
+
FLOAT16 = "Float16"
|
|
15
|
+
FLOAT32 = "Float32"
|
|
16
|
+
FLOAT64 = "Float64"
|
|
17
|
+
BOOL = "Bool"
|
|
18
|
+
BFLOAT16 = "BFloat16"
|
|
19
|
+
INT4 = "Int4"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
dtype_str_to_ms_dtype = {
|
|
23
|
+
INT8: mstype.int8,
|
|
24
|
+
UINT8: mstype.uint8,
|
|
25
|
+
INT16: mstype.int16,
|
|
26
|
+
UINT16: mstype.uint16,
|
|
27
|
+
INT32: mstype.int32,
|
|
28
|
+
UINT32: mstype.uint32,
|
|
29
|
+
INT64: mstype.int64,
|
|
30
|
+
UINT64: mstype.uint64,
|
|
31
|
+
FLOAT16: mstype.float16,
|
|
32
|
+
FLOAT32: mstype.float32,
|
|
33
|
+
FLOAT64: mstype.float64,
|
|
34
|
+
BOOL: mstype.bool_,
|
|
35
|
+
BFLOAT16: mstype.bfloat16,
|
|
36
|
+
INT4: mstype.qint4x2
|
|
37
|
+
}
|
|
38
|
+
ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
dtype_str_to_np_dtype = {
|
|
42
|
+
INT8: np.int8,
|
|
43
|
+
UINT8: np.uint8,
|
|
44
|
+
INT16: np.int16,
|
|
45
|
+
UINT16: np.uint16,
|
|
46
|
+
INT32: np.int32,
|
|
47
|
+
UINT32: np.uint32,
|
|
48
|
+
INT64: np.int64,
|
|
49
|
+
UINT64: np.uint64,
|
|
50
|
+
FLOAT16: np.float16,
|
|
51
|
+
FLOAT32: np.float32,
|
|
52
|
+
FLOAT64: np.float64,
|
|
53
|
+
BOOL: np.bool_
|
|
54
|
+
}
|
|
55
|
+
np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()}
|
|
56
|
+
|
|
57
|
+
dtype_str_to_torch_dtype = {
|
|
58
|
+
INT8: torch.int8,
|
|
59
|
+
UINT8: torch.uint8,
|
|
60
|
+
INT16: torch.int16,
|
|
61
|
+
INT32: torch.int32,
|
|
62
|
+
INT64: torch.int64,
|
|
63
|
+
FLOAT16: torch.float16,
|
|
64
|
+
FLOAT32: torch.float32,
|
|
65
|
+
FLOAT64: torch.float64,
|
|
66
|
+
BOOL: torch.bool,
|
|
67
|
+
BFLOAT16: torch.bfloat16,
|
|
68
|
+
}
|
|
69
|
+
torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
|
|
70
|
+
|
|
71
|
+
MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
|
|
72
|
+
BOOL_TYPE_STR = "bool"
|
|
73
|
+
INT_TYPE_STR = "int"
|
|
74
|
+
FLOAT_TYPE_STR = "float"
|
|
75
|
+
SLICE_TYPE_STR = "slice"
|
|
76
|
+
TUPLE_TYPE_STR = "tuple"
|
|
77
|
+
STR_TYPE_STR = "str"
|
|
78
|
+
|
|
79
|
+
api_info_type_str_to_type = {
|
|
80
|
+
MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
|
|
81
|
+
BOOL_TYPE_STR: bool,
|
|
82
|
+
INT_TYPE_STR: int,
|
|
83
|
+
FLOAT_TYPE_STR: float,
|
|
84
|
+
SLICE_TYPE_STR: slice,
|
|
85
|
+
STR_TYPE_STR: str,
|
|
86
|
+
}
|
|
87
|
+
type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
|
|
88
|
+
|
|
89
|
+
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64
|
|
90
|
+
DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64
|
|
91
|
+
DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64
|
|
92
|
+
|
|
93
|
+
float_dtype_str_list = [
|
|
94
|
+
FLOAT16,
|
|
95
|
+
FLOAT32,
|
|
96
|
+
FLOAT64,
|
|
97
|
+
BFLOAT16,
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
int_dtype_str_list = [
|
|
101
|
+
INT8,
|
|
102
|
+
INT16,
|
|
103
|
+
INT32,
|
|
104
|
+
INT64,
|
|
105
|
+
BOOL,
|
|
106
|
+
INT4,
|
|
107
|
+
]
|
|
108
|
+
|
|
109
|
+
uint_dtype_str_list = [
|
|
110
|
+
UINT8,
|
|
111
|
+
UINT16,
|
|
112
|
+
UINT32,
|
|
113
|
+
UINT64,
|
|
114
|
+
]
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
2
|
+
from msprobe.core.common.log import logger
|
|
3
|
+
|
|
4
|
+
def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
|
|
5
|
+
'''
|
|
6
|
+
Args:
|
|
7
|
+
dict_instance: dict, dict parsed from input json
|
|
8
|
+
key: str
|
|
9
|
+
key_description: str
|
|
10
|
+
accepted_type: tuple
|
|
11
|
+
accepted_value: Union[tuple, list]
|
|
12
|
+
|
|
13
|
+
Return:
|
|
14
|
+
value, the corresponding value of "key" in "dict_instance"
|
|
15
|
+
|
|
16
|
+
Exception:
|
|
17
|
+
raise ApiAccuracyCheckerException.ParseJsonFailed error when
|
|
18
|
+
1. dict_instance is not a dict
|
|
19
|
+
2. value is None
|
|
20
|
+
3. value is not accepted type
|
|
21
|
+
4. value is not accepted value
|
|
22
|
+
'''
|
|
23
|
+
parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
|
|
24
|
+
if not isinstance(dict_instance, dict):
|
|
25
|
+
logger.error_log_with_exp("check_and_get_from_json_dict failed: input is not a dict", parse_failed_exception)
|
|
26
|
+
value = dict_instance.get(key)
|
|
27
|
+
if value is None:
|
|
28
|
+
logger.error_log_with_exp(f"check_and_get_from_json_dict failed: {key_description} is missing",
|
|
29
|
+
parse_failed_exception)
|
|
30
|
+
elif accepted_type is not None and not isinstance(value, accepted_type):
|
|
31
|
+
logger.error_log_with_exp(
|
|
32
|
+
f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}",
|
|
33
|
+
parse_failed_exception)
|
|
34
|
+
elif accepted_value is not None and value not in accepted_value:
|
|
35
|
+
logger.error_log_with_exp(
|
|
36
|
+
f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}",
|
|
37
|
+
parse_failed_exception)
|
|
38
|
+
return value
|
|
39
|
+
|
|
40
|
+
def convert_to_tuple(input):
|
|
41
|
+
if isinstance(input, (tuple, list)):
|
|
42
|
+
return tuple(input)
|
|
43
|
+
else:
|
|
44
|
+
input_list = [input]
|
|
45
|
+
return tuple(input_list)
|
|
46
|
+
|
|
47
|
+
class GlobalContext:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.is_constructed = True
|
|
50
|
+
self.dump_data_dir = ""
|
|
51
|
+
|
|
52
|
+
def init(self, is_constructed, dump_data_dir):
|
|
53
|
+
self.is_constructed = is_constructed
|
|
54
|
+
self.dump_data_dir = dump_data_dir
|
|
55
|
+
|
|
56
|
+
def get_dump_data_dir(self):
|
|
57
|
+
return self.dump_data_dir
|
|
58
|
+
|
|
59
|
+
def get_is_constructed(self):
|
|
60
|
+
return self.is_constructed
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
global_context = GlobalContext()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from msprobe.core.data_dump.scope import ModuleRangeScope
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.mindspore.common.log import logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CellProcessor:
|
|
7
|
+
cell_count = {}
|
|
8
|
+
|
|
9
|
+
def __init__(self, scope):
|
|
10
|
+
if isinstance(scope, ModuleRangeScope):
|
|
11
|
+
self.scope = scope
|
|
12
|
+
else:
|
|
13
|
+
self.scope = None
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def set_cell_count(cell_name):
|
|
17
|
+
if cell_name not in CellProcessor.cell_count:
|
|
18
|
+
CellProcessor.cell_count[cell_name] = 0
|
|
19
|
+
else:
|
|
20
|
+
CellProcessor.cell_count[cell_name] += 1
|
|
21
|
+
return CellProcessor.cell_count[cell_name]
|
|
22
|
+
|
|
23
|
+
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
24
|
+
def begin_hook(cell, input):
|
|
25
|
+
index = self.set_cell_count(name_prefix)
|
|
26
|
+
cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
|
|
27
|
+
if self.scope:
|
|
28
|
+
self.scope.begin_module(full_name)
|
|
29
|
+
|
|
30
|
+
def end_hook(cell, input, output):
|
|
31
|
+
if self.scope:
|
|
32
|
+
self.scope.end_module(cell.mindstudio_reserved_name)
|
|
33
|
+
|
|
34
|
+
return begin_hook if Const.START == start_or_stop else end_hook
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import mindspore as ms
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.const import Const as CoreCost
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Const:
|
|
8
|
+
CELL = "cell"
|
|
9
|
+
API = "api"
|
|
10
|
+
KERNEL = "kernel"
|
|
11
|
+
TOOL_LEVEL_DICT = {
|
|
12
|
+
CoreCost.LEVEL_L0: CELL,
|
|
13
|
+
CoreCost.LEVEL_L1: API,
|
|
14
|
+
CoreCost.LEVEL_L2: KERNEL
|
|
15
|
+
}
|
|
16
|
+
PYNATIVE_MODE = "pynative"
|
|
17
|
+
GRAPH_GE_MODE = "graph_ge"
|
|
18
|
+
GRAPH_KBYK_MODE = "graph_kbyk"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FreeBenchmarkConst:
|
|
22
|
+
DEFAULT_DEVICE = "npu"
|
|
23
|
+
DEFAULT_STAGE = "forward"
|
|
24
|
+
DEFAULT_DUMP_LEVEL = CoreCost.LEVEL_L1
|
|
25
|
+
DEFAULT_PERT_TYPE = "improve_precision"
|
|
26
|
+
DEFAULT_HANDLER_TYPE = "check"
|
|
27
|
+
FIX_HANDLER_MODE = "fix"
|
|
28
|
+
ADD_NOISE = "add_noise"
|
|
29
|
+
BIT_NOISE = "bit_noise"
|
|
30
|
+
NO_CHANGE = "no_change"
|
|
31
|
+
IMPROVE_PRECISION = "improve_precision"
|
|
32
|
+
CHECK = "check"
|
|
33
|
+
FIX = "fix"
|
|
34
|
+
DEVICE_LIST = ["npu"]
|
|
35
|
+
STAGE_LIST = ["forward"]
|
|
36
|
+
DUMP_LEVEL_LIST = [CoreCost.LEVEL_L1]
|
|
37
|
+
PERT_TYPE_LIST = [IMPROVE_PRECISION, ADD_NOISE, BIT_NOISE, NO_CHANGE]
|
|
38
|
+
HANDLER_TYPE_LIST = [CHECK, FIX]
|
|
39
|
+
COMMUNICATION_API_LIST = [
|
|
40
|
+
"mindspore.communication.comm_func.all_gather_into_tensor",
|
|
41
|
+
"mindspore.communication.comm_func.gather_into_tensor",
|
|
42
|
+
"mindspore.communication.comm_func.all_reduce",
|
|
43
|
+
"mindspore.communication.comm_func.reduce",
|
|
44
|
+
"mindspore.communication.comm_func.reduce_scatter_tensor"
|
|
45
|
+
]
|
|
46
|
+
NO_CHANGE_ERROR_THRESHOLD = 1.0
|
|
47
|
+
SYMBOL_FLIPPING_RATIO = 8.0
|
|
48
|
+
OPS_PREFIX = "mindspore.ops."
|
|
49
|
+
Tensor_PREFIX = "mindspore.Tensor."
|
|
50
|
+
MINT_PREFIX = "mindspore.mint."
|
|
51
|
+
MINT_NN_FUNC_PREFIX = "mindspore.mint.nn.functional."
|
|
52
|
+
COMM_PREFIX = "mindspore.communication.comm_func."
|
|
53
|
+
|
|
54
|
+
API_PREFIX_DICT = {
|
|
55
|
+
"ops": OPS_PREFIX,
|
|
56
|
+
"Tensor": Tensor_PREFIX,
|
|
57
|
+
"mint": MINT_PREFIX,
|
|
58
|
+
"mint.nn.functional": MINT_NN_FUNC_PREFIX,
|
|
59
|
+
"communication": COMM_PREFIX
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
PERT_VALUE_DICT = {
|
|
63
|
+
ms.bfloat16: 1e-4,
|
|
64
|
+
ms.float16: 1e-6,
|
|
65
|
+
ms.float32: 1e-8,
|
|
66
|
+
ms.float64: 1e-16
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
ERROR_THRESHOLD = {
|
|
70
|
+
ms.float16: 1.002,
|
|
71
|
+
ms.float32: 1.0002
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
PERT_BIT_DICT = {
|
|
75
|
+
ms.float16: np.int16,
|
|
76
|
+
ms.float32: np.int32,
|
|
77
|
+
ms.float64: np.int64
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
MS_NUMPY_DTYPE_DICT = {
|
|
81
|
+
ms.int16: np.int16,
|
|
82
|
+
ms.int32: np.int32,
|
|
83
|
+
ms.int64: np.int64,
|
|
84
|
+
ms.float16: np.float16,
|
|
85
|
+
ms.float32: np.float32,
|
|
86
|
+
ms.float64: np.float64
|
|
87
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
import sys
|
|
19
|
+
|
|
20
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
21
|
+
from msprobe.core.common.log import BaseLogger
|
|
22
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MindsporeLogger(BaseLogger):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
def get_rank(self):
|
|
30
|
+
try:
|
|
31
|
+
current_rank = get_rank_if_initialized()
|
|
32
|
+
except DistributedNotInitializedError:
|
|
33
|
+
current_rank = None
|
|
34
|
+
|
|
35
|
+
return current_rank
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
logger = MindsporeLogger()
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import mindspore as ms
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
19
|
+
from msprobe.core.common.file_check import path_len_exceeds_limit
|
|
20
|
+
from msprobe.core.common.utils import save_npy
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_rank_if_initialized():
|
|
25
|
+
if ms.communication.GlobalComm.INITED:
|
|
26
|
+
return ms.communication.get_rank()
|
|
27
|
+
else:
|
|
28
|
+
raise DistributedNotInitializedError("mindspore distributed environment is not initialized")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def convert_bf16_to_fp32(tensor):
|
|
32
|
+
if tensor.dtype == ms.bfloat16:
|
|
33
|
+
tensor = tensor.to(ms.float32)
|
|
34
|
+
return tensor
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def save_tensor_as_npy(tensor, file_path):
|
|
38
|
+
if not path_len_exceeds_limit(file_path):
|
|
39
|
+
tensor = convert_bf16_to_fp32(tensor)
|
|
40
|
+
saved_tensor = tensor.asnumpy()
|
|
41
|
+
save_npy(saved_tensor, file_path)
|
|
42
|
+
else:
|
|
43
|
+
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MsprobeStep(ms.train.Callback):
|
|
47
|
+
|
|
48
|
+
def __init__(self, debugger):
|
|
49
|
+
super(MsprobeStep, self).__init__()
|
|
50
|
+
self.debugger = debugger
|
|
51
|
+
|
|
52
|
+
def on_train_step_begin(self, run_context):
|
|
53
|
+
self.debugger.start()
|
|
54
|
+
|
|
55
|
+
def on_train_step_end(self, run_context):
|
|
56
|
+
self.debugger.stop()
|
|
57
|
+
self.debugger.step()
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import os
|
|
18
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, \
|
|
19
|
+
check_configuration_param, task_dumppath_get
|
|
20
|
+
from msprobe.core.common.file_check import create_directory
|
|
21
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
24
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
25
|
+
from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
|
|
26
|
+
|
|
27
|
+
def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
28
|
+
if kwargs.get('suffix'):
|
|
29
|
+
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
30
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
31
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
32
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
33
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
34
|
+
# get the ranks and match by order
|
|
35
|
+
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
36
|
+
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
37
|
+
if len(npu_ranks) != len(bench_ranks):
|
|
38
|
+
logger.error('The number of ranks in the two runs are different. '
|
|
39
|
+
'Unable to match the ranks. Please use another folder to compare '
|
|
40
|
+
'or use compare() api and manually match the ranks.')
|
|
41
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
42
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
43
|
+
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
44
|
+
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
45
|
+
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
46
|
+
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
47
|
+
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
48
|
+
|
|
49
|
+
dump_result_param = {
|
|
50
|
+
'npu_json_path': npu_path,
|
|
51
|
+
'bench_json_path': bench_path,
|
|
52
|
+
'stack_json_path': stack_path,
|
|
53
|
+
'is_print_compare_log': True
|
|
54
|
+
}
|
|
55
|
+
try:
|
|
56
|
+
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
|
|
57
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
58
|
+
create_directory(output_path)
|
|
59
|
+
check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
|
|
60
|
+
except (CompareException, FileCheckException) as error:
|
|
61
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
62
|
+
raise CompareException(error.code) from error
|
|
63
|
+
ms_comparator = MSComparator()
|
|
64
|
+
ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
|
|
65
|
+
md5_compare=md5_compare, **kwargs)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def ms_graph_compare(inputs, outputs):
|
|
69
|
+
try:
|
|
70
|
+
create_directory(outputs)
|
|
71
|
+
except (CompareException, FileCheckException) as error:
|
|
72
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
73
|
+
return
|
|
74
|
+
msComparator = GraphMSComparator(inputs, outputs)
|
|
75
|
+
msComparator.compare_core()
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
|
|
3
|
+
task_dumppath_get, load_yaml, load_npy
|
|
4
|
+
from msprobe.core.common.file_check import create_directory
|
|
5
|
+
from msprobe.core.common.const import Const
|
|
6
|
+
from msprobe.core.common.log import logger
|
|
7
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
8
|
+
from msprobe.core.compare.acc_compare import Comparator
|
|
9
|
+
from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MSComparator(Comparator):
|
|
13
|
+
def __init__(self, cell_mapping=None, api_mapping=None):
|
|
14
|
+
self.frame_name = MSComparator.__name__
|
|
15
|
+
self.cell_mapping = cell_mapping
|
|
16
|
+
self.api_mapping = api_mapping
|
|
17
|
+
self.cross_frame = cell_mapping is not None or api_mapping is not None
|
|
18
|
+
self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
|
|
19
|
+
self.api_mapping_dict = {}
|
|
20
|
+
if api_mapping is not None:
|
|
21
|
+
self.ms_to_pt_mapping = self.load_internal_api()
|
|
22
|
+
|
|
23
|
+
def load_internal_api(self):
|
|
24
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
25
|
+
yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
|
|
26
|
+
return load_yaml(yaml_path)
|
|
27
|
+
|
|
28
|
+
def load_mapping_file(self, mapping_file):
|
|
29
|
+
if isinstance(mapping_file, str):
|
|
30
|
+
mapping_dict = load_yaml(mapping_file)
|
|
31
|
+
else:
|
|
32
|
+
mapping_dict = {}
|
|
33
|
+
return mapping_dict
|
|
34
|
+
|
|
35
|
+
def process_cell_mapping(self, npu_op_name):
|
|
36
|
+
npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
|
|
37
|
+
if self.cell_mapping_dict:
|
|
38
|
+
for index, op_name in enumerate(npu_op_name):
|
|
39
|
+
# get cell name & class name from op_name
|
|
40
|
+
# Cell.fc1.Dense.forward.0.input.0
|
|
41
|
+
cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
|
|
42
|
+
if cell_name in self.cell_mapping_dict:
|
|
43
|
+
npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
44
|
+
return npu_op_name
|
|
45
|
+
|
|
46
|
+
def check_op(self, npu_dict, bench_dict, fuzzy_match):
|
|
47
|
+
npu_op_name = npu_dict["op_name"].copy()
|
|
48
|
+
bench_op_name = bench_dict["op_name"].copy()
|
|
49
|
+
|
|
50
|
+
if self.api_mapping is not None:
|
|
51
|
+
npu_op_name = self.process_api_mapping(npu_op_name, bench_op_name)
|
|
52
|
+
if self.cell_mapping is not None:
|
|
53
|
+
npu_op_name = self.process_cell_mapping(npu_op_name)
|
|
54
|
+
|
|
55
|
+
struct_match = check_struct_match(npu_dict, bench_dict, cross_frame=self.cross_frame)
|
|
56
|
+
if not fuzzy_match:
|
|
57
|
+
return npu_op_name == bench_op_name and struct_match
|
|
58
|
+
is_match = True
|
|
59
|
+
try:
|
|
60
|
+
is_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
61
|
+
except Exception as err:
|
|
62
|
+
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
63
|
+
is_match = False
|
|
64
|
+
return is_match and struct_match
|
|
65
|
+
|
|
66
|
+
def read_npy_data(self, dir_path, file_name, load_pt_file=False):
|
|
67
|
+
data_path = os.path.join(dir_path, file_name)
|
|
68
|
+
if load_pt_file:
|
|
69
|
+
import torch
|
|
70
|
+
from msprobe.pytorch.common.utils import load_pt
|
|
71
|
+
data_value = load_pt(data_path).detach()
|
|
72
|
+
if data_value.dtype == torch.bfloat16:
|
|
73
|
+
data_value = data_value.to(torch.float32)
|
|
74
|
+
data_value = data_value.numpy()
|
|
75
|
+
else:
|
|
76
|
+
data_value = load_npy(data_path)
|
|
77
|
+
return data_value
|
|
78
|
+
|
|
79
|
+
def api_replace(self, npu_op_name, target, para):
|
|
80
|
+
for idx, _ in enumerate(npu_op_name):
|
|
81
|
+
npu_op_name[idx] = npu_op_name[idx].replace(target, para)
|
|
82
|
+
return npu_op_name
|
|
83
|
+
|
|
84
|
+
def process_api_mapping(self, npu_op_name, bench_op_name):
|
|
85
|
+
# get api name & class name from op_name
|
|
86
|
+
# Functional.addcmul.0.forward.input.0
|
|
87
|
+
ms_api_name = npu_op_name[0].rsplit(Const.SEP, 4)[0]
|
|
88
|
+
pt_api_name = bench_op_name[0].rsplit(Const.SEP, 4)[0]
|
|
89
|
+
class_name = ms_api_name.split(Const.SEP)[0]
|
|
90
|
+
if class_name == "Mint":
|
|
91
|
+
return self.api_replace(npu_op_name, "Mint", "Torch")
|
|
92
|
+
elif class_name == "MintFunctional":
|
|
93
|
+
return self.api_replace(npu_op_name, "MintFunctional", "Functional")
|
|
94
|
+
elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
|
|
95
|
+
return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
|
|
96
|
+
else:
|
|
97
|
+
return npu_op_name
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def ms_compare(input_param, output_path, **kwargs):
|
|
101
|
+
try:
|
|
102
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
103
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
104
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
105
|
+
cell_mapping = kwargs.get('cell_mapping', None)
|
|
106
|
+
api_mapping = kwargs.get('api_mapping', None)
|
|
107
|
+
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
108
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
109
|
+
create_directory(output_path)
|
|
110
|
+
check_compare_param(input_param, output_path, summary_compare, md5_compare)
|
|
111
|
+
except (CompareException, FileCheckException) as error:
|
|
112
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
113
|
+
raise CompareException(error.code) from error
|
|
114
|
+
ms_comparator = MSComparator(cell_mapping, api_mapping)
|
|
115
|
+
ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
116
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
117
|
+
md5_compare=md5_compare)
|