mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -27,10 +27,14 @@ import numpy as np
27
27
  from tqdm import tqdm
28
28
 
29
29
  # 本地应用/库特定导入
30
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
30
+ from msprobe.core.common.const import Const, CompareConst
31
31
  from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
32
32
  from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
33
33
  from msprobe.mindspore.common.log import logger
34
+ from msprobe.mindspore.common.const import MsCompareConst
35
+
36
+ from msprobe.core.data_dump.data_collector import build_data_collector
37
+ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
34
38
 
35
39
 
36
40
  class MultiApiAccuracyChecker(ApiAccuracyChecker):
@@ -50,6 +54,12 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
50
54
  # 初始化一个属性来存储当前的设备ID(用于日志中显示)
51
55
  self.current_device_id = None
52
56
 
57
+ self.save_error_data = args.save_error_data
58
+ if self.save_error_data:
59
+ config, dump_path_aggregation = self.init_save_error_data(args)
60
+ self.data_collector = build_data_collector(config)
61
+ self.data_collector.update_dump_paths(dump_path_aggregation)
62
+
53
63
  def process_on_device(self, device_id, api_infos, progress_queue):
54
64
  """
55
65
  在特定设备上处理一部分API。
@@ -19,7 +19,8 @@ import sys
19
19
  from pathlib import Path
20
20
  import mindspore
21
21
  from msprobe.mindspore.common.log import logger
22
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
22
+ from msprobe.core.common.const import Const, CompareConst
23
+ from msprobe.mindspore.common.const import MsCompareConst
23
24
  import torch as mindtorch
24
25
  from torch import Tensor as mindtorch_tensor
25
26
  import torch.nn.functional as mindtorch_func
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,21 +13,50 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
16
+ from collections import OrderedDict
17
+
18
+ from mindspore import Tensor
19
+ from mindspore.common.hook_handle import HookHandle
20
+ from mindspore.ops.operations import _inner_ops as inner
21
+
17
22
  from msprobe.core.common.const import Const
23
+ from msprobe.core.common.exceptions import MsprobeException
24
+ from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
25
+ from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.log import logger
27
+ from msprobe.mindspore.common.utils import (
28
+ is_mindtorch,
29
+ get_cells_and_names_with_index,
30
+ has_kwargs_in_forward_hook,
31
+ is_graph_mode_cell_dump_allowed
32
+ )
33
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
34
+ from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
35
+ from msprobe.core.common.runtime import Runtime
36
+
37
+
38
+ def get_cell_construct(construct):
39
+ def _construct(self, *args, **kwargs):
40
+ if hasattr(self, 'msprobe_hook'):
41
+ setattr(self, 'msprobe_input_kwargs', kwargs)
42
+ return construct(self, *args, **kwargs)
43
+ return _construct
18
44
 
19
45
 
20
46
  class CellProcessor:
21
47
  cell_count = {}
22
48
  cell_stack = []
23
- api_parent_node = ""
49
+ api_parent_node = None
24
50
  module_node = {}
51
+ cell_bw_hook_kernels = {}
52
+ cell_backward_pre_hook = []
53
+ cell_backward_hook = []
25
54
 
26
55
  def __init__(self, scope):
27
56
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
28
57
 
29
58
  @staticmethod
30
- def set_cell_count(cell_name):
59
+ def set_and_get_calls_number(cell_name):
31
60
  if cell_name not in CellProcessor.cell_count:
32
61
  CellProcessor.cell_count[cell_name] = 0
33
62
  else:
@@ -38,42 +67,184 @@ class CellProcessor:
38
67
  def reset_cell_stats(cls):
39
68
  cls.cell_count = {}
40
69
  cls.cell_stack = []
41
- cls.api_parent_node = ""
70
+ cls.api_parent_node = None
42
71
  cls.module_node = {}
72
+ cls.cell_bw_hook_kernels = {}
73
+ cls.cell_backward_pre_hook = []
74
+ cls.cell_backward_hook = []
43
75
 
44
- def node_hook(self, name_prefix, start_or_stop, **kwargs):
45
- def begin_hook(cell, input_data):
46
- full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True)
47
- if CellProcessor.cell_stack:
48
- CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1]
49
- else:
50
- CellProcessor.module_node[full_name] = None
76
+ def register_cell_hook(self, models, build_hook, config: DebuggerConfig):
77
+ if not models:
78
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
79
+ 'The model cannot be None, when level is "L0" or "mix"')
80
+
81
+ is_registered = False
82
+ model_type = Const.MODULE if is_mindtorch() else Const.CELL
83
+ cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models)
84
+ construct_name = '_call_impl' if is_mindtorch() else '_run_construct'
85
+
86
+ for index, cells_and_names in cells_with_index_in_pynative_mode.items():
87
+ model = models if index == "-1" else models[int(index)]
88
+ for name, cell in cells_and_names:
89
+ if cell == model:
90
+ continue
91
+
92
+ if not has_kwargs_in_forward_hook():
93
+ if not hasattr(cell.__class__, 'msprobe_construct'):
94
+ setattr(cell.__class__, 'msprobe_construct', True)
95
+ if hasattr(cell.__class__, construct_name):
96
+ setattr(cell.__class__, construct_name,
97
+ get_cell_construct(getattr(cell.__class__, construct_name)))
98
+ setattr(cell, 'msprobe_hook', True)
99
+
100
+ cell_index = (index + Const.SEP) if index != "-1" else ""
101
+ prefix = f'{model_type}{Const.SEP}{cell_index}{name}{Const.SEP}{cell.__class__.__name__}{Const.SEP}'
102
+
103
+ forward_pre_hook = self.build_cell_hook(prefix, build_hook)
104
+ cell.register_forward_pre_hook(forward_pre_hook)
105
+
106
+ if not is_registered:
107
+ logger.info("The cell hook function is successfully mounted to the model.")
108
+ is_registered = True
109
+
110
+ if is_graph_mode_cell_dump_allowed(config):
111
+ cells_and_names_in_graph_mode = []
112
+ for index, cells_and_names in cells_with_index_in_graph_mode.items():
113
+ model = models if index == "-1" else models[int(index)]
114
+ for name, cell in cells_and_names:
115
+ if cell == model:
116
+ continue
117
+ cell_index = (index + Const.SEP) if index != "-1" else ""
118
+ cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell))
119
+
120
+ if cells_and_names_in_graph_mode:
121
+ Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE
122
+ GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
51
123
 
52
- CellProcessor.cell_stack.append(full_name)
53
- CellProcessor.api_parent_node = full_name
124
+ def build_cell_hook(self, cell_name, build_data_hook):
125
+ def forward_pre_hook(cell, args):
126
+ index = CellProcessor.set_and_get_calls_number(cell_name)
127
+ full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
128
+ full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
54
129
 
55
- if self.scope:
56
- self.scope.begin_module(full_name)
130
+ self.set_construct_info_in_pre_hook(full_forward_name)
57
131
 
58
- def end_hook(cell, input_data, output_data):
59
- if CellProcessor.cell_stack:
60
- CellProcessor.cell_stack.pop()
61
- if CellProcessor.cell_stack:
62
- CellProcessor.api_parent_node = CellProcessor.cell_stack[-1]
132
+ if not hasattr(cell, 'msprobe_forward_hook'):
133
+ if is_mindtorch():
134
+ cell.register_forward_hook(forward_hook, prepend=True, with_kwargs=True)
135
+ else:
136
+ forward_hook_dict = getattr(cell, '_forward_hook', OrderedDict())
137
+ if has_kwargs_in_forward_hook():
138
+ forward_hook_with_kwargs_dict = getattr(cell, '_forward_hook_with_kwargs', OrderedDict())
139
+ handle = HookHandle(forward_hook_dict, extra_dict=forward_hook_with_kwargs_dict)
140
+ forward_hook_with_kwargs_dict[handle.handle_id] = True
141
+ else:
142
+ handle = HookHandle(forward_hook_dict)
143
+ forward_hook_dict[handle.handle_id] = forward_hook
144
+ forward_hook_dict.move_to_end(handle.handle_id, last=False)
145
+
146
+ setattr(cell, 'msprobe_forward_hook', True)
147
+
148
+ def get_backward_hook(backward_data_hook, full_backward_name):
149
+ def backward_hook_fn(cell, grad_input, grad_output):
150
+ new_output = backward_data_hook(cell, grad_input, grad_output)
151
+ self.set_construct_info_in_hook(full_backward_name)
152
+ cell.has_pre_hook_called = False
153
+ return new_output
154
+ return backward_hook_fn
155
+
156
+ enable_hooked = sum(
157
+ [isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args]
158
+ )
159
+ if enable_hooked:
160
+ backward_hook = OrderedDict()
161
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
162
+ backward_hook[full_backward_name] = get_backward_hook(hook_set.backward_hook, full_backward_name)
163
+ CellProcessor.cell_backward_hook.append(backward_hook)
164
+ bw_hook = inner.CellBackwardHook(full_backward_name, cell,
165
+ self.cell_backward_hook[-1])
166
+ bw_hook.register_backward_hook()
167
+ CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook
168
+
169
+ args = bw_hook(*args)
170
+
171
+ return args
172
+
173
+ def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None):
174
+ index = CellProcessor.cell_count.get(cell_name, 0)
175
+ full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
176
+ full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
177
+
178
+ self.set_construct_info_in_hook(full_forward_name)
179
+
180
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
181
+ hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs)
182
+ if hook_result is not None:
183
+ outputs = hook_result
63
184
  else:
64
- CellProcessor.api_parent_node = None
185
+ outputs = output_or_kwargs if has_kwargs_in_forward_hook() else kwargs_or_output
186
+
187
+ bw_hook = CellProcessor.cell_bw_hook_kernels.get(full_forward_name)
188
+ if bw_hook:
189
+ if not isinstance(outputs, (Tensor, tuple)):
190
+ logger.warning("For backward hooks to be called,"
191
+ " cell output should be a Tensor or a tuple of Tensors"
192
+ f" but received {type(outputs)}")
193
+ if isinstance(outputs, tuple):
194
+ new_outputs = bw_hook(*outputs)
195
+ else:
196
+ new_outputs = bw_hook(outputs)
197
+ if isinstance(outputs, tuple) and len(outputs) == 1:
198
+ new_outputs = (new_outputs,)
199
+ outputs = new_outputs
200
+
201
+ def get_backward_pre_hook(full_backward_name, backward_data_hook):
202
+ def backward_pre_hook_fn(cell, grad_output):
203
+ cell.has_pre_hook_called = True
204
+ self.set_construct_info_in_pre_hook(full_backward_name)
205
+ if backward_data_hook:
206
+ backward_data_hook(cell, (), grad_output)
207
+ self.set_construct_info_in_hook(full_backward_name)
208
+ cell.has_pre_hook_called = False
209
+ return backward_pre_hook_fn
65
210
 
66
- if self.scope:
67
- self.scope.end_module(cell.mindstudio_reserved_name)
211
+ backward_pre_hook = OrderedDict()
212
+ backward_data_hook = None if bw_hook else hook_set.backward_hook
213
+ backward_pre_hook[full_backward_name] = get_backward_pre_hook(full_backward_name, backward_data_hook)
214
+ CellProcessor.cell_backward_pre_hook.append(backward_pre_hook)
215
+ bw_pre_hook = inner.CellBackwardHook(full_backward_name, cell,
216
+ self.cell_backward_pre_hook[-1])
217
+ bw_pre_hook.register_backward_pre_hook()
68
218
 
69
- return begin_hook if Const.START == start_or_stop else end_hook
219
+ if isinstance(outputs, tuple):
220
+ result = bw_pre_hook(*outputs)
221
+ else:
222
+ result = bw_pre_hook(outputs)
223
+ if isinstance(outputs, tuple):
224
+ if len(outputs) == 1:
225
+ result = (result,)
226
+ if len(result) != len(outputs):
227
+ raise TypeError(
228
+ f"The backward pre hook return value size is {len(result)} "
229
+ f"not equal to output size {len(outputs)}"
230
+ )
231
+ return result
232
+
233
+ return forward_pre_hook
70
234
 
71
- def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False):
72
- if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called:
73
- cell.has_pre_hook_called = False
235
+ def set_construct_info_in_pre_hook(self, full_name):
236
+ if self.cell_stack:
237
+ CellProcessor.module_node[full_name] = self.cell_stack[-1]
74
238
  else:
75
- if is_called_by_pre_hook:
76
- cell.has_pre_hook_called = True
77
- index = self.set_cell_count(cell_name)
78
- cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index)
79
- return cell.mindstudio_reserved_name
239
+ CellProcessor.module_node[full_name] = None
240
+ CellProcessor.cell_stack.append(full_name)
241
+ CellProcessor.api_parent_node = full_name
242
+ if self.scope:
243
+ self.scope.begin_module(full_name)
244
+
245
+ def set_construct_info_in_hook(self, full_name):
246
+ if self.cell_stack:
247
+ CellProcessor.cell_stack.pop()
248
+ CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] if self.cell_stack else None
249
+ if self.scope:
250
+ self.scope.end_module(full_name)
@@ -34,19 +34,6 @@ class Parser:
34
34
  if isinstance(subgraph_node.attrs, list):
35
35
  subgraph_node.attrs.extend(attrs)
36
36
 
37
- @staticmethod
38
- def parse_graph_attributes(text: str, graph_node: GraphNode) -> None:
39
- attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL)
40
- match = attr_pattern.search(text, graph_node.pos)
41
- if match:
42
- attrs = match.group(1).strip().split('\n')
43
- for attr in attrs:
44
- if not attr:
45
- break
46
- key, value = attr.split(':')
47
- if isinstance(graph_node.attrs, dict):
48
- graph_node.attrs[key.strip()] = value.strip()
49
-
50
37
  @staticmethod
51
38
  def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]:
52
39
  code_info = []
@@ -124,8 +111,9 @@ class Parser:
124
111
  scope_match = scope_pattern.search(text, end_pos)
125
112
  scope = scope_match.group(1) if scope_match else ""
126
113
 
127
- id_pattern = re.compile(r'.*cnode_primal_attrs:'
128
- r'\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE)
114
+ id_pattern = re.compile(
115
+ r'cnode_primal_attrs:'r'\s*\{[\w+]{1, 10000}\b(?:forward_unique_id|unique_id):\s*\"(\d+)\"',
116
+ re.IGNORECASE)
129
117
  unique_id_match = id_pattern.search(text, end_pos, scope_match.start())
130
118
  unique_id = unique_id_match.group(1) if unique_id_match else None
131
119
 
@@ -186,7 +174,7 @@ class Parser:
186
174
  node_info.var_inputs.append(callee_name)
187
175
 
188
176
  def parse_subgraphs(self, text: str) -> None:
189
- subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{')
177
+ subgraph_pattern = re.compile(r'/subgraph\s+@([\w+]{1,1000)(\([^\)]{1,100}\))?\s+\S[^\{]\{/+')
190
178
  matches = list(subgraph_pattern.finditer(text))
191
179
  end_pos = 0
192
180
  for match in matches:
@@ -203,11 +191,6 @@ class Parser:
203
191
  subgraph_info.end = end_pos
204
192
  logging.info('Parsed subgraph: %s', subgraph_name)
205
193
 
206
- def count_nodes(self) -> Tuple[int, int]:
207
- total_nodes = len(self.nodes)
208
- total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode'))
209
- return total_nodes, total_cnodes
210
-
211
194
  def create_backward_map(self):
212
195
  for node in self.nodes.values():
213
196
  if node.scope and node.scope.startswith("Gradients"):
@@ -15,6 +15,7 @@
15
15
 
16
16
  import numpy as np
17
17
  import mindspore as ms
18
+ from mindspore import dtype as mstype
18
19
 
19
20
  from msprobe.core.common.const import Const as CoreConst
20
21
 
@@ -23,14 +24,20 @@ class Const:
23
24
  CELL = "cell"
24
25
  API = "api"
25
26
  KERNEL = "kernel"
27
+ CELL_AND_API = 'cell_and_api'
26
28
  TOOL_LEVEL_DICT = {
27
29
  CoreConst.LEVEL_L0: CELL,
28
30
  CoreConst.LEVEL_L1: API,
29
- CoreConst.LEVEL_L2: KERNEL
31
+ CoreConst.LEVEL_L2: KERNEL,
32
+ CoreConst.LEVEL_MIX: CELL_AND_API
30
33
  }
31
- PYNATIVE_MODE = "pynative"
34
+
35
+ PYNATIVE_MODE = CoreConst.PYNATIVE_MODE
36
+ GRAPH_MODE = "graph"
32
37
  GRAPH_GE_MODE = "graph_ge"
33
38
  GRAPH_KBYK_MODE = "graph_kbyk"
39
+ PYNATIVE_GRAPH_MODE = CoreConst.PYNATIVE_GRAPH_MODE
40
+
34
41
  JIT_LEVEL = "jit_level"
35
42
  JIT_LEVEL_O0 = "O0"
36
43
  JIT_LEVEL_O1 = "O1"
@@ -61,6 +68,7 @@ class Const:
61
68
  DROPOUT_API_NAME_PREFIX = "dropout"
62
69
 
63
70
  GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT]
71
+ GRAPH_CELL_DUMP_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.FORWARD, CoreConst.BACKWARD]
64
72
 
65
73
  HOOK_MS_PREFIX_DICT = {
66
74
  OPS_DATA_PREFIX: OPS_PREFIX,
@@ -69,6 +77,69 @@ class Const:
69
77
  MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX
70
78
  }
71
79
 
80
+ NonDifferentiableType = (
81
+ mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte,
82
+ mstype.int16, mstype.short, mstype.uint16, mstype.ushort,
83
+ mstype.int32, mstype.intc, mstype.uint32, mstype.uintc,
84
+ mstype.int64, mstype.intp, mstype.uint64, mstype.uintp
85
+ )
86
+
87
+
88
+ class MsCompareConst:
89
+ # api_info field
90
+ MINT = "Mint"
91
+ MINT_FUNCTIONAL = "MintFunctional"
92
+ TENSOR_API = "Tensor"
93
+ FUNCTIONAL_API = "Functional"
94
+ FUSION_API = "FUSION"
95
+
96
+ API_NAME_STR_LENGTH = 4
97
+ MAX_RECURSION_DEPTH = 20
98
+
99
+ # Mindtorch api_info field
100
+ MINDTORCH_TENSOR = "Tensor"
101
+ MINDTORCH = "Torch"
102
+ MINDTORCH_FUNC = "Functional"
103
+ MINDTORCH_NPU = "NPU"
104
+ MINDTORCH_DIST = "Distributed"
105
+
106
+ MT_VALID_API_TYPES = [
107
+ MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
108
+ ]
109
+ SUPPORTED_FUSION_LIST = ["flash_attention_score"]
110
+
111
+ TASK_FIELD = "task"
112
+ STATISTICS_TASK = "statistics"
113
+ FRAMEWORK = "framework"
114
+ TENSOR_TASK = "tensor"
115
+ DUMP_DATA_DIR_FIELD = "dump_data_dir"
116
+ DATA_FIELD = "data"
117
+
118
+ # supported api yaml
119
+ SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
120
+ SUPPORTED_TENSOR_LIST_KEY = "tensor"
121
+
122
+ # detail_csv
123
+ DETAIL_CSV_API_NAME = "API Name"
124
+ DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
125
+ DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
126
+ DETAIL_CSV_SHAPE = "Shape"
127
+ DETAIL_CSV_PASS_STATUS = "Status"
128
+ DETAIL_CSV_MESSAGE = "Message"
129
+ DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
130
+
131
+ # result_csv
132
+ RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
133
+ RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
134
+ RESULT_CSV_FILE_NAME = "accuracy_checking_result"
135
+
136
+ EPSILON = 1e-8
137
+
138
+ class ProcessStatus:
139
+ SUCCESS = "success"
140
+ API_NOT_FOUND = "api_not_found"
141
+ EXCEPTION_SKIP = "exception_skip"
142
+
72
143
 
73
144
  class FreeBenchmarkConst:
74
145
  ADD_NOISE = "add_noise"