mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -14,24 +14,39 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import mindspore
|
|
17
|
-
import torch
|
|
18
17
|
from mindspore import ops
|
|
19
|
-
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
20
19
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
21
20
|
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
22
21
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
|
|
23
22
|
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
23
|
+
from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion
|
|
24
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
24
25
|
from msprobe.mindspore.common.log import logger
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
29
|
+
|
|
30
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
31
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_tensor
|
|
32
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_func
|
|
33
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_dist
|
|
34
|
+
|
|
35
|
+
if torch_mindtorch_importer.is_valid_pt_mt_env:
|
|
36
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
|
|
37
|
+
else:
|
|
38
|
+
import torch
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
27
42
|
class ApiInputAggregation:
|
|
28
43
|
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
29
|
-
|
|
44
|
+
"""
|
|
30
45
|
Args:
|
|
31
46
|
inputs: List[ComputeElement]
|
|
32
47
|
kwargs: dict{str: ComputeElement}
|
|
33
48
|
gradient_inputs: Union[List[ComputeElement], None]
|
|
34
|
-
|
|
49
|
+
"""
|
|
35
50
|
self.inputs = inputs
|
|
36
51
|
self.kwargs = kwargs
|
|
37
52
|
self.gradient_inputs = gradient_inputs
|
|
@@ -43,16 +58,38 @@ api_parent_module_mapping = {
|
|
|
43
58
|
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
44
59
|
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
45
60
|
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
|
|
46
|
-
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
|
|
61
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor,
|
|
62
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor,
|
|
63
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor,
|
|
64
|
+
(MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch,
|
|
65
|
+
(MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch,
|
|
66
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
|
|
67
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
68
|
+
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
|
|
69
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed,
|
|
70
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops,
|
|
71
|
+
(MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion
|
|
72
|
+
|
|
47
73
|
}
|
|
48
74
|
|
|
75
|
+
|
|
49
76
|
api_parent_module_str_mapping = {
|
|
50
77
|
(MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
|
|
51
78
|
(MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
|
|
52
79
|
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
|
|
53
80
|
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
54
81
|
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
|
|
55
|
-
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
|
|
82
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor",
|
|
83
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor",
|
|
84
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor",
|
|
85
|
+
(MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch",
|
|
86
|
+
(MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch",
|
|
87
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
|
|
88
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
89
|
+
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
|
|
90
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed",
|
|
91
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops",
|
|
92
|
+
(MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion"
|
|
56
93
|
}
|
|
57
94
|
|
|
58
95
|
|
|
@@ -64,7 +101,7 @@ class ApiRunner:
|
|
|
64
101
|
api_input_aggregation: ApiInputAggregation
|
|
65
102
|
api_name_str: str, e.g. "MintFunctional.relu.0"
|
|
66
103
|
forward_or_backward: str, Union["forward", "backward"]
|
|
67
|
-
api_platform: str, Union["mindspore", "torch"]
|
|
104
|
+
api_platform: str, Union["mindspore", "torch", "mindtorch"]
|
|
68
105
|
|
|
69
106
|
Return:
|
|
70
107
|
outputs: list[ComputeElement]
|
|
@@ -72,39 +109,46 @@ class ApiRunner:
|
|
|
72
109
|
Description:
|
|
73
110
|
run mindspore.mint/torch api
|
|
74
111
|
'''
|
|
75
|
-
|
|
112
|
+
|
|
113
|
+
api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform)
|
|
76
114
|
api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
|
|
77
115
|
|
|
78
116
|
return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
|
|
79
117
|
|
|
80
118
|
@staticmethod
|
|
81
|
-
def get_info_from_name(api_name_str):
|
|
82
|
-
|
|
119
|
+
def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK):
|
|
120
|
+
"""
|
|
83
121
|
Args:
|
|
84
122
|
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
85
|
-
|
|
123
|
+
api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch".
|
|
124
|
+
It specifies which framework is being used. Default is "mindspore".
|
|
86
125
|
Return:
|
|
87
|
-
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
126
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"]
|
|
88
127
|
api_sub_name: str, e.g. "relu"
|
|
89
|
-
|
|
128
|
+
"""
|
|
90
129
|
api_name_list = api_name_str.split(Const.SEP)
|
|
91
130
|
if len(api_name_list) != 3:
|
|
92
131
|
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
93
132
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
94
133
|
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
95
|
-
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API
|
|
134
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API,
|
|
135
|
+
MsCompareConst.FUNCTIONAL_API] \
|
|
136
|
+
and api_platform == Const.MS_FRAMEWORK:
|
|
96
137
|
err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
|
|
97
138
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
98
139
|
|
|
140
|
+
if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK:
|
|
141
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api"
|
|
142
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
99
143
|
return api_type_str, api_sub_name
|
|
100
144
|
|
|
101
145
|
@staticmethod
|
|
102
146
|
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
103
|
-
|
|
147
|
+
"""
|
|
104
148
|
Args:
|
|
105
|
-
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
149
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
|
|
106
150
|
api_sub_name: str, e.g. "relu"
|
|
107
|
-
api_platform: str: Union["mindpore", "
|
|
151
|
+
api_platform: str: Union["mindpore", "pytorch"]
|
|
108
152
|
|
|
109
153
|
Return:
|
|
110
154
|
api_instance: function object
|
|
@@ -113,11 +157,15 @@ class ApiRunner:
|
|
|
113
157
|
get mindspore.mint/torch api fucntion
|
|
114
158
|
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
115
159
|
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
160
|
+
"""
|
|
161
|
+
if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch":
|
|
162
|
+
api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform))
|
|
163
|
+
api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform))
|
|
164
|
+
else:
|
|
165
|
+
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
166
|
+
api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
|
|
120
167
|
full_api_name = api_parent_module_str + Const.SEP + api_sub_name
|
|
168
|
+
|
|
121
169
|
if not hasattr(api_parent_module, api_sub_name):
|
|
122
170
|
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
123
171
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
@@ -147,7 +195,7 @@ class ApiRunner:
|
|
|
147
195
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
148
196
|
gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
149
197
|
for compute_element in gradient_inputs)
|
|
150
|
-
if api_platform == Const.MS_FRAMEWORK:
|
|
198
|
+
if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
|
|
151
199
|
if len(gradient_inputs) == 1:
|
|
152
200
|
gradient_inputs = gradient_inputs[0]
|
|
153
201
|
|
|
@@ -18,9 +18,10 @@ from abc import ABC, abstractmethod
|
|
|
18
18
|
import mindspore
|
|
19
19
|
import numpy as np
|
|
20
20
|
import torch
|
|
21
|
-
from msprobe.core.common.const import CompareConst
|
|
21
|
+
from msprobe.core.common.const import CompareConst
|
|
22
22
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
23
23
|
from msprobe.mindspore.common.log import logger
|
|
24
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class CompareResult:
|