mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# MSAdapter 场景的溢出检测
|
|
2
|
+
|
|
3
|
+
msprobe 工具提供 MSAdapter 场景下的溢出检测功能。其检测对象为 **API** 级别(除 Primitive 和 Jit 类 API)或**模块**级别,分别对应 config.json 配置中的 **"L1"** 、**"L0"** level。
|
|
4
|
+
|
|
5
|
+
需要注意,本工具仅支持在 INF/NAN 模式<sup>a</sup>下进行溢出检测。INF/NAN 模式的使能方式如下:
|
|
6
|
+
|
|
7
|
+
```Shell
|
|
8
|
+
# 使能 CANN 侧 INF/NAN 模式
|
|
9
|
+
export INF_NAN_MODE_ENABLE=1
|
|
10
|
+
# 使能 MindSpore 框架侧 INF/NAN 模式
|
|
11
|
+
export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE"
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。
|
|
15
|
+
|
|
16
|
+
溢出检测任务的配置示例见["**MindSpore 动态图场景 task 配置为 overflow_check**"](./03.config_examples.md#33-task配置为overflow_check)小节。
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
## 1 接口介绍
|
|
20
|
+
|
|
21
|
+
溢出检测功能提供的接口与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**2 接口介绍**"](./29.data_dump_MSAdapter.md#2-接口介绍)小节。
|
|
22
|
+
|
|
23
|
+
需要注意,目前暂不支持 "L1" level 下 primitive op 的溢出检测。
|
|
24
|
+
|
|
25
|
+
## 2 示例代码
|
|
26
|
+
|
|
27
|
+
溢出检测功能使用方式与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**3 示例代码**"](./29.data_dump_MSAdapter.md#3-示例代码)小节。
|
|
28
|
+
|
|
29
|
+
## 3 溢出检测结果文件介绍
|
|
30
|
+
|
|
31
|
+
溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 模块 的真实数据或统计信息。详见 MSAdapter 场景的精度数据采集中的["**4 dump 结果文件介绍**"](./29.data_dump_MSAdapter.md#4-dump-结果文件介绍)小节。
|
msprobe/docs/FAQ.md
CHANGED
|
@@ -58,11 +58,7 @@
|
|
|
58
58
|
|
|
59
59
|
答:对于 fp16 的数据,CPU 会上升一个精度 fp32 去计算,这是和算子那边对齐的精度结论,CPU 用更高精度去计算会更接近真实值。
|
|
60
60
|
|
|
61
|
-
6.
|
|
62
|
-
|
|
63
|
-
答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 Tensor: 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要 dump 关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
64
|
-
|
|
65
|
-
7. Tensor 魔法函数具体对应什么操作?
|
|
61
|
+
6. Tensor 魔法函数具体对应什么操作?
|
|
66
62
|
|
|
67
63
|
答:
|
|
68
64
|
|
|
@@ -202,15 +198,11 @@ def npu_forward_fused_softmax(self, input_, mask):
|
|
|
202
198
|
|
|
203
199
|
答:正常现象,dataloader 通过 raise 结束程序,堆栈信息可忽略。
|
|
204
200
|
|
|
205
|
-
10.
|
|
206
|
-
|
|
207
|
-
答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: ` 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
208
|
-
|
|
209
|
-
11. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。
|
|
201
|
+
10. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。
|
|
210
202
|
|
|
211
203
|
答:这一类报错常见于 Megatron/MindSpeed/ModelLink 等加速库或模型仓中,原因是工具本身会封装 torch 的 API(API类型和地址会发生改变),而有些 API 在工具使能前类型和地址就已经确定,此时工具无法对这类 API 再进行封装,而加速库中会对某些 API 进行类型检查,即会把工具无法封装的原始的 API和工具封装之后的 API 进行判断,所以会报错。
|
|
212
204
|
规避方式有3种:①将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装;②注释 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中的 `-gelu` 或者 `-silu`,工具会跳过采集该 API。③ 可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
213
205
|
|
|
214
|
-
|
|
206
|
+
11. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
|
|
215
207
|
|
|
216
208
|
答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: `下的 `-t` 和 `- transpose`。
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
msprobe/mindspore/__init__.py
CHANGED
|
@@ -17,12 +17,13 @@ import os
|
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
19
|
from msprobe.lib import _msprobe_c
|
|
20
|
-
os.environ["MS_HOOK_ENABLE"] = "on"
|
|
21
20
|
os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__
|
|
22
21
|
except ImportError:
|
|
23
22
|
from .common.log import logger
|
|
24
23
|
logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.")
|
|
25
24
|
|
|
26
25
|
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
27
|
-
from msprobe.mindspore.common.utils import seed_all
|
|
28
|
-
from msprobe.mindspore.monitor.module_hook import TrainerMon
|
|
26
|
+
from msprobe.mindspore.common.utils import seed_all, MsprobeStep, MsprobeInitStep
|
|
27
|
+
from msprobe.mindspore.monitor.module_hook import TrainerMon
|
|
28
|
+
|
|
29
|
+
os.environ["MS_HOOK_ENABLE"] = "on"
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from tqdm import tqdm
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
20
20
|
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
|
|
21
21
|
from msprobe.core.common.utils import add_time_as_suffix
|
|
22
22
|
from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
|
|
@@ -25,6 +25,7 @@ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compar
|
|
|
25
25
|
from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
|
|
26
26
|
from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
|
|
27
27
|
trim_output_compute_element_list)
|
|
28
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
28
29
|
from msprobe.mindspore.common.log import logger
|
|
29
30
|
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
30
31
|
|
|
@@ -156,6 +157,7 @@ class ApiAccuracyChecker:
|
|
|
156
157
|
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
157
158
|
api_list = load_yaml(yaml_path)
|
|
158
159
|
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
160
|
+
supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST
|
|
159
161
|
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
|
|
160
162
|
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
161
163
|
return True
|
|
@@ -165,6 +167,9 @@ class ApiAccuracyChecker:
|
|
|
165
167
|
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
|
|
166
168
|
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
167
169
|
return True
|
|
170
|
+
if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \
|
|
171
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
172
|
+
return True
|
|
168
173
|
return False
|
|
169
174
|
|
|
170
175
|
def parse(self, api_info_path):
|
|
@@ -15,11 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
import mindspore
|
|
17
17
|
from mindspore import ops
|
|
18
|
-
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
19
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
20
20
|
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
21
21
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
|
|
22
22
|
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
23
|
+
from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion
|
|
24
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
23
25
|
from msprobe.mindspore.common.log import logger
|
|
24
26
|
|
|
25
27
|
|
|
@@ -64,7 +66,9 @@ api_parent_module_mapping = {
|
|
|
64
66
|
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
|
|
65
67
|
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
66
68
|
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
|
|
67
|
-
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed
|
|
69
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed,
|
|
70
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops,
|
|
71
|
+
(MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion
|
|
68
72
|
|
|
69
73
|
}
|
|
70
74
|
|
|
@@ -83,7 +87,9 @@ api_parent_module_str_mapping = {
|
|
|
83
87
|
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
|
|
84
88
|
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
85
89
|
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
|
|
86
|
-
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed"
|
|
90
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed",
|
|
91
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops",
|
|
92
|
+
(MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion"
|
|
87
93
|
}
|
|
88
94
|
|
|
89
95
|
|
|
@@ -125,7 +131,8 @@ class ApiRunner:
|
|
|
125
131
|
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
126
132
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
127
133
|
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
128
|
-
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API
|
|
134
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API,
|
|
135
|
+
MsCompareConst.FUNCTIONAL_API] \
|
|
129
136
|
and api_platform == Const.MS_FRAMEWORK:
|
|
130
137
|
err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
|
|
131
138
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
@@ -139,9 +146,9 @@ class ApiRunner:
|
|
|
139
146
|
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
140
147
|
"""
|
|
141
148
|
Args:
|
|
142
|
-
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
149
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
|
|
143
150
|
api_sub_name: str, e.g. "relu"
|
|
144
|
-
api_platform: str: Union["mindpore", "
|
|
151
|
+
api_platform: str: Union["mindpore", "pytorch"]
|
|
145
152
|
|
|
146
153
|
Return:
|
|
147
154
|
api_instance: function object
|
|
@@ -151,9 +158,12 @@ class ApiRunner:
|
|
|
151
158
|
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
152
159
|
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
153
160
|
"""
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
161
|
+
if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch":
|
|
162
|
+
api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform))
|
|
163
|
+
api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform))
|
|
164
|
+
else:
|
|
165
|
+
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
166
|
+
api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
|
|
157
167
|
full_api_name = api_parent_module_str + Const.SEP + api_sub_name
|
|
158
168
|
|
|
159
169
|
if not hasattr(api_parent_module, api_sub_name):
|
|
@@ -18,9 +18,10 @@ from abc import ABC, abstractmethod
|
|
|
18
18
|
import mindspore
|
|
19
19
|
import numpy as np
|
|
20
20
|
import torch
|
|
21
|
-
from msprobe.core.common.const import CompareConst
|
|
21
|
+
from msprobe.core.common.const import CompareConst
|
|
22
22
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
23
23
|
from msprobe.mindspore.common.log import logger
|
|
24
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class CompareResult:
|