mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,211 @@
1
+ # 动态图精度数据采集快速入门示例
2
+
3
+ 本示例将展示如何在 MindSpore 动态图模式下使用 msprobe 工具进行精度数据采集。
4
+
5
+ ## 1. 配置文件
6
+
7
+ 请在当前目录下创建一个名为 `config.json` 的配置文件,内容如下:
8
+
9
+ ```json
10
+ {
11
+ "task": "statistics",
12
+ "dump_path": "./output",
13
+ "rank": [],
14
+ "step": ["0-2"],
15
+ "level": "L1",
16
+ "statistics": {
17
+ "scope": [],
18
+ "list": [],
19
+ "data_mode": [
20
+ "all"
21
+ ],
22
+ "summary_mode": "statistics"
23
+ }
24
+ }
25
+
26
+ ```
27
+ 以上配置参数详细介绍和使用请参见[《config.json 配置文件介绍》](../02.config_introduction.md)和[《config.json 配置示例》](../03.config_examples.md#3-mindspore-动态图场景) 中的“MindSpore动态图场景”。
28
+
29
+ ## 2. 模型脚本
30
+
31
+ 在当前目录下创建一个 Python 脚本文件,例如 `alexnet_model.py`,将以下代码粘贴进去:
32
+
33
+ ```python
34
+ import os
35
+ import numpy as np
36
+ import mindspore as ms
37
+ from mindspore import nn, ops
38
+ from mindspore import context
39
+ from mindspore import Tensor
40
+ from msprobe.mindspore import PrecisionDebugger, seed_all
41
+
42
+ # 设置随机种子以确保结果可重现
43
+ seed_all(seed=1234, mode=False, rm_dropout=True)
44
+
45
+ # 配置文件路径
46
+ script_dir = os.path.dirname(os.path.abspath(__file__))
47
+ config_path = os.path.join(script_dir, 'config.json')
48
+
49
+ # 初始化精度调试器
50
+ debugger = PrecisionDebugger(config_path=config_path)
51
+
52
+ # 设置 MindSpore 设备上下文
53
+ context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0)
54
+
55
+ # 定义卷积层
56
+ def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid", has_bias=True):
57
+ return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
58
+ has_bias=has_bias, pad_mode=pad_mode)
59
+
60
+ # 定义全连接层
61
+ def fc_layer(input_channels, out_channels, has_bias=True):
62
+ return nn.Dense(input_channels, out_channels, has_bias=has_bias)
63
+
64
+
65
+ class AlexNet(nn.Cell):
66
+ """
67
+ AlexNet 模型定义
68
+
69
+ 参数:
70
+ - num_classes: 分类数量
71
+ - channel: 输入通道数(图像的颜色通道数)
72
+ - phase: 模型运行阶段('train' 或 'test')
73
+ - include_top: 是否包含全连接层的顶部(最后的分类层)
74
+ """
75
+ def __init__(self, num_classes=10, channel=3, phase='train', include_top=True):
76
+ super(AlexNet, self).__init__()
77
+
78
+ # 卷积层
79
+ self.conv1 = conv_layer(channel, 64, 11, stride=4, pad_mode="same")
80
+ self.conv2 = conv_layer(64, 128, 5, pad_mode="same")
81
+ self.conv3 = conv_layer(128, 192, 3, pad_mode="same")
82
+ self.conv4 = conv_layer(192, 256, 3, pad_mode="same")
83
+ self.conv5 = conv_layer(256, 256, 3, pad_mode="same")
84
+
85
+ # 激活函数和池化层
86
+ self.relu = nn.ReLU()
87
+ self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
88
+
89
+ # 如果包括顶部(全连接层)
90
+ self.include_top = include_top
91
+ if self.include_top:
92
+ self.flatten = nn.Flatten()
93
+ self.fc1 = fc_layer(256 * 28 * 28, 4096)
94
+ self.fc2 = fc_layer(4096, 4096)
95
+ self.fc3 = fc_layer(4096, num_classes)
96
+
97
+ # 数学操作
98
+ self.add = ops.Add()
99
+ self.mul = ops.Mul()
100
+
101
+ def construct(self, x):
102
+ """定义前向传播过程"""
103
+
104
+ x = self.conv1(x)
105
+ x = self.add(x, 0.1) # 偏置加法
106
+ x = self.mul(x, 2.0) # 乘法操作
107
+ x = self.relu(x) # ReLU 激活函数
108
+ x = ops.celu(x)
109
+ x = x + 2
110
+
111
+ # 打印每层输出形状,调试时可使用
112
+ print(f"After Conv1: {x.shape}")
113
+
114
+ x = self.max_pool2d(x) # Max pooling 操作
115
+ print(f"After MaxPool: {x.shape}") # 打印池化后的形状
116
+
117
+ x = self.conv2(x)
118
+ x = self.relu(x)
119
+
120
+ x = self.conv3(x)
121
+ x = self.relu(x)
122
+
123
+ x = self.conv4(x)
124
+ x = self.relu(x)
125
+
126
+ x = self.conv5(x)
127
+ x = self.relu(x)
128
+
129
+ # 打印卷积层后的形状,调试时使用
130
+ print(f"After Conv5: {x.shape}")
131
+
132
+ # 可选的全连接层部分
133
+ if self.include_top:
134
+ x = self.flatten(x)
135
+ x = self.fc1(x)
136
+ x = self.fc2(x)
137
+ x = self.fc3(x)
138
+
139
+ return x
140
+
141
+ # 前向函数
142
+ def forward_fn(data, label):
143
+ out = net(data)
144
+ loss = criterion(out, label)
145
+ return loss
146
+
147
+ # 训练步骤
148
+ def train_step(data, label):
149
+ loss, grads = grad_fn(data, label)
150
+ optimizer(grads)
151
+ return loss
152
+
153
+ # 测试模型
154
+ if __name__ == "__main__":
155
+ net = AlexNet()
156
+ optimizer = nn.SGD(net.trainable_params(), learning_rate=0.01)
157
+ criterion = nn.MSELoss()
158
+
159
+ grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)
160
+
161
+ # 生成数据和标签
162
+ batch_size = 1
163
+ num_classes = 10
164
+ data = np.random.normal(1, 1, (batch_size, 3, 227, 227)).astype(np.float32)
165
+ label = np.random.randint(0, num_classes, (batch_size,)).astype(np.float32) # 注意此处类型应为 float32
166
+
167
+ # 转换为 MindSpore 张量
168
+ data = Tensor(data)
169
+ label = Tensor(label)
170
+
171
+ steps = 5
172
+ for i in range(steps):
173
+ debugger.start(net) # 启动调试器
174
+ loss = train_step(data, label) # 执行训练步骤
175
+ print(f"Step {i}, Loss: {loss}")
176
+ debugger.stop() # 停止调试器
177
+ debugger.step() # 计数步数
178
+ ```
179
+
180
+ ## 3. 运行训练脚本
181
+
182
+ 在命令行中执行以下命令:
183
+
184
+ ```bash
185
+ python alexnet_model.py
186
+ ```
187
+
188
+ ## 4. 查看采集结果
189
+
190
+ 执行训练命令后,工具会将模型训练过程中的精度数据采集下来。
191
+
192
+ 日志中打印出现如下信息表示数据采集成功,即可手动停止模型训练查看采集数据。
193
+
194
+ ```markdown
195
+ ****************************************************************************
196
+ * msprobe ends successfully. *
197
+ ****************************************************************************
198
+ ```
199
+
200
+ ## 5. 数据分析
201
+
202
+ 在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)。:
203
+
204
+ ```bash
205
+ output/
206
+ └── step0
207
+ └── rank
208
+ ├── construct.json # level为L0时,保存Cell的层级关系信息。当前场景为空
209
+ ├── dump.json # 保存API前反向输入输出数据的统计量信息
210
+ └── stack.json # 保存API的调用栈
211
+ ```
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -1 +1,17 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
17
+ from msprobe.mindspore.common.utils import seed_all
@@ -1,16 +1,34 @@
1
- import json
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import os
17
+ from tqdm import tqdm
3
18
 
4
- from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
5
- from msprobe.core.common.utils import add_time_as_suffix
6
19
  from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
- from msprobe.mindspore.common.log import logger
20
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
21
+ from msprobe.core.common.utils import add_time_as_suffix
8
22
  from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
23
  from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
24
  from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
25
+ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
11
26
  from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
12
27
  trim_output_compute_element_list)
28
+ from msprobe.mindspore.common.log import logger
13
29
 
30
+ cur_path = os.path.dirname(os.path.realpath(__file__))
31
+ yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
14
32
 
15
33
  class BasicInfoAndStatus:
16
34
  def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
@@ -21,6 +39,7 @@ class BasicInfoAndStatus:
21
39
  self.status = status
22
40
  self.err_msg = err_msg
23
41
 
42
+
24
43
  class ResultCsvEntry:
25
44
  def __init__(self) -> None:
26
45
  self.forward_pass_status = None
@@ -31,9 +50,9 @@ class ResultCsvEntry:
31
50
 
32
51
 
33
52
  class ApiAccuracyChecker:
34
- def __init__(self):
53
+ def __init__(self, args):
35
54
  self.api_infos = dict()
36
- self.results = dict()
55
+ self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
37
56
 
38
57
  @staticmethod
39
58
  def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
@@ -80,25 +99,64 @@ class ApiAccuracyChecker:
80
99
  compare_result_dict[compare_algorithm_name] = compare_result
81
100
 
82
101
  if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
83
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
102
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
84
103
  status = CompareConst.PASS
85
104
  err_msg = ""
86
105
  else:
87
106
  status = CompareConst.ERROR
88
107
  err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
89
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
108
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
90
109
  basic_info_status = \
91
110
  BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
92
111
  output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
93
112
  return output_list
94
113
 
114
+ @staticmethod
115
+ def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
116
+ '''
117
+ Args:
118
+ api_info: ApiInfo
119
+ forward_or_backward: str
120
+ Returns:
121
+ ApiInputAggregation
122
+ '''
123
+ forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
124
+ kwargs = api_info.get_kwargs()
125
+ if forward_or_backward == Const.FORWARD:
126
+ gradient_inputs = None
127
+ else:
128
+ gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
129
+ return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
130
+
131
+ @staticmethod
132
+ def is_api_checkable(api_name_str):
133
+ '''
134
+ Args:
135
+ api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
136
+ Returns:
137
+ is_checkable: bool
138
+ Description:
139
+ tell whether this api is checkable based on the key in "data" dict in api_info.json
140
+ '''
141
+ api_name_str_list = api_name_str.split(Const.SEP)
142
+ if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
143
+ return False
144
+ api_type_str = api_name_str_list[0]
145
+ real_api_str = Const.SEP.join(api_name_str_list[1:-2])
146
+ api_list = load_yaml(yaml_path)
147
+ supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
148
+ if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL):
149
+ return True
150
+ if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list:
151
+ return True
152
+ return False
153
+
95
154
  def parse(self, api_info_path):
96
- with FileOpen(api_info_path, "r") as f:
97
- api_info_dict = json.load(f)
155
+ api_info_dict = load_json(api_info_path)
98
156
 
99
157
  # init global context
100
158
  task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
101
- "task field in api_info.json",accepted_type=str,
159
+ "task field in api_info.json", accepted_type=str,
102
160
  accepted_value=(MsCompareConst.STATISTICS_TASK,
103
161
  MsCompareConst.TENSOR_TASK))
104
162
  is_constructed = task == MsCompareConst.STATISTICS_TASK
@@ -112,14 +170,12 @@ class ApiAccuracyChecker:
112
170
  api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
113
171
  "data field in api_info.json", accepted_type=dict)
114
172
  for api_name, api_info in api_info_data.items():
115
- is_mint = api_name.split(Const.SEP)[0] in \
116
- (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
117
- if not is_mint:
173
+ if not self.is_api_checkable(api_name):
118
174
  continue
119
175
  forbackward_str = api_name.split(Const.SEP)[-1]
120
176
  if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
121
177
  logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
122
- api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
178
+ api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
123
179
  if api_name not in self.api_infos:
124
180
  self.api_infos[api_name] = ApiInfo(api_name)
125
181
 
@@ -128,128 +184,64 @@ class ApiAccuracyChecker:
128
184
  else:
129
185
  self.api_infos[api_name].load_backward_info(api_info)
130
186
 
187
+ def process_forward(self, api_name_str, api_info):
188
+ """处理前向检查"""
189
+ if not api_info.check_forward_info():
190
+ logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
191
+ return Const.EXCEPTION_NONE
192
+
193
+ try:
194
+ forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
195
+ except Exception as e:
196
+ logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
197
+ f"Skipping forward check. Detailed exception information: {e}.")
198
+ return Const.EXCEPTION_NONE
199
+
200
+ forward_output_list = None
201
+ try:
202
+ forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
203
+ except Exception as e:
204
+ logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
205
+ f"Detailed exception information: {e}.")
206
+ return forward_output_list
207
+
208
+ def process_backward(self, api_name_str, api_info):
209
+ """处理反向检查"""
210
+ if not api_info.check_backward_info():
211
+ logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
212
+ return Const.EXCEPTION_NONE
213
+
214
+ try:
215
+ backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
216
+ except Exception as e:
217
+ logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
218
+ f"Skipping backward check. Detailed exception information: {e}.")
219
+ return Const.EXCEPTION_NONE
220
+
221
+ backward_output_list = None
222
+ try:
223
+ backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
224
+ except Exception as e:
225
+ logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
226
+ f"Detailed exception information: {e}.")
227
+ return backward_output_list
228
+
229
+
230
+
131
231
  def run_and_compare(self):
132
- for api_name_str, api_info in self.api_infos.items():
133
- if not api_info.check_forward_info():
134
- logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
135
- continue
136
- forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
137
- kwargs = api_info.get_kwargs()
138
- forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
139
- forward_output_list = None
140
- try:
141
- forward_output_list = \
142
- self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
143
- except Exception as e:
144
- logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
145
- f"detailed exception information: {e}")
146
- self.record(forward_output_list)
147
-
148
- if not api_info.check_backward_info():
149
- logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
232
+ for api_name_str, api_info in tqdm(self.api_infos.items()):
233
+ if not self.data_manager.is_unique_api(api_name_str):
150
234
  continue
151
- gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
152
- backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
153
- backward_output_list = None
154
- try:
155
- backward_output_list = \
156
- self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
157
- except Exception as e:
158
- logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
159
- f"detailed exception information: {e}")
160
- self.record(backward_output_list)
161
-
162
- def record(self, output_list):
163
- if output_list is None:
164
- return
165
- for output in output_list:
166
- api_real_name, forward_or_backward, basic_info, compare_result_dict = output
167
- key = tuple([api_real_name, forward_or_backward])
168
- if key not in self.results:
169
- self.results[key] = []
170
- self.results[key].append(tuple([basic_info, compare_result_dict]))
171
-
172
-
173
- def to_detail_csv(self, csv_dir):
174
- # detail_csv
175
- detail_csv = []
176
- detail_csv_header_basic_info = [
177
- MsCompareConst.DETAIL_CSV_API_NAME,
178
- MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
179
- MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
180
- MsCompareConst.DETAIL_CSV_SHAPE,
181
- ]
182
- detail_csv_header_compare_result = list(compare_algorithms.keys())
183
- detail_csv_header_status = [
184
- MsCompareConst.DETAIL_CSV_PASS_STATUS,
185
- MsCompareConst.DETAIL_CSV_MESSAGE,
186
- ]
187
-
188
- detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
189
- detail_csv.append(detail_csv_header)
190
-
191
- for _, results in self.results.items():
192
- # detail csv
193
- for res in results:
194
- basic_info, compare_result_dict = res
195
- csv_row_basic_info = \
196
- [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
197
- csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
198
- for algorithm_name in detail_csv_header_compare_result)
199
- csv_row_status = [basic_info.status, basic_info.err_msg]
200
- csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
201
- detail_csv.append(csv_row)
202
-
203
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
204
- create_directory(csv_dir)
205
- write_csv(detail_csv, file_name, mode="w")
206
-
207
-
208
- def to_result_csv(self, csv_dir):
209
- result_csv_dict = dict()
210
- for key, results in self.results.items():
211
- api_real_name, forward_or_backward = key
212
- forward_or_backward_pass_status = CompareConst.PASS
213
- forward_or_backward_overall_err_msg = ""
214
- # detail csv
215
- for res in results:
216
- basic_info, _ = res
217
- if basic_info.status != CompareConst.PASS:
218
- forward_or_backward_pass_status = CompareConst.ERROR
219
- forward_or_backward_overall_err_msg += basic_info.err_msg
220
- forward_or_backward_overall_err_msg = \
221
- "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
222
-
223
- #result_csv_dict
224
- if api_real_name not in result_csv_dict:
225
- result_csv_dict[api_real_name] = ResultCsvEntry()
226
- if forward_or_backward == Const.FORWARD:
227
- result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
228
- result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
229
- else:
230
- result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
231
- result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
232
-
233
- #result_csv
234
- result_csv = []
235
- result_csv_header = [
236
- MsCompareConst.DETAIL_CSV_API_NAME,
237
- MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
238
- MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
239
- MsCompareConst.DETAIL_CSV_MESSAGE,
240
- ]
241
- result_csv.append(result_csv_header)
242
-
243
- for api_name, result_csv_entry in result_csv_dict.items():
244
- if result_csv_entry.forward_pass_status == CompareConst.PASS and \
245
- result_csv_entry.backward_pass_status == CompareConst.PASS:
246
- overall_err_msg = ""
247
- else:
248
- overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
249
- row = [api_name, result_csv_entry.forward_pass_status,
250
- result_csv_entry.backward_pass_status, overall_err_msg]
251
- result_csv.append(row)
252
-
253
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
254
- create_directory(csv_dir)
255
- write_csv(result_csv, file_name, mode="w")
235
+
236
+ # 处理前向
237
+ forward_output_list = self.process_forward(api_name_str, api_info)
238
+ if forward_output_list is not Const.EXCEPTION_NONE:
239
+ self.data_manager.record(forward_output_list)
240
+
241
+ # 处理反向
242
+ backward_output_list = self.process_backward(api_name_str, api_info)
243
+ if backward_output_list is not Const.EXCEPTION_NONE:
244
+ self.data_manager.record(backward_output_list)
245
+
246
+ self.data_manager.save_results(api_name_str)
247
+
@@ -1,11 +1,34 @@
1
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from msprobe.core.common.const import Const
3
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
17
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
18
+ from msprobe.core.common.utils import is_invalid_pattern
19
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
20
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
5
21
  from msprobe.mindspore.common.log import logger
6
22
 
23
+
7
24
  class ApiInfo:
8
25
  def __init__(self, api_name):
26
+ if not isinstance(api_name, str):
27
+ err_msg = "ApiInfo.__init__ failed: api_name is not a string"
28
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
29
+ if is_invalid_pattern(api_name):
30
+ err_msg = "ApiInfo.__init__ failed: api_name contain illegal character"
31
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
9
32
  self.api_name = api_name
10
33
  self.forward_info = None
11
34
  self.backward_info = None
@@ -59,11 +82,10 @@ class ApiInfo:
59
82
  err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
60
83
  logger.error_log_with_exp(err_msg,
61
84
  ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
62
- if not isinstance(compute_element_info, (list, dict)):
63
- err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
85
+ if not (isinstance(compute_element_info, (list, dict)) or compute_element_info is None):
86
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list, dict or null"
64
87
  logger.error_log_with_exp(err_msg,
65
88
  ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
66
89
  kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
67
90
  for key_str, compute_element_info in kwargs_dict.items()}
68
91
  return kwargs_compute_element_dict
69
-