mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
from mindspore import ops
|
|
6
|
+
|
|
7
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
8
|
+
from msprobe.core.common.const import Const, MsCompareConst
|
|
9
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
10
|
+
from msprobe.core.common.log import logger
|
|
11
|
+
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ApiInputAggregation:
|
|
15
|
+
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
16
|
+
'''
|
|
17
|
+
Args:
|
|
18
|
+
inputs: List[ComputeElement]
|
|
19
|
+
kwargs: dict{str: ComputeElement}
|
|
20
|
+
gradient_inputs: Union[List[ComputeElement], None]
|
|
21
|
+
'''
|
|
22
|
+
self.inputs = inputs
|
|
23
|
+
self.kwargs = kwargs
|
|
24
|
+
self.gradient_inputs = gradient_inputs
|
|
25
|
+
|
|
26
|
+
api_parent_module_mapping = {
|
|
27
|
+
(MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
|
|
28
|
+
(MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
|
|
29
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
30
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
class ApiRunner:
|
|
34
|
+
def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
|
|
35
|
+
api_platform=Const.MS_FRAMEWORK):
|
|
36
|
+
'''
|
|
37
|
+
Args:
|
|
38
|
+
api_input_aggregation: ApiInputAggregation
|
|
39
|
+
api_name_str: str, e.g. "MintFunctional.relu.0"
|
|
40
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
41
|
+
api_platform: str, Union["mindspore", "torch"]
|
|
42
|
+
|
|
43
|
+
Return:
|
|
44
|
+
outputs: list[ComputeElement]
|
|
45
|
+
|
|
46
|
+
Description:
|
|
47
|
+
run mindspore.mint/torch api
|
|
48
|
+
'''
|
|
49
|
+
api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
|
|
50
|
+
api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
|
|
51
|
+
|
|
52
|
+
return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def get_info_from_name(api_name_str):
|
|
56
|
+
'''
|
|
57
|
+
Args:
|
|
58
|
+
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
59
|
+
|
|
60
|
+
Return:
|
|
61
|
+
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
62
|
+
api_sub_name: str, e.g. "relu"
|
|
63
|
+
'''
|
|
64
|
+
api_name_list = api_name_str.split(Const.SEP)
|
|
65
|
+
if len(api_name_list) != 3:
|
|
66
|
+
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
67
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
68
|
+
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
69
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
|
|
70
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
|
|
71
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
72
|
+
|
|
73
|
+
return api_type_str, api_sub_name
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
77
|
+
'''
|
|
78
|
+
Args:
|
|
79
|
+
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
80
|
+
api_sub_name: str, e.g. "relu"
|
|
81
|
+
api_platform: str: Union["mindpore", "torch"]
|
|
82
|
+
|
|
83
|
+
Return:
|
|
84
|
+
api_instance: function object
|
|
85
|
+
|
|
86
|
+
Description:
|
|
87
|
+
get mindspore.mint/torch api fucntion
|
|
88
|
+
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
89
|
+
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
90
|
+
'''
|
|
91
|
+
|
|
92
|
+
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
93
|
+
module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
|
|
94
|
+
submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
|
|
95
|
+
full_api_name = module_str + submodule_str + api_sub_name
|
|
96
|
+
if not hasattr(api_parent_module, api_sub_name):
|
|
97
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
98
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
99
|
+
|
|
100
|
+
api_instance = getattr(api_parent_module, api_sub_name)
|
|
101
|
+
if not callable(api_instance):
|
|
102
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
|
|
103
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
104
|
+
|
|
105
|
+
return api_instance
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
|
|
109
|
+
inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
110
|
+
for compute_element in api_input_aggregation.inputs)
|
|
111
|
+
kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
112
|
+
for key, value in api_input_aggregation.kwargs.items()}
|
|
113
|
+
gradient_inputs = api_input_aggregation.gradient_inputs
|
|
114
|
+
|
|
115
|
+
if forward_or_backward == Const.FORWARD:
|
|
116
|
+
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
117
|
+
forward_result_tuple = convert_to_tuple(forward_result)
|
|
118
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
119
|
+
else:
|
|
120
|
+
if gradient_inputs is None:
|
|
121
|
+
err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
|
|
122
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
123
|
+
gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
124
|
+
for compute_element in gradient_inputs)
|
|
125
|
+
if api_platform == Const.MS_FRAMEWORK:
|
|
126
|
+
if len(gradient_inputs) == 1:
|
|
127
|
+
gradient_inputs = gradient_inputs[0]
|
|
128
|
+
def api_with_kwargs(*forward_inputs):
|
|
129
|
+
return api_instance(*forward_inputs, **kwargs)
|
|
130
|
+
grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
|
|
131
|
+
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
132
|
+
backward_result_tuple = convert_to_tuple(backward_result)
|
|
133
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
134
|
+
else:
|
|
135
|
+
#set requires_grad
|
|
136
|
+
for tensor in inputs:
|
|
137
|
+
if hasattr(tensor, "requires_grad"):
|
|
138
|
+
setattr(tensor, "requires_grad", True)
|
|
139
|
+
forward_results = api_instance(*inputs, **kwargs)
|
|
140
|
+
forward_results = convert_to_tuple(forward_results)
|
|
141
|
+
for forward_res, gradient_in in zip(forward_results, gradient_inputs):
|
|
142
|
+
forward_res.backward(gradient_in)
|
|
143
|
+
backward_result_list = []
|
|
144
|
+
for tensor in inputs:
|
|
145
|
+
if hasattr(tensor, "grad"):
|
|
146
|
+
backward_result_list.append(getattr(tensor, "grad"))
|
|
147
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
|
|
148
|
+
|
|
149
|
+
return res_compute_element_list
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
api_runner = ApiRunner()
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
8
|
+
from msprobe.core.common.log import logger
|
|
9
|
+
from msprobe.core.common.const import CompareConst, MsCompareConst
|
|
10
|
+
|
|
11
|
+
class CompareResult:
|
|
12
|
+
def __init__(self, compare_value, pass_status, err_msg):
|
|
13
|
+
self.compare_value = compare_value
|
|
14
|
+
self.pass_status = pass_status
|
|
15
|
+
self.err_msg = err_msg
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseCompareAlgorithm(ABC):
|
|
19
|
+
def __init__(self) -> None:
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.compare_algorithm_name = None
|
|
22
|
+
self.err_msg_mapping = {
|
|
23
|
+
CompareConst.COSINE: {
|
|
24
|
+
CompareConst.PASS: "",
|
|
25
|
+
CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ",
|
|
26
|
+
CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ",
|
|
27
|
+
},
|
|
28
|
+
CompareConst.MAX_ABS_ERR: {
|
|
29
|
+
CompareConst.PASS: "",
|
|
30
|
+
CompareConst.ERROR: "max absolute difference is greater than " \
|
|
31
|
+
f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
|
|
32
|
+
CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
|
|
33
|
+
},
|
|
34
|
+
CompareConst.MAX_RELATIVE_ERR: {
|
|
35
|
+
CompareConst.PASS: "",
|
|
36
|
+
CompareConst.ERROR: "",
|
|
37
|
+
CompareConst.SKIP: "",
|
|
38
|
+
},
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def __call__(self, bench_compute_element, tested_compute_element):
|
|
42
|
+
'''
|
|
43
|
+
Args:
|
|
44
|
+
bench_compute_element: ComputeElement
|
|
45
|
+
tested_compute_element: ComputeElement
|
|
46
|
+
|
|
47
|
+
Return:
|
|
48
|
+
compare_result: CompareResult
|
|
49
|
+
'''
|
|
50
|
+
if self.check_validity(bench_compute_element, tested_compute_element):
|
|
51
|
+
compare_value = self.run_compare(bench_compute_element, tested_compute_element)
|
|
52
|
+
pass_status = self.check_pass(compare_value)
|
|
53
|
+
else:
|
|
54
|
+
logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.")
|
|
55
|
+
compare_value = None
|
|
56
|
+
pass_status = CompareConst.SKIP
|
|
57
|
+
|
|
58
|
+
err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status)
|
|
59
|
+
|
|
60
|
+
compare_result = CompareResult(compare_value, pass_status, err_msg)
|
|
61
|
+
return compare_result
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def convert_to_np_float64_ndarray(tensor):
|
|
65
|
+
if isinstance(tensor, mindspore.Tensor):
|
|
66
|
+
ndarray = tensor.astype(mindspore.float64).numpy()
|
|
67
|
+
elif isinstance(tensor, torch.Tensor):
|
|
68
|
+
ndarray = tensor.to(torch.float64, copy=True).numpy()
|
|
69
|
+
else:
|
|
70
|
+
err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
|
|
71
|
+
"input is not mindspore.Tensor or torch.Tensor"
|
|
72
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
73
|
+
return ndarray
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def check_two_tensor(bench_compute_element, tested_compute_element):
|
|
77
|
+
bench_parameter = bench_compute_element.get_parameter()
|
|
78
|
+
tested_parameter = tested_compute_element.get_parameter()
|
|
79
|
+
|
|
80
|
+
bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor))
|
|
81
|
+
tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor))
|
|
82
|
+
shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape()
|
|
83
|
+
return bench_is_tensor and tested_is_tensor and shape_same
|
|
84
|
+
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
87
|
+
'''
|
|
88
|
+
Args:
|
|
89
|
+
bench_compute_element: ComputeElement
|
|
90
|
+
tested_compute_element: ComputeElement
|
|
91
|
+
|
|
92
|
+
Return:
|
|
93
|
+
check_res: boolean
|
|
94
|
+
'''
|
|
95
|
+
raise NotImplementedError
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
99
|
+
'''
|
|
100
|
+
Args:
|
|
101
|
+
bench_compute_element: ComputeElement
|
|
102
|
+
tested_compute_element: ComputeElement
|
|
103
|
+
|
|
104
|
+
Return:
|
|
105
|
+
compare_value: float/int
|
|
106
|
+
'''
|
|
107
|
+
raise NotImplementedError
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def check_pass(self, compare_value):
|
|
111
|
+
'''
|
|
112
|
+
Args:
|
|
113
|
+
compare_value: float/int
|
|
114
|
+
|
|
115
|
+
Return:
|
|
116
|
+
pass_status: str
|
|
117
|
+
'''
|
|
118
|
+
raise NotImplementedError
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm):
|
|
122
|
+
def __init__(self) -> None:
|
|
123
|
+
super().__init__()
|
|
124
|
+
self.compare_algorithm_name = CompareConst.COSINE
|
|
125
|
+
|
|
126
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
127
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
128
|
+
|
|
129
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
130
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
131
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
132
|
+
|
|
133
|
+
bench_norm = np.linalg.norm(bench_ndarray)
|
|
134
|
+
tested_norm = np.linalg.norm(tested_ndarray)
|
|
135
|
+
dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten())
|
|
136
|
+
cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm)
|
|
137
|
+
return cosine_similarity
|
|
138
|
+
|
|
139
|
+
def check_pass(self, compare_value):
|
|
140
|
+
if compare_value > CompareConst.COS_THRESHOLD:
|
|
141
|
+
return CompareConst.PASS
|
|
142
|
+
else:
|
|
143
|
+
return CompareConst.ERROR
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
147
|
+
def __init__(self) -> None:
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.compare_algorithm_name = CompareConst.MAX_ABS_ERR
|
|
150
|
+
|
|
151
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
152
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
153
|
+
|
|
154
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
155
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
156
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
157
|
+
|
|
158
|
+
max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray))
|
|
159
|
+
return max_absolute_diff
|
|
160
|
+
|
|
161
|
+
def check_pass(self, compare_value):
|
|
162
|
+
if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD:
|
|
163
|
+
return CompareConst.PASS
|
|
164
|
+
else:
|
|
165
|
+
return CompareConst.ERROR
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
169
|
+
def __init__(self) -> None:
|
|
170
|
+
super().__init__()
|
|
171
|
+
self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR
|
|
172
|
+
|
|
173
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
174
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
175
|
+
|
|
176
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
177
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
178
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
179
|
+
|
|
180
|
+
abs_diff = np.abs(bench_ndarray - tested_ndarray)
|
|
181
|
+
bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON
|
|
182
|
+
max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero)
|
|
183
|
+
return max_relative_diff
|
|
184
|
+
|
|
185
|
+
def check_pass(self, compare_value):
|
|
186
|
+
if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD:
|
|
187
|
+
return CompareConst.PASS
|
|
188
|
+
else:
|
|
189
|
+
return CompareConst.ERROR
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
compare_algorithms = {
|
|
194
|
+
CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
|
|
195
|
+
CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
|
|
196
|
+
CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
|
|
197
|
+
}
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from msprobe.core.common.log import logger
|
|
8
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
9
|
+
from msprobe.core.common.utils import load_npy
|
|
10
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
|
|
11
|
+
ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
|
|
12
|
+
dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
|
|
13
|
+
dtype_str_to_torch_dtype, type_to_api_info_type_str,
|
|
14
|
+
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
|
|
15
|
+
MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
|
|
16
|
+
int_dtype_str_list)
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MstensorMetaData:
|
|
22
|
+
def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
|
|
23
|
+
self.dtype_str = dtype_str
|
|
24
|
+
self.npy_path = npy_path
|
|
25
|
+
self.maximum = maximum
|
|
26
|
+
self.minimum = minimum
|
|
27
|
+
self.shape = shape
|
|
28
|
+
|
|
29
|
+
class ComputeElement:
|
|
30
|
+
def __init__(self, compute_element_info=None, parameter=None):
|
|
31
|
+
self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
|
|
32
|
+
if parameter is not None:
|
|
33
|
+
self._init_with_parameter(parameter)
|
|
34
|
+
elif isinstance(compute_element_info, (list, dict)):
|
|
35
|
+
self._init_from_compute_element_info(compute_element_info)
|
|
36
|
+
else:
|
|
37
|
+
logger.error_log_with_exp(
|
|
38
|
+
"ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
|
|
39
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def transfer_to_torch_tensor(ms_tensor):
|
|
43
|
+
'''
|
|
44
|
+
Args:
|
|
45
|
+
ms_tensor: mindspore.Tensor
|
|
46
|
+
Return:
|
|
47
|
+
torch_tensor: torch.Tensor
|
|
48
|
+
'''
|
|
49
|
+
ms_dtype = ms_tensor.dtype
|
|
50
|
+
dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
|
|
51
|
+
if dtype_str not in dtype_str_to_torch_dtype:
|
|
52
|
+
err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
|
|
53
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
54
|
+
else:
|
|
55
|
+
torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
|
|
56
|
+
|
|
57
|
+
if dtype_str in float_dtype_str_list:
|
|
58
|
+
middle_dtype = mindspore.float64
|
|
59
|
+
elif dtype_str in int_dtype_str_list:
|
|
60
|
+
middle_dtype = mindspore.int64
|
|
61
|
+
else:
|
|
62
|
+
middle_dtype = mindspore.uint64
|
|
63
|
+
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
64
|
+
torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
|
|
65
|
+
return torch_tensor
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def transfer_to_mindspore_tensor(torch_tensor):
|
|
69
|
+
'''
|
|
70
|
+
Args:
|
|
71
|
+
torch_tensor: torch.Tensor
|
|
72
|
+
|
|
73
|
+
Return:
|
|
74
|
+
ms_tensor: mindspore.Tensor
|
|
75
|
+
'''
|
|
76
|
+
torch_dtype = torch_tensor.dtype
|
|
77
|
+
dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
|
|
78
|
+
if dtype_str not in dtype_str_to_ms_dtype:
|
|
79
|
+
err_msg = \
|
|
80
|
+
f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
|
|
81
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
82
|
+
else:
|
|
83
|
+
ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
|
|
84
|
+
|
|
85
|
+
if dtype_str in float_dtype_str_list:
|
|
86
|
+
middle_dtype = torch.float64
|
|
87
|
+
elif dtype_str in int_dtype_str_list:
|
|
88
|
+
middle_dtype = torch.int64
|
|
89
|
+
np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
|
|
90
|
+
ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
|
|
91
|
+
return ms_tensor
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def convert_inf_to_real_num(value, dtype_str):
|
|
95
|
+
if value == float("inf"):
|
|
96
|
+
np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
97
|
+
value = np.finfo(np_dtype).max
|
|
98
|
+
elif value == float("-inf"):
|
|
99
|
+
np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
100
|
+
value = np.finfo(np_dtype).min
|
|
101
|
+
return value
|
|
102
|
+
|
|
103
|
+
def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
|
|
104
|
+
'''
|
|
105
|
+
Args:
|
|
106
|
+
get_origin: boolean
|
|
107
|
+
get_mindspore_tensor: boolean
|
|
108
|
+
|
|
109
|
+
Return:
|
|
110
|
+
parameter: Union[int, float, str, slice,tuple, torch.Tensor, mindspore.Tensor]
|
|
111
|
+
'''
|
|
112
|
+
if isinstance(self.parameter, self.supported_parameter_type):
|
|
113
|
+
parameter_tmp = self.parameter
|
|
114
|
+
elif isinstance(self.parameter, MstensorMetaData):
|
|
115
|
+
mstensor_meta_data = self.parameter
|
|
116
|
+
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
117
|
+
if global_context.get_is_constructed():
|
|
118
|
+
np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
119
|
+
ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
|
|
120
|
+
mstensor_meta_data.minimum, np_dtype)
|
|
121
|
+
else:
|
|
122
|
+
ndarray = load_npy(mstensor_meta_data.npy_path)
|
|
123
|
+
parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
|
|
124
|
+
else:
|
|
125
|
+
err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
|
|
126
|
+
"(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
|
|
127
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
128
|
+
|
|
129
|
+
# if necessary, do transfer
|
|
130
|
+
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
131
|
+
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
132
|
+
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
|
|
133
|
+
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
134
|
+
else:
|
|
135
|
+
parameter = parameter_tmp
|
|
136
|
+
|
|
137
|
+
return parameter
|
|
138
|
+
|
|
139
|
+
def get_shape(self):
|
|
140
|
+
return self.shape
|
|
141
|
+
|
|
142
|
+
def get_dtype(self):
|
|
143
|
+
return self.dtype_str
|
|
144
|
+
|
|
145
|
+
def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
|
|
146
|
+
shape = tuple(shape)
|
|
147
|
+
np.random.seed(42)
|
|
148
|
+
if np_dtype == np.bool_:
|
|
149
|
+
ndarray = np.random.rand(*shape) > 0.5
|
|
150
|
+
else:
|
|
151
|
+
maximum = self.convert_inf_to_real_num(maximum, np_dtype)
|
|
152
|
+
minimum = self.convert_inf_to_real_num(minimum, np_dtype)
|
|
153
|
+
ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
|
|
154
|
+
return ndarray
|
|
155
|
+
|
|
156
|
+
def _init_from_compute_element_info(self, compute_element_info):
|
|
157
|
+
'''
|
|
158
|
+
Args:
|
|
159
|
+
compute_element_info: Union[list, dict]
|
|
160
|
+
is_constructed: boolean
|
|
161
|
+
|
|
162
|
+
Return:
|
|
163
|
+
void
|
|
164
|
+
|
|
165
|
+
init member attributes: self.shape, self.dtype_str, self.parameter
|
|
166
|
+
'''
|
|
167
|
+
if isinstance(compute_element_info, list):
|
|
168
|
+
self.shape = tuple()
|
|
169
|
+
self.dtype_str = TUPLE_TYPE_STR
|
|
170
|
+
self.parameter = tuple(ComputeElement(compute_element_info=sub_info).get_parameter()
|
|
171
|
+
for sub_info in compute_element_info)
|
|
172
|
+
else:
|
|
173
|
+
type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
|
|
174
|
+
accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
|
|
175
|
+
|
|
176
|
+
if type_str == MINDSPORE_TENSOR_TYPE_STR:
|
|
177
|
+
self._init_from_mstensor_compute_element_info(compute_element_info)
|
|
178
|
+
else: # type_str in ("slice", "int", "float", "bool")
|
|
179
|
+
value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
|
|
180
|
+
self.shape = tuple()
|
|
181
|
+
self.dtype_str = type_str
|
|
182
|
+
self.parameter = slice(*tuple(value)) if type_str == "slice" else value
|
|
183
|
+
|
|
184
|
+
def _init_from_mstensor_compute_element_info(self, compute_element_info):
|
|
185
|
+
'''
|
|
186
|
+
do not load real tensor, only record meta data
|
|
187
|
+
'''
|
|
188
|
+
dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
|
|
189
|
+
accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
|
|
190
|
+
shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
|
|
191
|
+
accepted_type=(list,))
|
|
192
|
+
if global_context.get_is_constructed():
|
|
193
|
+
maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
|
|
194
|
+
accepted_type=(int, float))
|
|
195
|
+
minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
|
|
196
|
+
accepted_type=(int, float))
|
|
197
|
+
|
|
198
|
+
npy_path = None
|
|
199
|
+
else:
|
|
200
|
+
maximum, minimum = None, None
|
|
201
|
+
data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
|
|
202
|
+
"data_name field in api_info.json", accepted_type=(str,))
|
|
203
|
+
npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
|
|
204
|
+
mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
|
|
205
|
+
self.parameter = mstensor_meta_data
|
|
206
|
+
self.dtype_str = dtype_str
|
|
207
|
+
self.shape = tuple(shape)
|
|
208
|
+
|
|
209
|
+
def _init_with_parameter(self, parameter):
|
|
210
|
+
self.parameter = parameter
|
|
211
|
+
if not isinstance(parameter, self.supported_parameter_type):
|
|
212
|
+
err_msg = "ComputeElement._init_with_parameter failed: " \
|
|
213
|
+
"parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
|
|
214
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
215
|
+
if isinstance(parameter, mindspore.Tensor):
|
|
216
|
+
self.shape = tuple(parameter.shape)
|
|
217
|
+
self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
|
|
218
|
+
elif isinstance(parameter, torch.Tensor):
|
|
219
|
+
self.shape = tuple(parameter.shape)
|
|
220
|
+
self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
|
|
221
|
+
else:
|
|
222
|
+
self.shape = tuple()
|
|
223
|
+
self.dtype_str = \
|
|
224
|
+
TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
|
|
2
|
+
|
|
3
|
+
def add_api_accuracy_checker_argument(parser):
|
|
4
|
+
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
5
|
+
help="<Required> The api param tool result file: generate from api param tool, "
|
|
6
|
+
"a json file.")
|
|
7
|
+
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
8
|
+
help="<optional> The ut task result out path.")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def api_checker_main(args):
|
|
12
|
+
api_accuracy_checker = ApiAccuracyChecker()
|
|
13
|
+
api_accuracy_checker.parse(args.api_info_file)
|
|
14
|
+
api_accuracy_checker.run_and_compare()
|
|
15
|
+
api_accuracy_checker.to_detail_csv(args.out_path)
|
|
16
|
+
api_accuracy_checker.to_result_csv(args.out_path)
|