mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,6 +1,68 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+
20
+ from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory
21
+ from msprobe.core.common.utils import Const, MsprobeBaseException
22
+
23
+ class UniqueDeviceAction(argparse.Action):
24
+ def __call__(self, parser, namespace, values, option_string=None):
25
+ unique_values = set(values)
26
+ if len(values) != len(unique_values):
27
+ parser.error("device id must be unique")
28
+ for device_id in values:
29
+ if not 0 <= device_id <= 4095:
30
+ parser.error(f"the argument 'device_id' must be in range [0, 4095], but got {device_id}")
31
+ setattr(namespace, self.dest, values)
32
+
33
+
1
34
  def add_api_accuracy_checker_argument(parser):
2
35
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
3
36
  help="<Required> The api param tool result file: generate from api param tool, "
4
37
  "a json file.")
5
38
  parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
6
- help="<optional> The ut task result out path.")
39
+ help="<optional> The ut task result out path.")
40
+ parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
41
+ help="<optional> the exit csv for continue")
42
+
43
+ def multi_add_api_accuracy_checker_argument(parser):
44
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
45
+ help="<Required> The api param tool result file: generate from api param tool, "
46
+ "a json file.")
47
+ parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
48
+ help="<optional> The ut task result out path.")
49
+ parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
50
+ help="<optional> the exit csv for continue")
51
+ #以下属于多线程参数
52
+ parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
53
+ help="<optional> set device id to run ut, must be unique and in range 0-7",
54
+ default=[0], required=False, action=UniqueDeviceAction)
55
+
56
+
57
+ def check_args(args):
58
+ args.api_info_file = os.path.abspath(args.api_info_file)
59
+ check_file_or_directory_path(args.api_info_file)
60
+
61
+ if args.out_path == "":
62
+ args.out_path = "./"
63
+ args.out_path = os.path.abspath(args.out_path)
64
+ create_directory(args.out_path)
65
+
66
+ if args.result_csv_path:
67
+ args.result_csv_path = os.path.abspath(args.result_csv_path)
68
+ check_file_or_directory_path(args.result_csv_path)
@@ -1,21 +1,37 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
 
3
18
  import mindspore
4
- import torch
5
19
  import numpy as np
6
-
7
- from msprobe.mindspore.common.log import logger
20
+ import torch
21
+ from mindspore._c_expression import typing
22
+ from msprobe.core.common.const import Const
8
23
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
9
24
  from msprobe.core.common.file_utils import load_npy
10
- from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
25
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
11
26
  ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
12
27
  dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
13
28
  dtype_str_to_torch_dtype, type_to_api_info_type_str,
14
29
  DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
15
- MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
16
- int_dtype_str_list)
17
- from msprobe.core.common.const import Const
30
+ MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
31
+ SLICE_TYPE_STR, TORCH_DTYPE_TYPE_STR,
32
+ float_dtype_str_list, int_dtype_str_list)
18
33
  from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
34
+ from msprobe.mindspore.common.log import logger
19
35
 
20
36
 
21
37
  class MstensorMetaData:
@@ -26,6 +42,12 @@ class MstensorMetaData:
26
42
  self.minimum = minimum
27
43
  self.shape = shape
28
44
 
45
+
46
+ class DtypeMetaData:
47
+ def __init__(self, dtype_str) -> None:
48
+ self.dtype_str = dtype_str
49
+
50
+
29
51
  class ComputeElement:
30
52
  def __init__(self, compute_element_info=None, parameter=None):
31
53
  self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
@@ -118,6 +140,11 @@ class ComputeElement:
118
140
  for compute_element in self.parameter])
119
141
  elif isinstance(self.parameter, self.supported_parameter_type):
120
142
  parameter_tmp = self.parameter
143
+ elif isinstance(self.parameter, DtypeMetaData):
144
+ if tensor_platform == Const.MS_FRAMEWORK:
145
+ parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
146
+ else:
147
+ parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
121
148
  elif isinstance(self.parameter, MstensorMetaData):
122
149
  mstensor_meta_data = self.parameter
123
150
  ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
@@ -130,13 +157,13 @@ class ComputeElement:
130
157
  parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
131
158
  else:
132
159
  err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
133
- "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
160
+ "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
134
161
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
135
162
 
136
163
  # if necessary, do transfer
137
164
  if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
138
165
  parameter = self.transfer_to_torch_tensor(parameter_tmp)
139
- elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
166
+ elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
140
167
  parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
141
168
  else:
142
169
  parameter = parameter_tmp
@@ -183,34 +210,38 @@ class ComputeElement:
183
210
  else:
184
211
  type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
185
212
  accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
186
-
213
+ self.shape = tuple()
214
+ self.dtype_str = type_str
187
215
  if type_str == MINDSPORE_TENSOR_TYPE_STR:
188
216
  self._init_from_mstensor_compute_element_info(compute_element_info)
189
- else: # type_str in ("slice", "int", "float", "bool")
217
+ else:
190
218
  value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
191
- self.shape = tuple()
192
- self.dtype_str = type_str
193
- self.parameter = slice(*tuple(value)) if type_str == "slice" else value
219
+ if type_str == MINDSPORE_DTYPE_TYPE_STR:
220
+ self.parameter = DtypeMetaData(value)
221
+ elif type_str == SLICE_TYPE_STR:
222
+ self.parameter = slice(*tuple(value))
223
+ else: # type_str in ("str", "int", "float", "bool")
224
+ self.parameter = value
194
225
 
195
226
  def _init_from_mstensor_compute_element_info(self, compute_element_info):
196
227
  '''
197
228
  do not load real tensor, only record meta data
198
229
  '''
199
230
  dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
200
- accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
231
+ accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
201
232
  shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
202
- accepted_type=(list,))
233
+ accepted_type=(list,))
203
234
  if global_context.get_is_constructed():
204
235
  maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
205
- accepted_type=(int, float))
236
+ accepted_type=(int, float))
206
237
  minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
207
- accepted_type=(int, float))
238
+ accepted_type=(int, float))
208
239
 
209
240
  npy_path = None
210
241
  else:
211
242
  maximum, minimum = None, None
212
243
  data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
213
- "data_name field in api_info.json", accepted_type=(str,))
244
+ "data_name field in api_info.json", accepted_type=(str,))
214
245
  npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
215
246
  mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
216
247
  self.parameter = mstensor_meta_data
@@ -219,9 +250,10 @@ class ComputeElement:
219
250
 
220
251
  def _init_with_parameter(self, parameter):
221
252
  self.parameter = parameter
253
+ self.shape = tuple()
222
254
  if not isinstance(parameter, self.supported_parameter_type):
223
255
  err_msg = "ComputeElement._init_with_parameter failed: " \
224
- "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
256
+ "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
225
257
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
226
258
  if isinstance(parameter, mindspore.Tensor):
227
259
  self.shape = tuple(parameter.shape)
@@ -229,11 +261,14 @@ class ComputeElement:
229
261
  elif isinstance(parameter, torch.Tensor):
230
262
  self.shape = tuple(parameter.shape)
231
263
  self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
264
+ elif isinstance(parameter, typing.Type):
265
+ self.dtype_str = MINDSPORE_DTYPE_TYPE_STR
266
+ self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter))
267
+ elif isinstance(parameter, torch.dtype):
268
+ self.dtype_str = TORCH_DTYPE_TYPE_STR
269
+ self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter))
232
270
  elif isinstance(parameter, tuple):
233
- self.shape = tuple()
234
271
  self.dtype_str = TUPLE_TYPE_STR
235
272
  self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
236
273
  else:
237
- self.shape = tuple()
238
- self.dtype_str = \
239
- TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
274
+ self.dtype_str = type_to_api_info_type_str.get(type(parameter))
@@ -0,0 +1,264 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import csv
18
+
19
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
20
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
21
+ from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
22
+ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
23
+ from msprobe.core.common.file_utils import check_file_or_directory_path
24
+ from msprobe.mindspore.common.log import logger
25
+
26
+
27
+ class ResultCsvEntry:
28
+ def __init__(self) -> None:
29
+ self.forward_pass_status = None
30
+ self.backward_pass_status = None
31
+ self.forward_err_msg = ""
32
+ self.backward_err_msg = ""
33
+ self.overall_err_msg = None
34
+
35
+
36
+ def write_csv_header(csv_path, header_func):
37
+ """如果是第一次写入,则写入 CSV 表头"""
38
+ header = header_func() # 获取表头
39
+ logger.debug(f"Writing CSV header: {header}")
40
+ write_csv([header], csv_path, mode="a+")
41
+
42
+
43
+ def get_result_csv_header():
44
+ """获取结果 CSV 文件的表头"""
45
+ return [
46
+ MsCompareConst.DETAIL_CSV_API_NAME,
47
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
48
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
49
+ MsCompareConst.DETAIL_CSV_MESSAGE,
50
+ ]
51
+
52
+
53
+ def get_detail_csv_header():
54
+ """获取详细 CSV 文件的表头"""
55
+ detail_csv_header_basic_info = [
56
+ MsCompareConst.DETAIL_CSV_API_NAME,
57
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
58
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
59
+ MsCompareConst.DETAIL_CSV_SHAPE,
60
+ ]
61
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
62
+ detail_csv_header_status = [
63
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
64
+ MsCompareConst.DETAIL_CSV_MESSAGE,
65
+ ]
66
+ return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
67
+
68
+
69
+ def check_csv_header(headers, required_constants, csv_path):
70
+ """校验 CSV 文件表头是否包含所有必需的常量"""
71
+ missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
72
+
73
+ if missing_constants:
74
+ raise MsprobeBaseException(
75
+ MsprobeBaseException.MISSING_HEADER_ERROR,
76
+ f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
77
+ )
78
+
79
+
80
+ class DataManager:
81
+ def __init__(self, csv_dir, result_csv_path):
82
+ self.results = {}
83
+ self.is_first_write = True # 标记用于添加表头
84
+ self.csv_dir = csv_dir
85
+ self.api_names_set = set() # 存储已经出现的 API 名称的集合
86
+ # 如果传入了 result_csv_path,则启用断点续检
87
+ if result_csv_path:
88
+ self.resume_from_last_csv(result_csv_path)
89
+ self.initialize_api_names_set(result_csv_path)
90
+ else:
91
+ # 默认情况下,设置输出路径为空,等待首次写入时初始化
92
+ self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
93
+ self.detail_out_path = os.path.join(
94
+ self.csv_dir,
95
+ os.path.basename(self.result_out_path).replace("result", "details")
96
+ )
97
+
98
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
99
+ check_file_or_directory_path(self.detail_out_path)
100
+
101
+ if self.result_out_path and os.path.exists(self.result_out_path):
102
+ check_file_or_directory_path(self.result_out_path)
103
+
104
+ def initialize_api_names_set(self, result_csv_path):
105
+ """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
106
+ # 使用新的 read_csv 函数读取数据
107
+ csv_data = read_csv(result_csv_path, as_pd=False)
108
+
109
+ # 读取标题行
110
+ headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
111
+
112
+ # 使用提取的表头校验函数
113
+ if check_csv_header(headers, get_result_csv_header(), result_csv_path):
114
+
115
+ # 获取 "API Name" 列的索引
116
+ api_name_index = None
117
+ for i, header in enumerate(headers):
118
+ if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
119
+ api_name_index = i
120
+ break
121
+
122
+ if api_name_index is None:
123
+ logger.warning(f"{result_csv_path} No column contains 'API Name'.")
124
+ return
125
+
126
+ # 读取每一行的 API 名称
127
+ for row in csv_data[1:]: # 跳过标题行,从第二行开始
128
+ if row and len(row) > api_name_index:
129
+ api_name = row[api_name_index]
130
+ if api_name:
131
+ self.api_names_set.add(api_name)
132
+
133
+ logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
134
+
135
+ def is_unique_api(self, api_name):
136
+ """检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
137
+ if api_name in self.api_names_set:
138
+ return False
139
+ self.api_names_set.add(api_name)
140
+ return True
141
+
142
+ def resume_from_last_csv(self, result_csv_path):
143
+ """从上次运行的 result_csv_path 恢复断点"""
144
+ # 获取上次的目录路径
145
+ last_dir = os.path.dirname(result_csv_path)
146
+
147
+ # 设置当前目录和输出路径,确保在首次写入时使用
148
+ self.csv_dir = last_dir
149
+ self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
150
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
151
+ check_file_or_directory_path(self.detail_out_path)
152
+ self.result_out_path = result_csv_path
153
+ self.is_first_write = False
154
+
155
+ def save_results(self, api_name_str):
156
+ if self.is_first_write:
157
+ # 直接写入表头
158
+ logger.info("Writing CSV headers for the first time.")
159
+ write_csv_header(self.detail_out_path, get_detail_csv_header)
160
+ write_csv_header(self.result_out_path, get_result_csv_header)
161
+ self.is_first_write = False # 写入后标记为 False,避免重复写入表头
162
+
163
+ """写入详细输出和结果摘要并清理结果"""
164
+ logger.debug("Starting to write detailed output to CSV.")
165
+ self.to_detail_csv(self.detail_out_path)
166
+ logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
167
+
168
+ logger.debug("Starting to write result summary to CSV.")
169
+ self.to_result_csv(self.result_out_path)
170
+ logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
171
+
172
+ # 清理记录,准备下一次调用
173
+ self.clear_results()
174
+
175
+ def record(self, output_list):
176
+ if output_list is None:
177
+ return
178
+ for output in output_list:
179
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
180
+ key = (api_real_name, forward_or_backward)
181
+ if key not in self.results:
182
+ self.results[key] = []
183
+ self.results[key].append((basic_info, compare_result_dict))
184
+ logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
185
+ logger.debug(f"Complete self.results after recording: {self.results}")
186
+
187
+ def clear_results(self):
188
+ """清空 self.results 数据"""
189
+ logger.debug("Clearing self.results data.")
190
+ self.results.clear()
191
+
192
+ def to_detail_csv(self, csv_path):
193
+ logger.debug("Preparing detail CSV headers and rows.")
194
+ detail_csv = []
195
+
196
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
197
+
198
+ for _, results in self.results.items():
199
+ for res in results:
200
+ basic_info, compare_result_dict = res
201
+ csv_row_basic_info = [
202
+ basic_info.api_name,
203
+ basic_info.bench_dtype,
204
+ basic_info.tested_dtype,
205
+ basic_info.shape
206
+ ]
207
+ csv_row_compare_result = [
208
+ compare_result_dict.get(algorithm_name).compare_value
209
+ for algorithm_name in detail_csv_header_compare_result
210
+ ]
211
+ csv_row_status = [basic_info.status, basic_info.err_msg]
212
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
213
+ detail_csv.append(csv_row)
214
+ logger.debug(f"Detail CSV row added: {csv_row}")
215
+
216
+ logger.debug(f"Writing detail CSV to {csv_path}.")
217
+ write_csv(detail_csv, csv_path, mode="a+")
218
+ logger.debug(f"Detail CSV written successfully to {csv_path}.")
219
+
220
+ def to_result_csv(self, csv_path):
221
+ logger.debug("Preparing result CSV data.")
222
+ result_csv = []
223
+
224
+ result_csv_dict = {}
225
+ for key, results in self.results.items():
226
+ api_real_name, forward_or_backward = key
227
+ pass_status = CompareConst.PASS
228
+ overall_err_msg = ""
229
+
230
+ for res in results:
231
+ basic_info, _ = res
232
+ if basic_info.status != CompareConst.PASS:
233
+ pass_status = CompareConst.ERROR
234
+ overall_err_msg += basic_info.err_msg
235
+
236
+ overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
237
+
238
+ if api_real_name not in result_csv_dict:
239
+ result_csv_dict[api_real_name] = ResultCsvEntry()
240
+ if forward_or_backward == Const.FORWARD:
241
+ result_csv_dict[api_real_name].forward_pass_status = pass_status
242
+ result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
243
+ else:
244
+ result_csv_dict[api_real_name].backward_pass_status = pass_status
245
+ result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
246
+
247
+ for api_name, entry in result_csv_dict.items():
248
+ overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
249
+ entry.backward_pass_status == CompareConst.PASS) else \
250
+ entry.forward_err_msg + entry.backward_err_msg
251
+ row = [
252
+ api_name,
253
+ entry.forward_pass_status,
254
+ entry.backward_pass_status,
255
+ overall_err_msg
256
+ ]
257
+ result_csv.append(row)
258
+ logger.debug(f"Result CSV row added: {row}")
259
+
260
+ write_csv(result_csv, csv_path, mode="a+")
261
+ logger.debug(f"Result CSV written successfully to {csv_path}.")
262
+
263
+ # 设置标记为 False,防止后续重复添加表头
264
+ self.is_first_write = False
@@ -1,9 +1,33 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
2
17
 
18
+ from msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker import MultiApiAccuracyChecker
19
+
20
+ from msprobe.mindspore.api_accuracy_checker.cmd_parser import check_args
21
+
3
22
 
4
23
  def api_checker_main(args):
5
- api_accuracy_checker = ApiAccuracyChecker()
24
+ check_args(args)
25
+ api_accuracy_checker = ApiAccuracyChecker(args)
26
+ api_accuracy_checker.parse(args.api_info_file)
27
+ api_accuracy_checker.run_and_compare()
28
+
29
+ def mul_api_checker_main(args):
30
+ check_args(args)
31
+ api_accuracy_checker = MultiApiAccuracyChecker(args)
6
32
  api_accuracy_checker.parse(args.api_info_file)
7
33
  api_accuracy_checker.run_and_compare()
8
- api_accuracy_checker.to_detail_csv(args.out_path)
9
- api_accuracy_checker.to_result_csv(args.out_path)