mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ # MSAdapter 场景的溢出检测
2
+
3
+ msprobe 工具提供 MSAdapter 场景下的溢出检测功能。其检测对象为 **API** 级别(除 Primitive 和 Jit 类 API)或**模块**级别,分别对应 config.json 配置中的 **"L1"** 、**"L0"** level。
4
+
5
+ 需要注意,本工具仅支持在 INF/NAN 模式<sup>a</sup>下进行溢出检测。INF/NAN 模式的使能方式如下:
6
+
7
+ ```Shell
8
+ # 使能 CANN 侧 INF/NAN 模式
9
+ export INF_NAN_MODE_ENABLE=1
10
+ # 使能 MindSpore 框架侧 INF/NAN 模式
11
+ export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE"
12
+ ```
13
+
14
+ **a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。
15
+
16
+ 溢出检测任务的配置示例见["**MindSpore 动态图场景 task 配置为 overflow_check**"](./03.config_examples.md#33-task配置为overflow_check)小节。
17
+
18
+
19
+ ## 1 接口介绍
20
+
21
+ 溢出检测功能提供的接口与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**2 接口介绍**"](./29.data_dump_MSAdapter.md#2-接口介绍)小节。
22
+
23
+ 需要注意,目前暂不支持 "L1" level 下 primitive op 的溢出检测。
24
+
25
+ ## 2 示例代码
26
+
27
+ 溢出检测功能使用方式与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**3 示例代码**"](./29.data_dump_MSAdapter.md#3-示例代码)小节。
28
+
29
+ ## 3 溢出检测结果文件介绍
30
+
31
+ 溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 模块 的真实数据或统计信息。详见 MSAdapter 场景的精度数据采集中的["**4 dump 结果文件介绍**"](./29.data_dump_MSAdapter.md#4-dump-结果文件介绍)小节。
msprobe/docs/FAQ.md CHANGED
@@ -58,11 +58,7 @@
58
58
 
59
59
  答:对于 fp16 的数据,CPU 会上升一个精度 fp32 去计算,这是和算子那边对齐的精度结论,CPU 用更高精度去计算会更接近真实值。
60
60
 
61
- 6. 添加预检工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。
62
-
63
- 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 Tensor: 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要 dump 关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
64
-
65
- 7. Tensor 魔法函数具体对应什么操作?
61
+ 6. Tensor 魔法函数具体对应什么操作?
66
62
 
67
63
  答:
68
64
 
@@ -202,15 +198,11 @@ def npu_forward_fused_softmax(self, input_, mask):
202
198
 
203
199
  答:正常现象,dataloader 通过 raise 结束程序,堆栈信息可忽略。
204
200
 
205
- 10. 添加 msprobe 工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。
206
-
207
- 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: ` 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
208
-
209
- 11. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。
201
+ 10. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。
210
202
 
211
203
  答:这一类报错常见于 Megatron/MindSpeed/ModelLink 等加速库或模型仓中,原因是工具本身会封装 torch 的 API(API类型和地址会发生改变),而有些 API 在工具使能前类型和地址就已经确定,此时工具无法对这类 API 再进行封装,而加速库中会对某些 API 进行类型检查,即会把工具无法封装的原始的 API和工具封装之后的 API 进行判断,所以会报错。
212
204
  规避方式有3种:①将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装;②注释 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中的 `-gelu` 或者 `-silu`,工具会跳过采集该 API。③ 可以考虑根据报错堆栈信息注释引发报错的类型检查。
213
205
 
214
- 12. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
206
+ 11. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
215
207
 
216
208
  答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: `下的 `-t` 和 `- transpose`。
Binary file
Binary file
@@ -17,12 +17,13 @@ import os
17
17
 
18
18
  try:
19
19
  from msprobe.lib import _msprobe_c
20
- os.environ["MS_HOOK_ENABLE"] = "on"
21
20
  os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__
22
21
  except ImportError:
23
22
  from .common.log import logger
24
23
  logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.")
25
24
 
26
25
  from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
27
- from msprobe.mindspore.common.utils import seed_all
28
- from msprobe.mindspore.monitor.module_hook import TrainerMon
26
+ from msprobe.mindspore.common.utils import seed_all, MsprobeStep, MsprobeInitStep
27
+ from msprobe.mindspore.monitor.module_hook import TrainerMon
28
+
29
+ os.environ["MS_HOOK_ENABLE"] = "on"
@@ -16,7 +16,7 @@
16
16
  import os
17
17
  from tqdm import tqdm
18
18
 
19
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
19
+ from msprobe.core.common.const import Const, CompareConst
20
20
  from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
21
21
  from msprobe.core.common.utils import add_time_as_suffix
22
22
  from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
@@ -25,6 +25,7 @@ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compar
25
25
  from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
26
26
  from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
27
27
  trim_output_compute_element_list)
28
+ from msprobe.mindspore.common.const import MsCompareConst
28
29
  from msprobe.mindspore.common.log import logger
29
30
  from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
30
31
 
@@ -156,6 +157,7 @@ class ApiAccuracyChecker:
156
157
  real_api_str = Const.SEP.join(api_name_str_list[1:-2])
157
158
  api_list = load_yaml(yaml_path)
158
159
  supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
160
+ supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST
159
161
  if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
160
162
  and global_context.get_framework() == Const.MS_FRAMEWORK:
161
163
  return True
@@ -165,6 +167,9 @@ class ApiAccuracyChecker:
165
167
  if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
166
168
  and global_context.get_framework() == Const.MS_FRAMEWORK:
167
169
  return True
170
+ if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \
171
+ and global_context.get_framework() == Const.MS_FRAMEWORK:
172
+ return True
168
173
  return False
169
174
 
170
175
  def parse(self, api_info_path):
@@ -15,11 +15,13 @@
15
15
 
16
16
  import mindspore
17
17
  from mindspore import ops
18
- from msprobe.core.common.const import Const, MsCompareConst
18
+ from msprobe.core.common.const import Const
19
19
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
20
20
  from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
21
21
  from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
22
22
  from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
23
+ from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion
24
+ from msprobe.mindspore.common.const import MsCompareConst
23
25
  from msprobe.mindspore.common.log import logger
24
26
 
25
27
 
@@ -64,7 +66,9 @@ api_parent_module_mapping = {
64
66
  (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
65
67
  (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
66
68
  (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
67
- (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed
69
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed,
70
+ (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops,
71
+ (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion
68
72
 
69
73
  }
70
74
 
@@ -83,7 +87,9 @@ api_parent_module_str_mapping = {
83
87
  (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
84
88
  (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
85
89
  (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
86
- (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed"
90
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed",
91
+ (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops",
92
+ (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion"
87
93
  }
88
94
 
89
95
 
@@ -125,7 +131,8 @@ class ApiRunner:
125
131
  err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
126
132
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
127
133
  api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
128
- if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API] \
134
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API,
135
+ MsCompareConst.FUNCTIONAL_API] \
129
136
  and api_platform == Const.MS_FRAMEWORK:
130
137
  err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
131
138
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
@@ -139,9 +146,9 @@ class ApiRunner:
139
146
  def get_api_instance(api_type_str, api_sub_name, api_platform):
140
147
  """
141
148
  Args:
142
- api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
149
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
143
150
  api_sub_name: str, e.g. "relu"
144
- api_platform: str: Union["mindpore", "torch"]
151
+ api_platform: str: Union["mindpore", "pytorch"]
145
152
 
146
153
  Return:
147
154
  api_instance: function object
@@ -151,9 +158,12 @@ class ApiRunner:
151
158
  mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
152
159
  mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
153
160
  """
154
-
155
- api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
156
- api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
161
+ if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch":
162
+ api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform))
163
+ api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform))
164
+ else:
165
+ api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
166
+ api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
157
167
  full_api_name = api_parent_module_str + Const.SEP + api_sub_name
158
168
 
159
169
  if not hasattr(api_parent_module, api_sub_name):
@@ -18,9 +18,10 @@ from abc import ABC, abstractmethod
18
18
  import mindspore
19
19
  import numpy as np
20
20
  import torch
21
- from msprobe.core.common.const import CompareConst, MsCompareConst
21
+ from msprobe.core.common.const import CompareConst
22
22
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
23
23
  from msprobe.mindspore.common.log import logger
24
+ from msprobe.mindspore.common.const import MsCompareConst
24
25
 
25
26
 
26
27
  class CompareResult: