mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -14,24 +14,39 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import mindspore
17
- import torch
18
17
  from mindspore import ops
19
- from msprobe.core.common.const import Const, MsCompareConst
18
+ from msprobe.core.common.const import Const
20
19
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
21
20
  from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
22
21
  from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
23
22
  from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
23
+ from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion
24
+ from msprobe.mindspore.common.const import MsCompareConst
24
25
  from msprobe.mindspore.common.log import logger
25
26
 
26
27
 
28
+ from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
29
+
30
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
31
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_tensor
32
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_func
33
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_dist
34
+
35
+ if torch_mindtorch_importer.is_valid_pt_mt_env:
36
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
37
+ else:
38
+ import torch
39
+
40
+
41
+
27
42
  class ApiInputAggregation:
28
43
  def __init__(self, inputs, kwargs, gradient_inputs) -> None:
29
- '''
44
+ """
30
45
  Args:
31
46
  inputs: List[ComputeElement]
32
47
  kwargs: dict{str: ComputeElement}
33
48
  gradient_inputs: Union[List[ComputeElement], None]
34
- '''
49
+ """
35
50
  self.inputs = inputs
36
51
  self.kwargs = kwargs
37
52
  self.gradient_inputs = gradient_inputs
@@ -43,16 +58,38 @@ api_parent_module_mapping = {
43
58
  (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
44
59
  (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
45
60
  (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
46
- (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
61
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor,
62
+ (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor,
63
+ (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor,
64
+ (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch,
65
+ (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch,
66
+ (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
67
+ (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
68
+ (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
69
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed,
70
+ (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops,
71
+ (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion
72
+
47
73
  }
48
74
 
75
+
49
76
  api_parent_module_str_mapping = {
50
77
  (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
51
78
  (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
52
79
  (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
53
80
  (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
54
81
  (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
55
- (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
82
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor",
83
+ (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor",
84
+ (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor",
85
+ (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch",
86
+ (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch",
87
+ (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
88
+ (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
89
+ (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
90
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed",
91
+ (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops",
92
+ (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion"
56
93
  }
57
94
 
58
95
 
@@ -64,7 +101,7 @@ class ApiRunner:
64
101
  api_input_aggregation: ApiInputAggregation
65
102
  api_name_str: str, e.g. "MintFunctional.relu.0"
66
103
  forward_or_backward: str, Union["forward", "backward"]
67
- api_platform: str, Union["mindspore", "torch"]
104
+ api_platform: str, Union["mindspore", "torch", "mindtorch"]
68
105
 
69
106
  Return:
70
107
  outputs: list[ComputeElement]
@@ -72,39 +109,46 @@ class ApiRunner:
72
109
  Description:
73
110
  run mindspore.mint/torch api
74
111
  '''
75
- api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
112
+
113
+ api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform)
76
114
  api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
77
115
 
78
116
  return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
79
117
 
80
118
  @staticmethod
81
- def get_info_from_name(api_name_str):
82
- '''
119
+ def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK):
120
+ """
83
121
  Args:
84
122
  api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
85
-
123
+ api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch".
124
+ It specifies which framework is being used. Default is "mindspore".
86
125
  Return:
87
- api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
126
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"]
88
127
  api_sub_name: str, e.g. "relu"
89
- '''
128
+ """
90
129
  api_name_list = api_name_str.split(Const.SEP)
91
130
  if len(api_name_list) != 3:
92
131
  err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
93
132
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
94
133
  api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
95
- if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API]:
134
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API,
135
+ MsCompareConst.FUNCTIONAL_API] \
136
+ and api_platform == Const.MS_FRAMEWORK:
96
137
  err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
97
138
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
98
139
 
140
+ if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK:
141
+ err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api"
142
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
99
143
  return api_type_str, api_sub_name
100
144
 
101
145
  @staticmethod
102
146
  def get_api_instance(api_type_str, api_sub_name, api_platform):
103
- '''
147
+ """
104
148
  Args:
105
- api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
149
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
106
150
  api_sub_name: str, e.g. "relu"
107
- api_platform: str: Union["mindpore", "torch"]
151
+ api_platform: str: Union["mindpore", "pytorch"]
108
152
 
109
153
  Return:
110
154
  api_instance: function object
@@ -113,11 +157,15 @@ class ApiRunner:
113
157
  get mindspore.mint/torch api fucntion
114
158
  mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
115
159
  mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
116
- '''
117
-
118
- api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
119
- api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
160
+ """
161
+ if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch":
162
+ api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform))
163
+ api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform))
164
+ else:
165
+ api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
166
+ api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
120
167
  full_api_name = api_parent_module_str + Const.SEP + api_sub_name
168
+
121
169
  if not hasattr(api_parent_module, api_sub_name):
122
170
  err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
123
171
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
@@ -147,7 +195,7 @@ class ApiRunner:
147
195
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
148
196
  gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
149
197
  for compute_element in gradient_inputs)
150
- if api_platform == Const.MS_FRAMEWORK:
198
+ if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
151
199
  if len(gradient_inputs) == 1:
152
200
  gradient_inputs = gradient_inputs[0]
153
201
 
@@ -18,9 +18,10 @@ from abc import ABC, abstractmethod
18
18
  import mindspore
19
19
  import numpy as np
20
20
  import torch
21
- from msprobe.core.common.const import CompareConst, MsCompareConst
21
+ from msprobe.core.common.const import CompareConst
22
22
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
23
23
  from msprobe.mindspore.common.log import logger
24
+ from msprobe.mindspore.common.const import MsCompareConst
24
25
 
25
26
 
26
27
  class CompareResult: