mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,264 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import csv
18
+
19
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
20
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
21
+ from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
22
+ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
23
+ from msprobe.core.common.file_utils import check_file_or_directory_path
24
+ from msprobe.mindspore.common.log import logger
25
+
26
+
27
+ class ResultCsvEntry:
28
+ def __init__(self) -> None:
29
+ self.forward_pass_status = None
30
+ self.backward_pass_status = None
31
+ self.forward_err_msg = ""
32
+ self.backward_err_msg = ""
33
+ self.overall_err_msg = None
34
+
35
+
36
+ def write_csv_header(csv_path, header_func):
37
+ """如果是第一次写入,则写入 CSV 表头"""
38
+ header = header_func() # 获取表头
39
+ logger.debug(f"Writing CSV header: {header}")
40
+ write_csv([header], csv_path, mode="a+")
41
+
42
+
43
+ def get_result_csv_header():
44
+ """获取结果 CSV 文件的表头"""
45
+ return [
46
+ MsCompareConst.DETAIL_CSV_API_NAME,
47
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
48
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
49
+ MsCompareConst.DETAIL_CSV_MESSAGE,
50
+ ]
51
+
52
+
53
+ def get_detail_csv_header():
54
+ """获取详细 CSV 文件的表头"""
55
+ detail_csv_header_basic_info = [
56
+ MsCompareConst.DETAIL_CSV_API_NAME,
57
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
58
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
59
+ MsCompareConst.DETAIL_CSV_SHAPE,
60
+ ]
61
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
62
+ detail_csv_header_status = [
63
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
64
+ MsCompareConst.DETAIL_CSV_MESSAGE,
65
+ ]
66
+ return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
67
+
68
+
69
+ def check_csv_header(headers, required_constants, csv_path):
70
+ """校验 CSV 文件表头是否包含所有必需的常量"""
71
+ missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
72
+
73
+ if missing_constants:
74
+ raise MsprobeBaseException(
75
+ MsprobeBaseException.MISSING_HEADER_ERROR,
76
+ f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
77
+ )
78
+
79
+
80
+ class DataManager:
81
+ def __init__(self, csv_dir, result_csv_path):
82
+ self.results = {}
83
+ self.is_first_write = True # 标记用于添加表头
84
+ self.csv_dir = csv_dir
85
+ self.api_names_set = set() # 存储已经出现的 API 名称的集合
86
+ # 如果传入了 result_csv_path,则启用断点续检
87
+ if result_csv_path:
88
+ self.resume_from_last_csv(result_csv_path)
89
+ self.initialize_api_names_set(result_csv_path)
90
+ else:
91
+ # 默认情况下,设置输出路径为空,等待首次写入时初始化
92
+ self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
93
+ self.detail_out_path = os.path.join(
94
+ self.csv_dir,
95
+ os.path.basename(self.result_out_path).replace("result", "details")
96
+ )
97
+
98
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
99
+ check_file_or_directory_path(self.detail_out_path)
100
+
101
+ if self.result_out_path and os.path.exists(self.result_out_path):
102
+ check_file_or_directory_path(self.result_out_path)
103
+
104
+ def initialize_api_names_set(self, result_csv_path):
105
+ """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
106
+ # 使用新的 read_csv 函数读取数据
107
+ csv_data = read_csv(result_csv_path, as_pd=False)
108
+
109
+ # 读取标题行
110
+ headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
111
+
112
+ # 使用提取的表头校验函数
113
+ if check_csv_header(headers, get_result_csv_header(), result_csv_path):
114
+
115
+ # 获取 "API Name" 列的索引
116
+ api_name_index = None
117
+ for i, header in enumerate(headers):
118
+ if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
119
+ api_name_index = i
120
+ break
121
+
122
+ if api_name_index is None:
123
+ logger.warning(f"{result_csv_path} No column contains 'API Name'.")
124
+ return
125
+
126
+ # 读取每一行的 API 名称
127
+ for row in csv_data[1:]: # 跳过标题行,从第二行开始
128
+ if row and len(row) > api_name_index:
129
+ api_name = row[api_name_index]
130
+ if api_name:
131
+ self.api_names_set.add(api_name)
132
+
133
+ logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
134
+
135
+ def is_unique_api(self, api_name):
136
+ """检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
137
+ if api_name in self.api_names_set:
138
+ return False
139
+ self.api_names_set.add(api_name)
140
+ return True
141
+
142
+ def resume_from_last_csv(self, result_csv_path):
143
+ """从上次运行的 result_csv_path 恢复断点"""
144
+ # 获取上次的目录路径
145
+ last_dir = os.path.dirname(result_csv_path)
146
+
147
+ # 设置当前目录和输出路径,确保在首次写入时使用
148
+ self.csv_dir = last_dir
149
+ self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
150
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
151
+ check_file_or_directory_path(self.detail_out_path)
152
+ self.result_out_path = result_csv_path
153
+ self.is_first_write = False
154
+
155
+ def save_results(self, api_name_str):
156
+ if self.is_first_write:
157
+ # 直接写入表头
158
+ logger.info("Writing CSV headers for the first time.")
159
+ write_csv_header(self.detail_out_path, get_detail_csv_header)
160
+ write_csv_header(self.result_out_path, get_result_csv_header)
161
+ self.is_first_write = False # 写入后标记为 False,避免重复写入表头
162
+
163
+ """写入详细输出和结果摘要并清理结果"""
164
+ logger.debug("Starting to write detailed output to CSV.")
165
+ self.to_detail_csv(self.detail_out_path)
166
+ logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
167
+
168
+ logger.debug("Starting to write result summary to CSV.")
169
+ self.to_result_csv(self.result_out_path)
170
+ logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
171
+
172
+ # 清理记录,准备下一次调用
173
+ self.clear_results()
174
+
175
+ def record(self, output_list):
176
+ if output_list is None:
177
+ return
178
+ for output in output_list:
179
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
180
+ key = (api_real_name, forward_or_backward)
181
+ if key not in self.results:
182
+ self.results[key] = []
183
+ self.results[key].append((basic_info, compare_result_dict))
184
+ logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
185
+ logger.debug(f"Complete self.results after recording: {self.results}")
186
+
187
+ def clear_results(self):
188
+ """清空 self.results 数据"""
189
+ logger.debug("Clearing self.results data.")
190
+ self.results.clear()
191
+
192
+ def to_detail_csv(self, csv_path):
193
+ logger.debug("Preparing detail CSV headers and rows.")
194
+ detail_csv = []
195
+
196
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
197
+
198
+ for _, results in self.results.items():
199
+ for res in results:
200
+ basic_info, compare_result_dict = res
201
+ csv_row_basic_info = [
202
+ basic_info.api_name,
203
+ basic_info.bench_dtype,
204
+ basic_info.tested_dtype,
205
+ basic_info.shape
206
+ ]
207
+ csv_row_compare_result = [
208
+ compare_result_dict.get(algorithm_name).compare_value
209
+ for algorithm_name in detail_csv_header_compare_result
210
+ ]
211
+ csv_row_status = [basic_info.status, basic_info.err_msg]
212
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
213
+ detail_csv.append(csv_row)
214
+ logger.debug(f"Detail CSV row added: {csv_row}")
215
+
216
+ logger.debug(f"Writing detail CSV to {csv_path}.")
217
+ write_csv(detail_csv, csv_path, mode="a+")
218
+ logger.debug(f"Detail CSV written successfully to {csv_path}.")
219
+
220
+ def to_result_csv(self, csv_path):
221
+ logger.debug("Preparing result CSV data.")
222
+ result_csv = []
223
+
224
+ result_csv_dict = {}
225
+ for key, results in self.results.items():
226
+ api_real_name, forward_or_backward = key
227
+ pass_status = CompareConst.PASS
228
+ overall_err_msg = ""
229
+
230
+ for res in results:
231
+ basic_info, _ = res
232
+ if basic_info.status != CompareConst.PASS:
233
+ pass_status = CompareConst.ERROR
234
+ overall_err_msg += basic_info.err_msg
235
+
236
+ overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
237
+
238
+ if api_real_name not in result_csv_dict:
239
+ result_csv_dict[api_real_name] = ResultCsvEntry()
240
+ if forward_or_backward == Const.FORWARD:
241
+ result_csv_dict[api_real_name].forward_pass_status = pass_status
242
+ result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
243
+ else:
244
+ result_csv_dict[api_real_name].backward_pass_status = pass_status
245
+ result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
246
+
247
+ for api_name, entry in result_csv_dict.items():
248
+ overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
249
+ entry.backward_pass_status == CompareConst.PASS) else \
250
+ entry.forward_err_msg + entry.backward_err_msg
251
+ row = [
252
+ api_name,
253
+ entry.forward_pass_status,
254
+ entry.backward_pass_status,
255
+ overall_err_msg
256
+ ]
257
+ result_csv.append(row)
258
+ logger.debug(f"Result CSV row added: {row}")
259
+
260
+ write_csv(result_csv, csv_path, mode="a+")
261
+ logger.debug(f"Result CSV written successfully to {csv_path}.")
262
+
263
+ # 设置标记为 False,防止后续重复添加表头
264
+ self.is_first_write = False
@@ -1,9 +1,33 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
2
17
 
18
+ from msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker import MultiApiAccuracyChecker
19
+
20
+ from msprobe.mindspore.api_accuracy_checker.cmd_parser import check_args
21
+
3
22
 
4
23
  def api_checker_main(args):
5
- api_accuracy_checker = ApiAccuracyChecker()
24
+ check_args(args)
25
+ api_accuracy_checker = ApiAccuracyChecker(args)
26
+ api_accuracy_checker.parse(args.api_info_file)
27
+ api_accuracy_checker.run_and_compare()
28
+
29
+ def mul_api_checker_main(args):
30
+ check_args(args)
31
+ api_accuracy_checker = MultiApiAccuracyChecker(args)
6
32
  api_accuracy_checker.parse(args.api_info_file)
7
33
  api_accuracy_checker.run_and_compare()
8
- api_accuracy_checker.to_detail_csv(args.out_path)
9
- api_accuracy_checker.to_result_csv(args.out_path)
@@ -0,0 +1,206 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # 标准库导入
17
+ import multiprocessing
18
+ from multiprocessing import Manager
19
+ import os
20
+ import signal
21
+ import sys
22
+ import time
23
+
24
+ # 第三方库导入
25
+ from mindspore import context
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+
29
+ # 本地应用/库特定导入
30
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
31
+ from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
32
+ from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
33
+ from msprobe.mindspore.common.log import logger
34
+
35
+
36
+ class MultiApiAccuracyChecker(ApiAccuracyChecker):
37
+ def __init__(self, args):
38
+ # 可以添加 MultiApiAccuracyChecker 特有的属性或方法
39
+ self.api_infos = dict()
40
+
41
+ # 使用 Manager 创建共享变量,确保进程间的同步
42
+ self.manager = Manager()
43
+ self.is_first_write = self.manager.Value('b', True) # 创建共享变量
44
+
45
+ # 初始化 DataManager 时传入共享的 is_first_write
46
+ self.multi_data_manager = MultiDataManager(args.out_path, args.result_csv_path, self.is_first_write)
47
+
48
+ self.args = args # 将 args 保存为类的属性
49
+
50
+ # 初始化一个属性来存储当前的设备ID(用于日志中显示)
51
+ self.current_device_id = None
52
+
53
+ def process_on_device(self, device_id, api_infos, progress_queue):
54
+ """
55
+ 在特定设备上处理一部分API。
56
+
57
+ 参数:
58
+ device_id (int): 要使用的设备ID。
59
+ api_infos (list): 包含API名称和对应信息的元组列表。
60
+ progress_queue (multiprocessing.Queue): 用于通信进度更新的队列。
61
+ """
62
+
63
+ # 设置当前设备ID
64
+ self.current_device_id = device_id
65
+
66
+ # 设置 MindSpore context 的 device_id
67
+ context.set_context(device_id=device_id)
68
+
69
+ # 遍历当前进程分配的任务
70
+ for _, (api_name_str, api_info) in enumerate(api_infos):
71
+ logger.debug(f"Processing API: {api_name_str}, Device: {device_id}")
72
+
73
+ if not self.multi_data_manager.is_unique_api(api_name_str):
74
+ logger.debug(f"API {api_name_str} is not unique, skipping.")
75
+ progress_queue.put(1)
76
+ continue
77
+
78
+ # 处理前向
79
+ forward_output_list = self.process_forward(api_name_str, api_info)
80
+ if forward_output_list is not Const.EXCEPTION_NONE:
81
+ self.multi_data_manager.record(forward_output_list)
82
+
83
+ # 处理反向
84
+ backward_output_list = self.process_backward(api_name_str, api_info)
85
+ if backward_output_list is not Const.EXCEPTION_NONE:
86
+ self.multi_data_manager.record(backward_output_list)
87
+
88
+ # 保存结果
89
+ self.multi_data_manager.save_results(api_name_str)
90
+ progress_queue.put(1) # 更新进度
91
+
92
+ def run_and_compare(self):
93
+ # 获取要使用的设备ID列表
94
+ device_ids = self.args.device_id
95
+
96
+ # 按设备数划分要处理的 API 项
97
+ partitioned_api_infos = list(self.api_infos.items())
98
+
99
+ # 在主进程中进行交叉任务切分(基于取模的方式)
100
+ partitioned_api_infos_split = [[] for _ in range(len(device_ids))]
101
+ for idx, api_info in enumerate(partitioned_api_infos):
102
+ device_index = idx % len(device_ids) # 使用取模方法分配任务
103
+ partitioned_api_infos_split[device_index].append(api_info)
104
+
105
+ # 创建一个共享进度队列
106
+ progress_queue = multiprocessing.Queue()
107
+
108
+ # 进度条
109
+ total_tasks = len(partitioned_api_infos) # 计算总任务数
110
+ with tqdm(total=total_tasks, desc="Total Progress", ncols=100) as pbar:
111
+ # 创建多进程
112
+ processes = []
113
+ for index, device_id in enumerate(device_ids):
114
+ process = multiprocessing.Process(target=self.process_on_device,
115
+ args=(device_id, partitioned_api_infos_split[index], progress_queue))
116
+ processes.append(process)
117
+ process.start()
118
+
119
+ # 主进程更新进度条
120
+ completed_tasks = 0
121
+ while completed_tasks < total_tasks:
122
+ try:
123
+ completed_tasks += progress_queue.get(timeout=Const.PROGRESS_TIMEOUT) # 设置超时时间(秒)
124
+ pbar.update(1)
125
+ except multiprocessing.queues.Empty:
126
+ logger.error("Timeout while waiting for progress updates. Skipping remaining tasks.")
127
+ break
128
+
129
+ # 检查子进程状态
130
+ for process in processes:
131
+ if not process.is_alive():
132
+ if process.exitcode != 0:
133
+ logger.error(f"Process {process.pid} exited with code {process.exitcode}.")
134
+ total_tasks -= len(partitioned_api_infos_split[processes.index(process)])
135
+ processes.remove(process)
136
+
137
+ # 确保所有子进程完成或终止
138
+ for process in processes:
139
+ process.join(timeout=Const.PROGRESS_TIMEOUT)
140
+ if process.is_alive():
141
+ logger.error(f"Process {process.pid} did not terminate. Forcing termination.")
142
+ process.terminate()
143
+
144
+ def process_forward(self, api_name_str, api_info):
145
+ """
146
+ Overrides the parent class's process_forward method to log the device ID when exceptions occur.
147
+
148
+ Parameters:
149
+ api_name_str (str): The name of the API.
150
+ api_info (object): The API information object.
151
+
152
+ Returns:
153
+ list or None: The forward output list or None if an error occurs.
154
+ """
155
+ if not api_info.check_forward_info():
156
+ logger.debug(
157
+ f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping forward check.")
158
+ return Const.EXCEPTION_NONE
159
+
160
+ try:
161
+ forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
162
+ except Exception as e:
163
+ logger.warning(
164
+ f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for {api_name_str}. Skipping forward check. Detailed exception information: {e}.")
165
+ return Const.EXCEPTION_NONE
166
+
167
+ forward_output_list = None
168
+ try:
169
+ forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
170
+ Const.FORWARD)
171
+ except Exception as e:
172
+ logger.warning(
173
+ f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} forward API. Detailed exception information: {e}.")
174
+ return forward_output_list
175
+
176
+ def process_backward(self, api_name_str, api_info):
177
+ """
178
+ Overrides the parent class's process_backward method to log the device ID when exceptions occur.
179
+
180
+ Parameters:
181
+ api_name_str (str): The name of the API.
182
+ api_info (object): The API information object.
183
+
184
+ Returns:
185
+ list or None: The backward output list or None if an error occurs.
186
+ """
187
+ if not api_info.check_backward_info():
188
+ logger.debug(
189
+ f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping backward check.")
190
+ return Const.EXCEPTION_NONE
191
+
192
+ try:
193
+ backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
194
+ except Exception as e:
195
+ logger.warning(
196
+ f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for {api_name_str}. Skipping backward check. Detailed exception information: {e}.")
197
+ return Const.EXCEPTION_NONE
198
+
199
+ backward_output_list = None
200
+ try:
201
+ backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
202
+ Const.BACKWARD)
203
+ except Exception as e:
204
+ logger.warning(
205
+ f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} backward API. Detailed exception information: {e}.")
206
+ return backward_output_list
@@ -0,0 +1,58 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import multiprocessing
18
+ import os
19
+
20
+ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager, ResultCsvEntry, write_csv_header, get_result_csv_header, get_detail_csv_header, check_csv_header
21
+ from msprobe.mindspore.common.log import logger
22
+
23
+
24
+ class MultiDataManager(DataManager):
25
+ def __init__(self, csv_dir, result_csv_path, shared_is_first_write):
26
+ super().__init__(csv_dir, result_csv_path)
27
+
28
+ # 使用共享的 is_first_write 变量来控制表头写入
29
+ self.shared_is_first_write = shared_is_first_write
30
+ # 创建锁对象,确保线程安全
31
+ self.lock = multiprocessing.Lock()
32
+
33
+ def save_results(self, api_name_str):
34
+ """保存结果,线程安全操作"""
35
+
36
+ with self.lock: # 确保保存操作不会被多个进程同时进行
37
+ if self.is_first_write and self.shared_is_first_write.value:
38
+ self.shared_is_first_write.value = False
39
+ self.is_first_write = False # 写入后标记为 False,避免重复写入表头
40
+ # 直接写入表头
41
+ logger.info("Writing CSV headers for the first time.")
42
+ write_csv_header(self.detail_out_path, get_detail_csv_header)
43
+ write_csv_header(self.result_out_path, get_result_csv_header)
44
+
45
+ """写入详细输出和结果摘要并清理结果"""
46
+ self.to_detail_csv(self.detail_out_path)
47
+ logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
48
+
49
+ self.to_result_csv(self.result_out_path)
50
+ logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
51
+
52
+ # 清理记录,准备下一次调用
53
+ self.clear_results()
54
+
55
+ def clear_results(self):
56
+ """清空 self.results 数据,线程安全操作"""
57
+ logger.debug("Clearing results data.")
58
+ self.results.clear()
@@ -1,7 +1,23 @@
1
- from mindspore.common import dtype as mstype
2
- import numpy as np
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
3
16
  import mindspore
17
+ import numpy as np
4
18
  import torch
19
+ from mindspore._c_expression import typing
20
+ from mindspore.common import dtype as mstype
5
21
 
6
22
  INT8 = "Int8"
7
23
  UINT8 = "UInt8"
@@ -18,7 +34,6 @@ BOOL = "Bool"
18
34
  BFLOAT16 = "BFloat16"
19
35
  INT4 = "Int4"
20
36
 
21
-
22
37
  dtype_str_to_ms_dtype = {
23
38
  INT8: mstype.int8,
24
39
  UINT8: mstype.uint8,
@@ -37,7 +52,6 @@ dtype_str_to_ms_dtype = {
37
52
  }
38
53
  ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
39
54
 
40
-
41
55
  dtype_str_to_np_dtype = {
42
56
  INT8: np.int8,
43
57
  UINT8: np.uint8,
@@ -75,6 +89,8 @@ FLOAT_TYPE_STR = "float"
75
89
  SLICE_TYPE_STR = "slice"
76
90
  TUPLE_TYPE_STR = "tuple"
77
91
  STR_TYPE_STR = "str"
92
+ MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype"
93
+ TORCH_DTYPE_TYPE_STR = "torch.dtype"
78
94
 
79
95
  api_info_type_str_to_type = {
80
96
  MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
@@ -83,6 +99,7 @@ api_info_type_str_to_type = {
83
99
  FLOAT_TYPE_STR: float,
84
100
  SLICE_TYPE_STR: slice,
85
101
  STR_TYPE_STR: str,
102
+ MINDSPORE_DTYPE_TYPE_STR: typing.Type,
86
103
  }
87
104
  type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
88
105
 
@@ -111,4 +128,4 @@ uint_dtype_str_list = [
111
128
  UINT16,
112
129
  UINT32,
113
130
  UINT64,
114
- ]
131
+ ]