mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,20 +1,35 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import copy
2
- import csv
3
17
  import glob
4
18
  import os
19
+ import re
5
20
 
6
21
  import numpy as np
7
22
  import pandas as pd
8
- from msprobe.core.common.const import CompareConst, GraphMode, Const, FileCheckConst
9
- from msprobe.core.common.file_utils import FileOpen, check_path_before_create, change_mode, load_npy
23
+ from msprobe.core.common.const import CompareConst, GraphMode, Const
24
+ from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
10
25
  from msprobe.core.common.log import logger
11
26
  from msprobe.core.common.utils import add_time_with_xlsx, CompareException
12
27
  from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
13
- from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check, reshape_value, compare_ops_apply
28
+ from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check, compare_ops_apply
14
29
  from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
15
30
 
16
31
 
17
- class row_data:
32
+ class RowData:
18
33
  def __init__(self, mode):
19
34
  self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
20
35
  self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
@@ -28,17 +43,34 @@ class row_data:
28
43
  return self.data
29
44
 
30
45
 
46
+ def get_name_dict(name: str) -> dict:
47
+ compare_pattern = re.compile(r'^([^.]+)\.([^.]+)\.([^.]+)\.([^.]+)\.(\d+(?:\.\d+)*)\.'
48
+ r'((?:in|out)put(?:\.\d+)*)\.([^.]+)\.([^.]+)\.npy$')
49
+ match = compare_pattern.match(name)
50
+ if match:
51
+ return {'op_type': match.group(1),
52
+ 'op_name': match.group(2),
53
+ 'task_id': match.group(3),
54
+ 'stream_id': match.group(4),
55
+ 'timestamp': match.group(5).split(Const.SEP)[0],
56
+ 'input_output_index': match.group(6),
57
+ 'slot': match.group(7),
58
+ 'format': match.group(8)}
59
+ return {}
60
+
61
+
31
62
  def npy_data_read(data_path, npy_file_list, mapping_dict):
32
63
  data_list = []
64
+ compare_key_elements = ['op_name', 'task_id', 'input_output_index', 'slot']
33
65
  for data in npy_file_list:
34
66
  if data in mapping_dict:
35
- split_list = mapping_dict[data].split(Const.SEP)
67
+ name_dict = get_name_dict(mapping_dict[data])
36
68
  else:
37
- split_list = data.split(Const.SEP)
38
- if len(split_list) < 7:
69
+ name_dict = get_name_dict(data)
70
+ if not name_dict:
39
71
  continue
40
- compare_key = f"{split_list[1]}.{split_list[2]}.{split_list[3]}.{split_list[5]}.{split_list[6]}"
41
- timestamp = convert_to_int(split_list[4])
72
+ compare_key = Const.SEP.join([name_dict.get(element) for element in compare_key_elements])
73
+ timestamp = convert_to_int(name_dict.get('timestamp'))
42
74
 
43
75
  data_list.append([os.path.join(data_path, data), compare_key, timestamp])
44
76
  return data_list
@@ -48,18 +80,17 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
48
80
  data_list = []
49
81
  statistic_data_list = []
50
82
  header_index = {
51
- 'Data Type': None, 'Shape': None, 'Max Value': None,
52
- 'Min Value': None,'Avg Value': None, 'L2Norm Value': None
83
+ 'Data Type': None, 'Shape': None, 'Max Value': None,
84
+ 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
53
85
  }
54
86
  for statistic_file in statistic_file_list:
55
- with FileOpen(statistic_file, "r") as f:
56
- csv_reader = csv.reader(f, delimiter=",")
57
- header = next(csv_reader)
58
- for key in header_index.keys():
59
- for index, value in enumerate(header):
60
- if key == value:
61
- header_index[key] = index
62
- statistic_data_list.extend([row for row in csv_reader])
87
+ content = read_csv(statistic_file, as_pd=False)
88
+ header = content[0]
89
+ for key in header_index.keys():
90
+ for index, value in enumerate(header):
91
+ if key == value:
92
+ header_index[key] = index
93
+ statistic_data_list.extend(content[1:])
63
94
 
64
95
  for key in header_index.keys():
65
96
  if header_index[key] is None:
@@ -97,11 +128,9 @@ def generate_data_name(data_path):
97
128
  mapping_dict = {}
98
129
  if mapping_exist:
99
130
  for mapping_file in mapping_file_list:
100
- with FileOpen(mapping_file, "r") as f:
101
- csv_reader = csv.reader(f, delimiter=",")
102
- header = next(csv_reader)
103
- for row in csv_reader:
104
- mapping_dict[row[0]] = row[1]
131
+ content = read_csv(mapping_file, False)
132
+ for row in content[1:]:
133
+ mapping_dict[row[0]] = row[1]
105
134
 
106
135
  if npy_exist:
107
136
  data_list = npy_data_read(data_path, npy_file_list, mapping_dict)
@@ -115,10 +144,16 @@ def generate_data_name(data_path):
115
144
  mode = GraphMode.STATISTIC_MODE
116
145
  else:
117
146
  mode = GraphMode.ERROR_MODE
118
- logger.error(f"Error mode.")
147
+ logger.error("Error mode.")
119
148
  return mode, data_list
120
149
 
121
150
 
151
+ def transform_special_string_into_float(data_frame):
152
+ data_frame[data_frame == "null"] = '0'
153
+ data_frame[data_frame == "False"] = '0'
154
+ data_frame[data_frame == "True"] = '1'
155
+
156
+
122
157
  class GraphMSComparator:
123
158
  def __init__(self, input_param, output_path):
124
159
  self.output_path = output_path
@@ -136,7 +171,7 @@ class GraphMSComparator:
136
171
  def compare_ops(compare_result_db, mode):
137
172
 
138
173
  def npy_mode_compute(row):
139
- result_dict = row_data(GraphMode.NPY_MODE)()
174
+ result_dict = RowData(GraphMode.NPY_MODE)()
140
175
 
141
176
  def process_npy_file(file_path, name_prefix, result):
142
177
  if os.path.exists(file_path):
@@ -158,7 +193,6 @@ class GraphMSComparator:
158
193
  result_dict[CompareConst.ERROR_MESSAGE] = error_message
159
194
 
160
195
  if not error_flag:
161
- n_value, b_value = reshape_value(n_value, b_value)
162
196
  result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
163
197
  result_dict[CompareConst.COSINE] = result_list[0]
164
198
  result_dict[CompareConst.MAX_ABS_ERR] = result_list[1]
@@ -171,7 +205,7 @@ class GraphMSComparator:
171
205
  return pd.Series(result_dict)
172
206
 
173
207
  def statistic_mode_compute(row):
174
- result_dict = row_data('STATISTIC')()
208
+ result_dict = RowData('STATISTIC')()
175
209
 
176
210
  def update_result_dict(result, rows, prefix):
177
211
  result[f'{prefix} Name'] = rows[f'{prefix} Name']
@@ -198,24 +232,30 @@ class GraphMSComparator:
198
232
  result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM])
199
233
  result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[
200
234
  CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0
201
- result_dict[CompareConst.MAX_RELATIVE_ERR] = str(result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
235
+ if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]):
236
+ result_dict[CompareConst.MAX_RELATIVE_ERR] = str(
237
+ result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
202
238
  result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[
203
239
  CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0
204
- result_dict[CompareConst.MIN_RELATIVE_ERR] = str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
240
+ if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]):
241
+ result_dict[CompareConst.MIN_RELATIVE_ERR] = \
242
+ str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
205
243
  result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[
206
244
  CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0
207
- result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
208
- result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
245
+ if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
246
+ result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
247
+ result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
209
248
  result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
210
249
  CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
211
- result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
212
- result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
250
+ if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
251
+ result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
252
+ result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
213
253
  magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
214
254
  max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
215
- if magnitude_diff > CompareConst.MAGNITUDE:
216
- result_dict[CompareConst.ACCURACY] = 'No'
217
- else:
218
- result_dict[CompareConst.ACCURACY] = 'Yes'
255
+ if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
256
+ magnitude_diff = 0
257
+ result_dict[CompareConst.ACCURACY] = CompareConst.YES if \
258
+ magnitude_diff <= CompareConst.MAGNITUDE else CompareConst.NO
219
259
 
220
260
  return pd.Series(result_dict)
221
261
 
@@ -238,24 +278,23 @@ class GraphMSComparator:
238
278
  is_empty = True
239
279
  if is_empty or not mode:
240
280
  continue
241
- compare_result_df = self._do_multi_process(compare_result_df, mode)
281
+ compare_result_df = self.do_multi_process(compare_result_df, mode)
242
282
  compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
243
283
  compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
244
- check_path_before_create(compare_result_path)
245
284
  self.to_excel(compare_result_df, compare_result_path)
246
285
  logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
247
-
286
+
248
287
  def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
249
288
  size = len(compare_result_df)
250
289
  # sheet size cannot be larger than 1048576
251
290
  if size < CompareConst.MAX_EXCEL_LENGTH:
252
- compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if need_slice else compare_result_path
253
- compare_result_df.to_excel(compare_result_path, index=False)
254
- change_mode(compare_result_path, FileCheckConst.DATA_FILE_AUTHORITY)
291
+ compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
292
+ need_slice else compare_result_path
293
+ save_excel(compare_result_path, compare_result_df)
255
294
  return slice_num + 1
256
295
  else:
257
- slice_num = self.to_excel(compare_result_df.iloc[0: size//2], compare_result_path, slice_num, True)
258
- return self.to_excel(compare_result_df.iloc[size//2:], compare_result_path, slice_num, True)
296
+ slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
297
+ return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
259
298
 
260
299
  def compare_process(self, rank_id, step_id):
261
300
  # generate data_path
@@ -300,13 +339,17 @@ class GraphMSComparator:
300
339
  CompareConst.BENCH_NORM])
301
340
 
302
341
  npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
303
- npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
342
+ npu_float_data_df = npu_data_df[npu_float_type].astype(str)
343
+ transform_special_string_into_float(npu_float_data_df)
344
+ npu_data_df[npu_float_type] = npu_float_data_df.astype(float)
304
345
 
305
346
  bench_float_type = [
306
- CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
307
- CompareConst.BENCH_MEAN,CompareConst.BENCH_NORM
347
+ CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
348
+ CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
308
349
  ]
309
- bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
350
+ bench_float_data_df = bench_data_df[bench_float_type].astype(str)
351
+ transform_special_string_into_float(bench_float_data_df)
352
+ bench_data_df[bench_float_type] = bench_float_data_df.astype(float)
310
353
 
311
354
  npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
312
355
  bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
@@ -355,7 +398,7 @@ class GraphMSComparator:
355
398
  rank_step_path_dict[rank_step_key] = [dir_path]
356
399
  return dict(sorted(rank_step_path_dict.items()))
357
400
 
358
- def _do_multi_process(self, result_df, mode):
401
+ def do_multi_process(self, result_df, mode):
359
402
  try:
360
403
  result_df = _ms_graph_handle_multi_process(self.compare_ops, result_df, mode)
361
404
  except ValueError as e:
@@ -33,12 +33,13 @@ class DebuggerConfig:
33
33
  self.level_ori = common_config.level
34
34
  self.list = [] if not task_config.list else task_config.list
35
35
  self.scope = [] if not task_config.scope else task_config.scope
36
- self.data_mode = [] if not task_config.data_mode else task_config.data_mode
36
+ self.data_mode = [Const.ALL] if not task_config.data_mode else task_config.data_mode
37
37
  self.file_format = task_config.file_format
38
38
  self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
39
39
  self.check_mode = task_config.check_mode
40
40
  self.framework = Const.MS_FRAMEWORK
41
41
  self.summary_mode = task_config.summary_mode
42
+ self.async_dump = common_config.async_dump if common_config.async_dump else False
42
43
  self.check()
43
44
  create_directory(self.dump_path)
44
45
 
@@ -52,6 +53,9 @@ class DebuggerConfig:
52
53
  self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
53
54
  raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
54
55
  f"but got {self.pert_type}.")
56
+ if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
57
+ raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
58
+ f"but got {self.handler_type}.")
55
59
  self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
56
60
 
57
61
  def check(self):
@@ -66,4 +70,6 @@ class DebuggerConfig:
66
70
  self.file_format = "npy"
67
71
  if not self.check_mode:
68
72
  self.check_mode = "all"
73
+ if not isinstance(self.async_dump, bool):
74
+ raise Exception("The parameters async_dump should be bool.")
69
75
  return True
@@ -1,7 +1,7 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -14,25 +14,42 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ from collections import defaultdict, namedtuple
17
18
 
18
19
  import mindspore as ms
19
20
  from mindspore._c_expression import MSContext
20
21
 
21
- from msprobe.core.common.const import Const, MsgConst
22
+ from msprobe.core.common.const import Const, FileCheckConst, MsgConst
23
+ from msprobe.core.common.exceptions import MsprobeException
24
+ from msprobe.core.common.file_utils import FileChecker
25
+ from msprobe.core.common.utils import get_real_step_or_rank
26
+ from msprobe.mindspore.cell_processor import CellProcessor
22
27
  from msprobe.mindspore.common.const import Const as MsConst
28
+ from msprobe.mindspore.common.utils import set_register_backward_hook_functions
23
29
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
30
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
31
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
24
32
  from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
25
33
  from msprobe.mindspore.ms_config import parse_json_config
26
34
  from msprobe.mindspore.runtime import Runtime
27
35
  from msprobe.mindspore.service import Service
28
36
  from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
29
37
 
38
+ try:
39
+ from msprobe.lib import _msprobe_c
40
+ except ImportError:
41
+ _msprobe_c = None
42
+
43
+
44
+ ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"])
45
+
30
46
 
31
47
  class PrecisionDebugger:
32
48
  _instance = None
33
49
  task_not_need_service = [Const.GRAD_PROBE]
34
50
 
35
- def __new__(cls, config_path=None, opt=None):
51
+ def __new__(cls, config_path=None, task=None, dump_path=None,
52
+ level=None, step=None, opt=None):
36
53
  if not cls._instance:
37
54
  cls._instance = super().__new__(cls)
38
55
  cls._instance.initialized = False
@@ -41,22 +58,65 @@ class PrecisionDebugger:
41
58
  cls.first_start = False
42
59
  return cls._instance
43
60
 
44
- def __init__(self, config_path=None):
61
+ def __init__(self, config_path=None, task=None, dump_path=None,
62
+ level=None, step=None):
45
63
  if self.initialized:
46
64
  return
47
65
  self.initialized = True
66
+
67
+ set_register_backward_hook_functions()
68
+
48
69
  if not config_path:
49
70
  config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
71
+
72
+ config_params = ConfigParameters(config_path, task, dump_path, level)
73
+ self.check_input_params(config_params)
74
+
50
75
  common_config, task_config = parse_json_config(config_path)
76
+ common_config.task = task if task else common_config.task
51
77
  self.task = common_config.task
52
78
  if self.task == Const.GRAD_PROBE:
53
79
  self.gm = GradientMonitor(common_config, task_config)
54
80
  return
81
+ common_config.step = get_real_step_or_rank(
82
+ step, Const.STEP) if step is not None else common_config.step
83
+ common_config.level = level if level else common_config.level
84
+ common_config.dump_path = dump_path if dump_path else common_config.dump_path
55
85
  self.config = DebuggerConfig(common_config, task_config)
56
86
 
87
+ if _msprobe_c:
88
+ _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
89
+
90
+ self.config.execution_mode = self._get_execution_mode()
91
+ if self._need_service():
92
+ self.service = Service(self.config)
93
+
57
94
  Runtime.step_count = 0
58
95
  Runtime.is_running = False
59
96
 
97
+ @staticmethod
98
+ def check_input_params(args):
99
+ if args.config_path is not None:
100
+ if not isinstance(args.config_path, str):
101
+ raise MsprobeException(
102
+ MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
103
+ file_checker = FileChecker(
104
+ file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
105
+ file_checker.common_check()
106
+
107
+ if args.task is not None and args.task not in Const.TASK_LIST:
108
+ raise MsprobeException(
109
+ MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
110
+
111
+ if args.dump_path is not None:
112
+ if not isinstance(args.dump_path, str):
113
+ raise MsprobeException(
114
+ MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
115
+
116
+ if args.level is not None and args.level not in Const.LEVEL_LIST:
117
+ raise MsprobeException(
118
+ MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
119
+
60
120
  @staticmethod
61
121
  def _get_execution_mode():
62
122
  jit_level = ms.context.get_jit_config().get(MsConst.JIT_LEVEL)
@@ -75,11 +135,23 @@ class PrecisionDebugger:
75
135
  else:
76
136
  return MsConst.PYNATIVE_MODE
77
137
 
138
+ @staticmethod
139
+ def _is_graph_dump(config):
140
+ if config.level != MsConst.KERNEL:
141
+ return False
142
+ if not config.list or len(config.list) > 1:
143
+ return True
144
+ if '-' in config.list[0] or '/' in config.list[0]:
145
+ return True
146
+ return False
147
+
78
148
  @classmethod
79
149
  def start(cls, model=None):
80
150
  instance = cls._instance
81
151
  if not instance:
82
152
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
153
+ if _msprobe_c:
154
+ _msprobe_c._PrecisionDebugger().start()
83
155
  if instance.task in PrecisionDebugger.task_not_need_service:
84
156
  return
85
157
 
@@ -90,6 +162,7 @@ class PrecisionDebugger:
90
162
  instance.service.start(model)
91
163
  else:
92
164
  if not instance.first_start:
165
+ api_register.api_set_ori_func()
93
166
  handler = TaskHandlerFactory.create(instance.config)
94
167
  handler.handle()
95
168
 
@@ -99,18 +172,15 @@ class PrecisionDebugger:
99
172
  @classmethod
100
173
  def forward_backward_dump_end(cls):
101
174
  instance = cls._instance
102
- if not instance:
103
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
104
- if instance.task in PrecisionDebugger.task_not_need_service:
105
- return
106
- if instance.service:
107
- instance.service.forward_backward_dump_end()
175
+ instance.stop()
108
176
 
109
177
  @classmethod
110
178
  def stop(cls):
111
179
  instance = cls._instance
112
180
  if not instance:
113
181
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
182
+ if _msprobe_c:
183
+ _msprobe_c._PrecisionDebugger().stop()
114
184
  if instance.task == Const.GRAD_PROBE:
115
185
  instance.gm.stop()
116
186
  if instance.task in PrecisionDebugger.task_not_need_service:
@@ -124,10 +194,15 @@ class PrecisionDebugger:
124
194
  instance = cls._instance
125
195
  if not instance:
126
196
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
197
+ if _msprobe_c:
198
+ _msprobe_c._PrecisionDebugger().step()
127
199
  if instance.task in PrecisionDebugger.task_not_need_service:
128
200
  return
129
201
  if instance.service:
130
202
  instance.service.step()
203
+ HOOKCell.cell_count = defaultdict(int)
204
+ CellProcessor.reset_cell_stats()
205
+
131
206
  Runtime.step_count += 1
132
207
 
133
208
  @classmethod
@@ -147,4 +222,4 @@ class PrecisionDebugger:
147
222
  if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
148
223
  return False
149
224
  else:
150
- return instance.config.task != Const.FREE_BENCHMARK and instance.config.level != MsConst.KERNEL
225
+ return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -40,6 +40,8 @@ class DumpToolFactory:
40
40
 
41
41
  @staticmethod
42
42
  def create(config: DebuggerConfig):
43
+ if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
44
+ raise Exception("data_mode must be one of all, input, output.")
43
45
  tool = DumpToolFactory.tools.get(config.level)
44
46
  if not tool:
45
47
  raise Exception("Valid level is needed.")