mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,13 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import os
|
|
17
|
+
import numpy as np
|
|
16
18
|
import mindspore
|
|
17
19
|
from mindspore import ops
|
|
18
|
-
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
19
21
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
20
22
|
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
21
23
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
|
|
22
24
|
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion
|
|
26
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
23
27
|
from msprobe.mindspore.common.log import logger
|
|
24
28
|
|
|
25
29
|
|
|
@@ -36,7 +40,6 @@ else:
|
|
|
36
40
|
import torch
|
|
37
41
|
|
|
38
42
|
|
|
39
|
-
|
|
40
43
|
class ApiInputAggregation:
|
|
41
44
|
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
42
45
|
"""
|
|
@@ -64,7 +67,9 @@ api_parent_module_mapping = {
|
|
|
64
67
|
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
|
|
65
68
|
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
66
69
|
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
|
|
67
|
-
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed
|
|
70
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed,
|
|
71
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops,
|
|
72
|
+
(MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion
|
|
68
73
|
|
|
69
74
|
}
|
|
70
75
|
|
|
@@ -83,7 +88,9 @@ api_parent_module_str_mapping = {
|
|
|
83
88
|
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
|
|
84
89
|
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
85
90
|
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
|
|
86
|
-
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed"
|
|
91
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed",
|
|
92
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops",
|
|
93
|
+
(MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion"
|
|
87
94
|
}
|
|
88
95
|
|
|
89
96
|
|
|
@@ -125,7 +132,8 @@ class ApiRunner:
|
|
|
125
132
|
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
126
133
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
127
134
|
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
128
|
-
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API
|
|
135
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API,
|
|
136
|
+
MsCompareConst.FUNCTIONAL_API] \
|
|
129
137
|
and api_platform == Const.MS_FRAMEWORK:
|
|
130
138
|
err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
|
|
131
139
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
@@ -139,21 +147,24 @@ class ApiRunner:
|
|
|
139
147
|
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
140
148
|
"""
|
|
141
149
|
Args:
|
|
142
|
-
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
150
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
|
|
143
151
|
api_sub_name: str, e.g. "relu"
|
|
144
|
-
api_platform: str: Union["
|
|
152
|
+
api_platform: str: Union["mindspore", "pytorch"]
|
|
145
153
|
|
|
146
154
|
Return:
|
|
147
155
|
api_instance: function object
|
|
148
156
|
|
|
149
157
|
Description:
|
|
150
|
-
get mindspore.mint/torch api
|
|
158
|
+
get mindspore.mint/torch api function
|
|
151
159
|
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
152
160
|
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
153
161
|
"""
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
162
|
+
if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch":
|
|
163
|
+
api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform))
|
|
164
|
+
api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform))
|
|
165
|
+
else:
|
|
166
|
+
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
167
|
+
api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
|
|
157
168
|
full_api_name = api_parent_module_str + Const.SEP + api_sub_name
|
|
158
169
|
|
|
159
170
|
if not hasattr(api_parent_module, api_sub_name):
|
|
@@ -179,6 +190,8 @@ class ApiRunner:
|
|
|
179
190
|
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
180
191
|
forward_result_tuple = convert_to_tuple(forward_result)
|
|
181
192
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
193
|
+
if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
|
|
194
|
+
return res_compute_element_list, inputs, kwargs, forward_result_tuple
|
|
182
195
|
else:
|
|
183
196
|
if gradient_inputs is None:
|
|
184
197
|
err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
|
|
@@ -196,6 +209,7 @@ class ApiRunner:
|
|
|
196
209
|
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
197
210
|
backward_result_tuple = convert_to_tuple(backward_result)
|
|
198
211
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
212
|
+
return res_compute_element_list, gradient_inputs, backward_result_tuple
|
|
199
213
|
else:
|
|
200
214
|
# set requires_grad
|
|
201
215
|
requires_grad_index = []
|
|
@@ -18,9 +18,10 @@ from abc import ABC, abstractmethod
|
|
|
18
18
|
import mindspore
|
|
19
19
|
import numpy as np
|
|
20
20
|
import torch
|
|
21
|
-
from msprobe.core.common.const import CompareConst
|
|
21
|
+
from msprobe.core.common.const import CompareConst
|
|
22
22
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
23
23
|
from msprobe.mindspore.common.log import logger
|
|
24
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class CompareResult:
|