mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,16 +1,34 @@
1
- import json
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import os
17
+ from tqdm import tqdm
3
18
 
4
- from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
5
- from msprobe.core.common.utils import add_time_as_suffix
6
19
  from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
- from msprobe.mindspore.common.log import logger
20
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
21
+ from msprobe.core.common.utils import add_time_as_suffix
8
22
  from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
23
  from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
24
  from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
25
+ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
11
26
  from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
12
27
  trim_output_compute_element_list)
28
+ from msprobe.mindspore.common.log import logger
13
29
 
30
+ cur_path = os.path.dirname(os.path.realpath(__file__))
31
+ yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
14
32
 
15
33
  class BasicInfoAndStatus:
16
34
  def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
@@ -21,6 +39,7 @@ class BasicInfoAndStatus:
21
39
  self.status = status
22
40
  self.err_msg = err_msg
23
41
 
42
+
24
43
  class ResultCsvEntry:
25
44
  def __init__(self) -> None:
26
45
  self.forward_pass_status = None
@@ -31,9 +50,9 @@ class ResultCsvEntry:
31
50
 
32
51
 
33
52
  class ApiAccuracyChecker:
34
- def __init__(self):
53
+ def __init__(self, args):
35
54
  self.api_infos = dict()
36
- self.results = dict()
55
+ self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
37
56
 
38
57
  @staticmethod
39
58
  def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
@@ -80,13 +99,13 @@ class ApiAccuracyChecker:
80
99
  compare_result_dict[compare_algorithm_name] = compare_result
81
100
 
82
101
  if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
83
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
102
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
84
103
  status = CompareConst.PASS
85
104
  err_msg = ""
86
105
  else:
87
106
  status = CompareConst.ERROR
88
107
  err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
89
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
108
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
90
109
  basic_info_status = \
91
110
  BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
92
111
  output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
@@ -109,13 +128,35 @@ class ApiAccuracyChecker:
109
128
  gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
110
129
  return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
111
130
 
131
+ @staticmethod
132
+ def is_api_checkable(api_name_str):
133
+ '''
134
+ Args:
135
+ api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
136
+ Returns:
137
+ is_checkable: bool
138
+ Description:
139
+ tell whether this api is checkable based on the key in "data" dict in api_info.json
140
+ '''
141
+ api_name_str_list = api_name_str.split(Const.SEP)
142
+ if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
143
+ return False
144
+ api_type_str = api_name_str_list[0]
145
+ real_api_str = Const.SEP.join(api_name_str_list[1:-2])
146
+ api_list = load_yaml(yaml_path)
147
+ supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
148
+ if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL):
149
+ return True
150
+ if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list:
151
+ return True
152
+ return False
153
+
112
154
  def parse(self, api_info_path):
113
- with FileOpen(api_info_path, "r") as f:
114
- api_info_dict = json.load(f)
155
+ api_info_dict = load_json(api_info_path)
115
156
 
116
157
  # init global context
117
158
  task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
118
- "task field in api_info.json",accepted_type=str,
159
+ "task field in api_info.json", accepted_type=str,
119
160
  accepted_value=(MsCompareConst.STATISTICS_TASK,
120
161
  MsCompareConst.TENSOR_TASK))
121
162
  is_constructed = task == MsCompareConst.STATISTICS_TASK
@@ -129,14 +170,12 @@ class ApiAccuracyChecker:
129
170
  api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
130
171
  "data field in api_info.json", accepted_type=dict)
131
172
  for api_name, api_info in api_info_data.items():
132
- is_mint = api_name.split(Const.SEP)[0] in \
133
- (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
134
- if not is_mint:
173
+ if not self.is_api_checkable(api_name):
135
174
  continue
136
175
  forbackward_str = api_name.split(Const.SEP)[-1]
137
176
  if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
138
177
  logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
139
- api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
178
+ api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
140
179
  if api_name not in self.api_infos:
141
180
  self.api_infos[api_name] = ApiInfo(api_name)
142
181
 
@@ -145,135 +184,64 @@ class ApiAccuracyChecker:
145
184
  else:
146
185
  self.api_infos[api_name].load_backward_info(api_info)
147
186
 
187
+ def process_forward(self, api_name_str, api_info):
188
+ """处理前向检查"""
189
+ if not api_info.check_forward_info():
190
+ logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
191
+ return Const.EXCEPTION_NONE
192
+
193
+ try:
194
+ forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
195
+ except Exception as e:
196
+ logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
197
+ f"Skipping forward check. Detailed exception information: {e}.")
198
+ return Const.EXCEPTION_NONE
199
+
200
+ forward_output_list = None
201
+ try:
202
+ forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
203
+ except Exception as e:
204
+ logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
205
+ f"Detailed exception information: {e}.")
206
+ return forward_output_list
207
+
208
+ def process_backward(self, api_name_str, api_info):
209
+ """处理反向检查"""
210
+ if not api_info.check_backward_info():
211
+ logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
212
+ return Const.EXCEPTION_NONE
213
+
214
+ try:
215
+ backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
216
+ except Exception as e:
217
+ logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
218
+ f"Skipping backward check. Detailed exception information: {e}.")
219
+ return Const.EXCEPTION_NONE
220
+
221
+ backward_output_list = None
222
+ try:
223
+ backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
224
+ except Exception as e:
225
+ logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
226
+ f"Detailed exception information: {e}.")
227
+ return backward_output_list
228
+
229
+
230
+
148
231
  def run_and_compare(self):
149
- for api_name_str, api_info in self.api_infos.items():
150
- if not api_info.check_forward_info():
151
- logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check.")
152
- continue
153
- try:
154
- forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
155
- except Exception as e:
156
- logger.warning(f"exception occurs when getting inputs for {api_name_str} forward api. "
157
- f"skip forward and backward check. detailed exception information: {e}.")
158
- continue
159
- forward_output_list = None
160
- try:
161
- forward_output_list = \
162
- self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
163
- except Exception as e:
164
- logger.warning(f"exception occurs when running and comparing {api_name_str} forward api. "
165
- f"detailed exception information: {e}.")
166
- self.record(forward_output_list)
167
-
168
- if not api_info.check_backward_info():
169
- logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check.")
170
- continue
171
- try:
172
- backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
173
- except Exception as e:
174
- logger.warning(f"exception occurs when getting inputs for {api_name_str} backward api. "
175
- f"skip backward check. detailed exception information: {e}.")
232
+ for api_name_str, api_info in tqdm(self.api_infos.items()):
233
+ if not self.data_manager.is_unique_api(api_name_str):
176
234
  continue
177
- backward_output_list = None
178
- try:
179
- backward_output_list = \
180
- self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
181
- except Exception as e:
182
- logger.warning(f"exception occurs when running and comparing {api_name_str} backward api. "
183
- f"detailed exception information: {e}.")
184
- self.record(backward_output_list)
185
-
186
- def record(self, output_list):
187
- if output_list is None:
188
- return
189
- for output in output_list:
190
- api_real_name, forward_or_backward, basic_info, compare_result_dict = output
191
- key = tuple([api_real_name, forward_or_backward])
192
- if key not in self.results:
193
- self.results[key] = []
194
- self.results[key].append(tuple([basic_info, compare_result_dict]))
195
-
196
-
197
- def to_detail_csv(self, csv_dir):
198
- # detail_csv
199
- detail_csv = []
200
- detail_csv_header_basic_info = [
201
- MsCompareConst.DETAIL_CSV_API_NAME,
202
- MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
203
- MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
204
- MsCompareConst.DETAIL_CSV_SHAPE,
205
- ]
206
- detail_csv_header_compare_result = list(compare_algorithms.keys())
207
- detail_csv_header_status = [
208
- MsCompareConst.DETAIL_CSV_PASS_STATUS,
209
- MsCompareConst.DETAIL_CSV_MESSAGE,
210
- ]
211
-
212
- detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
213
- detail_csv.append(detail_csv_header)
214
-
215
- for _, results in self.results.items():
216
- # detail csv
217
- for res in results:
218
- basic_info, compare_result_dict = res
219
- csv_row_basic_info = \
220
- [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
221
- csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
222
- for algorithm_name in detail_csv_header_compare_result)
223
- csv_row_status = [basic_info.status, basic_info.err_msg]
224
- csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
225
- detail_csv.append(csv_row)
226
-
227
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
228
- create_directory(csv_dir)
229
- write_csv(detail_csv, file_name, mode="w")
230
-
231
-
232
- def to_result_csv(self, csv_dir):
233
- result_csv_dict = dict()
234
- for key, results in self.results.items():
235
- api_real_name, forward_or_backward = key
236
- forward_or_backward_pass_status = CompareConst.PASS
237
- forward_or_backward_overall_err_msg = ""
238
- # detail csv
239
- for res in results:
240
- basic_info, _ = res
241
- if basic_info.status != CompareConst.PASS:
242
- forward_or_backward_pass_status = CompareConst.ERROR
243
- forward_or_backward_overall_err_msg += basic_info.err_msg
244
- forward_or_backward_overall_err_msg = \
245
- "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
246
-
247
- #result_csv_dict
248
- if api_real_name not in result_csv_dict:
249
- result_csv_dict[api_real_name] = ResultCsvEntry()
250
- if forward_or_backward == Const.FORWARD:
251
- result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
252
- result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
253
- else:
254
- result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
255
- result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
256
-
257
- #result_csv
258
- result_csv = []
259
- result_csv_header = [
260
- MsCompareConst.DETAIL_CSV_API_NAME,
261
- MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
262
- MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
263
- MsCompareConst.DETAIL_CSV_MESSAGE,
264
- ]
265
- result_csv.append(result_csv_header)
266
-
267
- for api_name, result_csv_entry in result_csv_dict.items():
268
- if result_csv_entry.forward_pass_status == CompareConst.PASS and \
269
- result_csv_entry.backward_pass_status == CompareConst.PASS:
270
- overall_err_msg = ""
271
- else:
272
- overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
273
- row = [api_name, result_csv_entry.forward_pass_status,
274
- result_csv_entry.backward_pass_status, overall_err_msg]
275
- result_csv.append(row)
276
-
277
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
278
- create_directory(csv_dir)
279
- write_csv(result_csv, file_name, mode="w")
235
+
236
+ # 处理前向
237
+ forward_output_list = self.process_forward(api_name_str, api_info)
238
+ if forward_output_list is not Const.EXCEPTION_NONE:
239
+ self.data_manager.record(forward_output_list)
240
+
241
+ # 处理反向
242
+ backward_output_list = self.process_backward(api_name_str, api_info)
243
+ if backward_output_list is not Const.EXCEPTION_NONE:
244
+ self.data_manager.record(backward_output_list)
245
+
246
+ self.data_manager.save_results(api_name_str)
247
+
@@ -1,9 +1,25 @@
1
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from msprobe.core.common.const import Const
3
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
17
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
- from msprobe.mindspore.common.log import logger
6
18
  from msprobe.core.common.utils import is_invalid_pattern
19
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
20
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
21
+ from msprobe.mindspore.common.log import logger
22
+
7
23
 
8
24
  class ApiInfo:
9
25
  def __init__(self, api_name):
@@ -66,11 +82,10 @@ class ApiInfo:
66
82
  err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
67
83
  logger.error_log_with_exp(err_msg,
68
84
  ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
69
- if not isinstance(compute_element_info, (list, dict)):
70
- err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
85
+ if not (isinstance(compute_element_info, (list, dict)) or compute_element_info is None):
86
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list, dict or null"
71
87
  logger.error_log_with_exp(err_msg,
72
88
  ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
73
89
  kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
74
90
  for key_str, compute_element_info in kwargs_dict.items()}
75
91
  return kwargs_compute_element_dict
76
-
@@ -1,15 +1,27 @@
1
-
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
2
15
 
3
16
  import mindspore
4
17
  import torch
5
18
  from mindspore import ops
6
-
7
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
19
  from msprobe.core.common.const import Const, MsCompareConst
9
20
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
- from msprobe.mindspore.common.log import logger
11
- from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
21
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
12
22
  from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
23
+ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
24
+ from msprobe.mindspore.common.log import logger
13
25
 
14
26
 
15
27
  class ApiInputAggregation:
@@ -24,11 +36,23 @@ class ApiInputAggregation:
24
36
  self.kwargs = kwargs
25
37
  self.gradient_inputs = gradient_inputs
26
38
 
39
+
27
40
  api_parent_module_mapping = {
28
41
  (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
29
42
  (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
30
43
  (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
31
- (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
44
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
45
+ (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
46
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
47
+ }
48
+
49
+ api_parent_module_str_mapping = {
50
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
51
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
52
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
53
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
54
+ (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
55
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
32
56
  }
33
57
 
34
58
 
@@ -60,7 +84,7 @@ class ApiRunner:
60
84
  api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
61
85
 
62
86
  Return:
63
- api_type_str: str, Union["MintFunctional", "Mint"]
87
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
64
88
  api_sub_name: str, e.g. "relu"
65
89
  '''
66
90
  api_name_list = api_name_str.split(Const.SEP)
@@ -68,8 +92,8 @@ class ApiRunner:
68
92
  err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
69
93
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
70
94
  api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
71
- if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
72
- err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
95
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API]:
96
+ err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
73
97
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
74
98
 
75
99
  return api_type_str, api_sub_name
@@ -78,7 +102,7 @@ class ApiRunner:
78
102
  def get_api_instance(api_type_str, api_sub_name, api_platform):
79
103
  '''
80
104
  Args:
81
- api_type_str: str, Union["MintFunctional", "Mint"]
105
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
82
106
  api_sub_name: str, e.g. "relu"
83
107
  api_platform: str: Union["mindpore", "torch"]
84
108
 
@@ -92,9 +116,8 @@ class ApiRunner:
92
116
  '''
93
117
 
94
118
  api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
95
- module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
96
- submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
97
- full_api_name = module_str + submodule_str + api_sub_name
119
+ api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
120
+ full_api_name = api_parent_module_str + Const.SEP + api_sub_name
98
121
  if not hasattr(api_parent_module, api_sub_name):
99
122
  err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
100
123
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
@@ -115,7 +138,7 @@ class ApiRunner:
115
138
  gradient_inputs = api_input_aggregation.gradient_inputs
116
139
 
117
140
  if forward_or_backward == Const.FORWARD:
118
- forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
141
+ forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
119
142
  forward_result_tuple = convert_to_tuple(forward_result)
120
143
  res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
121
144
  else:
@@ -127,18 +150,20 @@ class ApiRunner:
127
150
  if api_platform == Const.MS_FRAMEWORK:
128
151
  if len(gradient_inputs) == 1:
129
152
  gradient_inputs = gradient_inputs[0]
153
+
130
154
  def api_with_kwargs(*forward_inputs):
131
155
  return api_instance(*forward_inputs, **kwargs)
156
+
132
157
  grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
133
- backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
158
+ backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
134
159
  backward_result_tuple = convert_to_tuple(backward_result)
135
160
  res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
136
161
  else:
137
- #set requires_grad
162
+ # set requires_grad
138
163
  requires_grad_index = []
139
164
  for index, tensor in enumerate(inputs):
140
165
  if isinstance(tensor, torch.Tensor) and \
141
- torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
166
+ torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
142
167
  setattr(tensor, "requires_grad", True)
143
168
  requires_grad_index.append(index)
144
169
  forward_results = api_instance(*inputs, **kwargs)
@@ -153,4 +178,4 @@ class ApiRunner:
153
178
  return res_compute_element_list
154
179
 
155
180
 
156
- api_runner = ApiRunner()
181
+ api_runner = ApiRunner()
@@ -1,12 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
17
 
3
18
  import mindspore
4
- import torch
5
19
  import numpy as np
6
-
20
+ import torch
21
+ from msprobe.core.common.const import CompareConst, MsCompareConst
7
22
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
8
23
  from msprobe.mindspore.common.log import logger
9
- from msprobe.core.common.const import CompareConst, MsCompareConst
24
+
10
25
 
11
26
  class CompareResult:
12
27
  def __init__(self, compare_value, pass_status, err_msg):
@@ -28,7 +43,7 @@ class BaseCompareAlgorithm(ABC):
28
43
  CompareConst.MAX_ABS_ERR: {
29
44
  CompareConst.PASS: "",
30
45
  CompareConst.ERROR: "max absolute difference is greater than " \
31
- f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
46
+ f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
32
47
  CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
33
48
  },
34
49
  CompareConst.MAX_RELATIVE_ERR: {
@@ -68,7 +83,7 @@ class BaseCompareAlgorithm(ABC):
68
83
  ndarray = tensor.to(torch.float64, copy=True).numpy()
69
84
  else:
70
85
  err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
71
- "input is not mindspore.Tensor or torch.Tensor"
86
+ "input is not mindspore.Tensor or torch.Tensor"
72
87
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
73
88
  return ndarray
74
89
 
@@ -189,9 +204,8 @@ class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
189
204
  return CompareConst.ERROR
190
205
 
191
206
 
192
-
193
207
  compare_algorithms = {
194
208
  CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
195
209
  CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
196
210
  CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
197
- }
211
+ }
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ # list of api that can be checked
17
+
18
+ tensor:
19
+ - add_
20
+ - add
21
+ - addmm_
22
+ - all
23
+ - allclose
24
+ - any
25
+ - bool
26
+ - byte
27
+ - ceil
28
+ - clamp
29
+ - contiguous
30
+ - copy_
31
+ - cos
32
+ - clone
33
+ - cumprod
34
+ - expand_as
35
+ - flatten
36
+ - float
37
+ - half
38
+ - int
39
+ - is_contiguous
40
+ - isnan
41
+ - item
42
+ - log
43
+ - log2
44
+ - long
45
+ - masked_fill
46
+ - max
47
+ - mean
48
+ - min
49
+ - numel
50
+ - numpy
51
+ - repeat
52
+ - repeat_interleave
53
+ - reshape
54
+ - round
55
+ - select
56
+ - sin
57
+ - size
58
+ - split
59
+ - sqrt
60
+ - square
61
+ - sub
62
+ - swapaxes
63
+ - to
64
+ - t
65
+ - tolist
66
+ - topk
67
+ - transpose
68
+ - trunc
69
+ - type
70
+ - unsqueeze
71
+ - view
72
+ - view_as
73
+ - fill_
74
+ - floor_
75
+ - clamp_
76
+ - type_as
77
+ - zero_