mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -33,6 +33,9 @@ from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataM
33
33
  from msprobe.mindspore.common.log import logger
34
34
  from msprobe.mindspore.common.const import MsCompareConst
35
35
 
36
+ from msprobe.core.data_dump.data_collector import build_data_collector
37
+ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
38
+
36
39
 
37
40
  class MultiApiAccuracyChecker(ApiAccuracyChecker):
38
41
  def __init__(self, args):
@@ -51,6 +54,12 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
51
54
  # 初始化一个属性来存储当前的设备ID(用于日志中显示)
52
55
  self.current_device_id = None
53
56
 
57
+ self.save_error_data = args.save_error_data
58
+ if self.save_error_data:
59
+ config, dump_path_aggregation = self.init_save_error_data(args)
60
+ self.data_collector = build_data_collector(config)
61
+ self.data_collector.update_dump_paths(dump_path_aggregation)
62
+
54
63
  def process_on_device(self, device_id, api_infos, progress_queue):
55
64
  """
56
65
  在特定设备上处理一部分API。
@@ -108,7 +108,8 @@ def delete_torch_paths():
108
108
 
109
109
  if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1:
110
110
  raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure "
111
- f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.")
111
+ f"the PYTHONPATH environment variable depth does not "
112
+ f"exceed {MsCompareConst.MAX_RECURSION_DEPTH}.")
112
113
 
113
114
 
114
115
  if not is_mindtorch():
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,21 +13,50 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
16
+ from collections import OrderedDict
17
+
18
+ from mindspore import Tensor
19
+ from mindspore.common.hook_handle import HookHandle
20
+ from mindspore.ops.operations import _inner_ops as inner
21
+
17
22
  from msprobe.core.common.const import Const
23
+ from msprobe.core.common.exceptions import MsprobeException
24
+ from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
25
+ from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.log import logger
27
+ from msprobe.mindspore.common.utils import (
28
+ is_mindtorch,
29
+ get_cells_and_names_with_index,
30
+ has_kwargs_in_forward_hook,
31
+ is_graph_mode_cell_dump_allowed
32
+ )
33
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
34
+ from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
35
+ from msprobe.core.common.runtime import Runtime
36
+
37
+
38
+ def get_cell_construct(construct):
39
+ def _construct(self, *args, **kwargs):
40
+ if hasattr(self, 'msprobe_hook'):
41
+ setattr(self, 'msprobe_input_kwargs', kwargs)
42
+ return construct(self, *args, **kwargs)
43
+ return _construct
18
44
 
19
45
 
20
46
  class CellProcessor:
21
47
  cell_count = {}
22
48
  cell_stack = []
23
- api_parent_node = ""
49
+ api_parent_node = None
24
50
  module_node = {}
51
+ cell_bw_hook_kernels = {}
52
+ cell_backward_pre_hook = []
53
+ cell_backward_hook = []
25
54
 
26
55
  def __init__(self, scope):
27
56
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
28
57
 
29
58
  @staticmethod
30
- def set_cell_count(cell_name):
59
+ def set_and_get_calls_number(cell_name):
31
60
  if cell_name not in CellProcessor.cell_count:
32
61
  CellProcessor.cell_count[cell_name] = 0
33
62
  else:
@@ -38,42 +67,184 @@ class CellProcessor:
38
67
  def reset_cell_stats(cls):
39
68
  cls.cell_count = {}
40
69
  cls.cell_stack = []
41
- cls.api_parent_node = ""
70
+ cls.api_parent_node = None
42
71
  cls.module_node = {}
72
+ cls.cell_bw_hook_kernels = {}
73
+ cls.cell_backward_pre_hook = []
74
+ cls.cell_backward_hook = []
43
75
 
44
- def node_hook(self, name_prefix, start_or_stop, **kwargs):
45
- def begin_hook(cell, input_data):
46
- full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True)
47
- if CellProcessor.cell_stack:
48
- CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1]
49
- else:
50
- CellProcessor.module_node[full_name] = None
76
+ def register_cell_hook(self, models, build_hook, config: DebuggerConfig):
77
+ if not models:
78
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
79
+ 'The model cannot be None, when level is "L0" or "mix"')
80
+
81
+ is_registered = False
82
+ model_type = Const.MODULE if is_mindtorch() else Const.CELL
83
+ cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models)
84
+ construct_name = '_call_impl' if is_mindtorch() else '_run_construct'
85
+
86
+ for index, cells_and_names in cells_with_index_in_pynative_mode.items():
87
+ model = models if index == "-1" else models[int(index)]
88
+ for name, cell in cells_and_names:
89
+ if cell == model:
90
+ continue
91
+
92
+ if not has_kwargs_in_forward_hook():
93
+ if not hasattr(cell.__class__, 'msprobe_construct'):
94
+ setattr(cell.__class__, 'msprobe_construct', True)
95
+ if hasattr(cell.__class__, construct_name):
96
+ setattr(cell.__class__, construct_name,
97
+ get_cell_construct(getattr(cell.__class__, construct_name)))
98
+ setattr(cell, 'msprobe_hook', True)
99
+
100
+ cell_index = (index + Const.SEP) if index != "-1" else ""
101
+ prefix = f'{model_type}{Const.SEP}{cell_index}{name}{Const.SEP}{cell.__class__.__name__}{Const.SEP}'
102
+
103
+ forward_pre_hook = self.build_cell_hook(prefix, build_hook)
104
+ cell.register_forward_pre_hook(forward_pre_hook)
105
+
106
+ if not is_registered:
107
+ logger.info("The cell hook function is successfully mounted to the model.")
108
+ is_registered = True
109
+
110
+ if is_graph_mode_cell_dump_allowed(config):
111
+ cells_and_names_in_graph_mode = []
112
+ for index, cells_and_names in cells_with_index_in_graph_mode.items():
113
+ model = models if index == "-1" else models[int(index)]
114
+ for name, cell in cells_and_names:
115
+ if cell == model:
116
+ continue
117
+ cell_index = (index + Const.SEP) if index != "-1" else ""
118
+ cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell))
119
+
120
+ if cells_and_names_in_graph_mode:
121
+ Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE
122
+ GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
51
123
 
52
- CellProcessor.cell_stack.append(full_name)
53
- CellProcessor.api_parent_node = full_name
124
+ def build_cell_hook(self, cell_name, build_data_hook):
125
+ def forward_pre_hook(cell, args):
126
+ index = CellProcessor.set_and_get_calls_number(cell_name)
127
+ full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
128
+ full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
54
129
 
55
- if self.scope:
56
- self.scope.begin_module(full_name)
130
+ self.set_construct_info_in_pre_hook(full_forward_name)
57
131
 
58
- def end_hook(cell, input_data, output_data):
59
- if CellProcessor.cell_stack:
60
- CellProcessor.cell_stack.pop()
61
- if CellProcessor.cell_stack:
62
- CellProcessor.api_parent_node = CellProcessor.cell_stack[-1]
132
+ if not hasattr(cell, 'msprobe_forward_hook'):
133
+ if is_mindtorch():
134
+ cell.register_forward_hook(forward_hook, prepend=True, with_kwargs=True)
135
+ else:
136
+ forward_hook_dict = getattr(cell, '_forward_hook', OrderedDict())
137
+ if has_kwargs_in_forward_hook():
138
+ forward_hook_with_kwargs_dict = getattr(cell, '_forward_hook_with_kwargs', OrderedDict())
139
+ handle = HookHandle(forward_hook_dict, extra_dict=forward_hook_with_kwargs_dict)
140
+ forward_hook_with_kwargs_dict[handle.handle_id] = True
141
+ else:
142
+ handle = HookHandle(forward_hook_dict)
143
+ forward_hook_dict[handle.handle_id] = forward_hook
144
+ forward_hook_dict.move_to_end(handle.handle_id, last=False)
145
+
146
+ setattr(cell, 'msprobe_forward_hook', True)
147
+
148
+ def get_backward_hook(backward_data_hook, full_backward_name):
149
+ def backward_hook_fn(cell, grad_input, grad_output):
150
+ new_output = backward_data_hook(cell, grad_input, grad_output)
151
+ self.set_construct_info_in_hook(full_backward_name)
152
+ cell.has_pre_hook_called = False
153
+ return new_output
154
+ return backward_hook_fn
155
+
156
+ enable_hooked = sum(
157
+ [isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args]
158
+ )
159
+ if enable_hooked:
160
+ backward_hook = OrderedDict()
161
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
162
+ backward_hook[full_backward_name] = get_backward_hook(hook_set.backward_hook, full_backward_name)
163
+ CellProcessor.cell_backward_hook.append(backward_hook)
164
+ bw_hook = inner.CellBackwardHook(full_backward_name, cell,
165
+ self.cell_backward_hook[-1])
166
+ bw_hook.register_backward_hook()
167
+ CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook
168
+
169
+ args = bw_hook(*args)
170
+
171
+ return args
172
+
173
+ def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None):
174
+ index = CellProcessor.cell_count.get(cell_name, 0)
175
+ full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
176
+ full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
177
+
178
+ self.set_construct_info_in_hook(full_forward_name)
179
+
180
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
181
+ hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs)
182
+ if hook_result is not None:
183
+ outputs = hook_result
63
184
  else:
64
- CellProcessor.api_parent_node = None
185
+ outputs = output_or_kwargs if has_kwargs_in_forward_hook() else kwargs_or_output
186
+
187
+ bw_hook = CellProcessor.cell_bw_hook_kernels.get(full_forward_name)
188
+ if bw_hook:
189
+ if not isinstance(outputs, (Tensor, tuple)):
190
+ logger.warning("For backward hooks to be called,"
191
+ " cell output should be a Tensor or a tuple of Tensors"
192
+ f" but received {type(outputs)}")
193
+ if isinstance(outputs, tuple):
194
+ new_outputs = bw_hook(*outputs)
195
+ else:
196
+ new_outputs = bw_hook(outputs)
197
+ if isinstance(outputs, tuple) and len(outputs) == 1:
198
+ new_outputs = (new_outputs,)
199
+ outputs = new_outputs
200
+
201
+ def get_backward_pre_hook(full_backward_name, backward_data_hook):
202
+ def backward_pre_hook_fn(cell, grad_output):
203
+ cell.has_pre_hook_called = True
204
+ self.set_construct_info_in_pre_hook(full_backward_name)
205
+ if backward_data_hook:
206
+ backward_data_hook(cell, (), grad_output)
207
+ self.set_construct_info_in_hook(full_backward_name)
208
+ cell.has_pre_hook_called = False
209
+ return backward_pre_hook_fn
65
210
 
66
- if self.scope:
67
- self.scope.end_module(cell.mindstudio_reserved_name)
211
+ backward_pre_hook = OrderedDict()
212
+ backward_data_hook = None if bw_hook else hook_set.backward_hook
213
+ backward_pre_hook[full_backward_name] = get_backward_pre_hook(full_backward_name, backward_data_hook)
214
+ CellProcessor.cell_backward_pre_hook.append(backward_pre_hook)
215
+ bw_pre_hook = inner.CellBackwardHook(full_backward_name, cell,
216
+ self.cell_backward_pre_hook[-1])
217
+ bw_pre_hook.register_backward_pre_hook()
68
218
 
69
- return begin_hook if Const.START == start_or_stop else end_hook
219
+ if isinstance(outputs, tuple):
220
+ result = bw_pre_hook(*outputs)
221
+ else:
222
+ result = bw_pre_hook(outputs)
223
+ if isinstance(outputs, tuple):
224
+ if len(outputs) == 1:
225
+ result = (result,)
226
+ if len(result) != len(outputs):
227
+ raise TypeError(
228
+ f"The backward pre hook return value size is {len(result)} "
229
+ f"not equal to output size {len(outputs)}"
230
+ )
231
+ return result
232
+
233
+ return forward_pre_hook
70
234
 
71
- def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False):
72
- if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called:
73
- cell.has_pre_hook_called = False
235
+ def set_construct_info_in_pre_hook(self, full_name):
236
+ if self.cell_stack:
237
+ CellProcessor.module_node[full_name] = self.cell_stack[-1]
74
238
  else:
75
- if is_called_by_pre_hook:
76
- cell.has_pre_hook_called = True
77
- index = self.set_cell_count(cell_name)
78
- cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index)
79
- return cell.mindstudio_reserved_name
239
+ CellProcessor.module_node[full_name] = None
240
+ CellProcessor.cell_stack.append(full_name)
241
+ CellProcessor.api_parent_node = full_name
242
+ if self.scope:
243
+ self.scope.begin_module(full_name)
244
+
245
+ def set_construct_info_in_hook(self, full_name):
246
+ if self.cell_stack:
247
+ CellProcessor.cell_stack.pop()
248
+ CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] if self.cell_stack else None
249
+ if self.scope:
250
+ self.scope.end_module(full_name)
@@ -34,19 +34,6 @@ class Parser:
34
34
  if isinstance(subgraph_node.attrs, list):
35
35
  subgraph_node.attrs.extend(attrs)
36
36
 
37
- @staticmethod
38
- def parse_graph_attributes(text: str, graph_node: GraphNode) -> None:
39
- attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL)
40
- match = attr_pattern.search(text, graph_node.pos)
41
- if match:
42
- attrs = match.group(1).strip().split('\n')
43
- for attr in attrs:
44
- if not attr:
45
- break
46
- key, value = attr.split(':')
47
- if isinstance(graph_node.attrs, dict):
48
- graph_node.attrs[key.strip()] = value.strip()
49
-
50
37
  @staticmethod
51
38
  def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]:
52
39
  code_info = []
@@ -124,8 +111,9 @@ class Parser:
124
111
  scope_match = scope_pattern.search(text, end_pos)
125
112
  scope = scope_match.group(1) if scope_match else ""
126
113
 
127
- id_pattern = re.compile(r'.*cnode_primal_attrs:'
128
- r'\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE)
114
+ id_pattern = re.compile(
115
+ r'cnode_primal_attrs:'r'\s*\{[\w+]{1, 10000}\b(?:forward_unique_id|unique_id):\s*\"(\d+)\"',
116
+ re.IGNORECASE)
129
117
  unique_id_match = id_pattern.search(text, end_pos, scope_match.start())
130
118
  unique_id = unique_id_match.group(1) if unique_id_match else None
131
119
 
@@ -186,7 +174,7 @@ class Parser:
186
174
  node_info.var_inputs.append(callee_name)
187
175
 
188
176
  def parse_subgraphs(self, text: str) -> None:
189
- subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{')
177
+ subgraph_pattern = re.compile(r'/subgraph\s+@([\w+]{1,1000)(\([^\)]{1,100}\))?\s+\S[^\{]\{/+')
190
178
  matches = list(subgraph_pattern.finditer(text))
191
179
  end_pos = 0
192
180
  for match in matches:
@@ -203,11 +191,6 @@ class Parser:
203
191
  subgraph_info.end = end_pos
204
192
  logging.info('Parsed subgraph: %s', subgraph_name)
205
193
 
206
- def count_nodes(self) -> Tuple[int, int]:
207
- total_nodes = len(self.nodes)
208
- total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode'))
209
- return total_nodes, total_cnodes
210
-
211
194
  def create_backward_map(self):
212
195
  for node in self.nodes.values():
213
196
  if node.scope and node.scope.startswith("Gradients"):
@@ -15,6 +15,7 @@
15
15
 
16
16
  import numpy as np
17
17
  import mindspore as ms
18
+ from mindspore import dtype as mstype
18
19
 
19
20
  from msprobe.core.common.const import Const as CoreConst
20
21
 
@@ -23,14 +24,20 @@ class Const:
23
24
  CELL = "cell"
24
25
  API = "api"
25
26
  KERNEL = "kernel"
27
+ CELL_AND_API = 'cell_and_api'
26
28
  TOOL_LEVEL_DICT = {
27
29
  CoreConst.LEVEL_L0: CELL,
28
30
  CoreConst.LEVEL_L1: API,
29
- CoreConst.LEVEL_L2: KERNEL
31
+ CoreConst.LEVEL_L2: KERNEL,
32
+ CoreConst.LEVEL_MIX: CELL_AND_API
30
33
  }
31
- PYNATIVE_MODE = "pynative"
34
+
35
+ PYNATIVE_MODE = CoreConst.PYNATIVE_MODE
36
+ GRAPH_MODE = "graph"
32
37
  GRAPH_GE_MODE = "graph_ge"
33
38
  GRAPH_KBYK_MODE = "graph_kbyk"
39
+ PYNATIVE_GRAPH_MODE = CoreConst.PYNATIVE_GRAPH_MODE
40
+
34
41
  JIT_LEVEL = "jit_level"
35
42
  JIT_LEVEL_O0 = "O0"
36
43
  JIT_LEVEL_O1 = "O1"
@@ -61,6 +68,7 @@ class Const:
61
68
  DROPOUT_API_NAME_PREFIX = "dropout"
62
69
 
63
70
  GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT]
71
+ GRAPH_CELL_DUMP_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.FORWARD, CoreConst.BACKWARD]
64
72
 
65
73
  HOOK_MS_PREFIX_DICT = {
66
74
  OPS_DATA_PREFIX: OPS_PREFIX,
@@ -69,6 +77,13 @@ class Const:
69
77
  MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX
70
78
  }
71
79
 
80
+ NonDifferentiableType = (
81
+ mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte,
82
+ mstype.int16, mstype.short, mstype.uint16, mstype.ushort,
83
+ mstype.int32, mstype.intc, mstype.uint32, mstype.uintc,
84
+ mstype.int64, mstype.intp, mstype.uint64, mstype.uintp
85
+ )
86
+
72
87
 
73
88
  class MsCompareConst:
74
89
  # api_info field
@@ -88,14 +103,11 @@ class MsCompareConst:
88
103
  MINDTORCH_NPU = "NPU"
89
104
  MINDTORCH_DIST = "Distributed"
90
105
 
91
-
92
-
93
106
  MT_VALID_API_TYPES = [
94
107
  MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
95
108
  ]
96
109
  SUPPORTED_FUSION_LIST = ["flash_attention_score"]
97
110
 
98
-
99
111
  TASK_FIELD = "task"
100
112
  STATISTICS_TASK = "statistics"
101
113
  FRAMEWORK = "framework"
@@ -129,8 +141,6 @@ class MsCompareConst:
129
141
  EXCEPTION_SKIP = "exception_skip"
130
142
 
131
143
 
132
-
133
-
134
144
  class FreeBenchmarkConst:
135
145
  ADD_NOISE = "add_noise"
136
146
  BIT_NOISE = "bit_noise"
@@ -13,19 +13,34 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import inspect
16
17
  import os
17
18
  import random
19
+ import types
18
20
 
19
21
  import mindspore as ms
20
-
21
22
  from mindspore import ops
23
+ from mindspore.common.jit_config import JitConfig
22
24
  from mindspore.mint import nn
23
25
 
26
+ from msprobe.core.common.const import Const
27
+ from msprobe.core.common.decorator import recursion_depth_decorator
24
28
  from msprobe.core.common.exceptions import DistributedNotInitializedError
25
29
  from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
26
30
  from msprobe.core.common.log import logger
27
- from msprobe.core.common.const import Const
28
31
  from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
32
+ from msprobe.mindspore.common.const import Const as MsConst
33
+
34
+ try:
35
+ from mindspore._c_expression import _set_init_iter
36
+ except ImportError:
37
+ enable_dynamic_kbyk_dump = False
38
+ else:
39
+ enable_dynamic_kbyk_dump = True
40
+
41
+ mindtorch_check_result = None
42
+ register_backward_hook_functions = {}
43
+ kwargs_exist_in_forward_hook = None
29
44
 
30
45
 
31
46
  class MsprobeStep(ms.train.Callback):
@@ -33,6 +48,11 @@ class MsprobeStep(ms.train.Callback):
33
48
  super(MsprobeStep, self).__init__()
34
49
  self.debugger = debugger
35
50
 
51
+ def on_train_begin(self, run_context):
52
+ self.debugger.start()
53
+ if enable_dynamic_kbyk_dump:
54
+ _set_init_iter(0)
55
+
36
56
  def on_train_step_begin(self, run_context):
37
57
  self.debugger.start()
38
58
 
@@ -82,8 +102,8 @@ def convert_to_int(value):
82
102
 
83
103
 
84
104
  def clean_input_kwargs(cell):
85
- if hasattr(cell, 'input_kwargs'):
86
- del cell.input_kwargs
105
+ if hasattr(cell, 'msprobe_input_kwargs'):
106
+ del cell.msprobe_input_kwargs
87
107
 
88
108
 
89
109
  def list_lowest_level_directories(root_dir):
@@ -152,9 +172,6 @@ def remove_dropout():
152
172
  nn.functional.dropout = dropout_ext
153
173
 
154
174
 
155
- mindtorch_check_result = None
156
-
157
-
158
175
  def is_mindtorch():
159
176
  global mindtorch_check_result
160
177
  if mindtorch_check_result is None:
@@ -169,11 +186,11 @@ def is_mindtorch():
169
186
  return mindtorch_check_result
170
187
 
171
188
 
172
- register_backward_hook_functions = {}
173
-
174
-
175
189
  def set_register_backward_hook_functions():
176
190
  global register_backward_hook_functions
191
+ if register_backward_hook_functions:
192
+ return
193
+
177
194
  if is_mindtorch():
178
195
  import torch
179
196
  from msprobe.mindspore.mindtorch import (_call_impl,
@@ -192,7 +209,7 @@ def set_register_backward_hook_functions():
192
209
 
193
210
  def check_save_param(variable, name, save_backward):
194
211
  # try catch this api to skip invalid call
195
- valid_data_types = tuple([ms.Tensor, int, float, str])
212
+ valid_data_types = (ms.Tensor, int, float, str)
196
213
  if not is_save_variable_valid(variable, valid_data_types):
197
214
  valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
198
215
  logger.warning("PrecisionDebugger.save variable type not valid, "
@@ -209,3 +226,103 @@ def check_save_param(variable, name, save_backward):
209
226
  "should be bool. "
210
227
  "Skip current save process.")
211
228
  raise ValueError
229
+
230
+
231
+ def is_graph_mode_cell_dump_allowed(config):
232
+ if config.task not in [Const.TENSOR, Const.STATISTICS] or is_mindtorch() or not hasattr(ops, 'DumpGradient'):
233
+ return False
234
+ valid_mix_level = [MsConst.CELL_AND_API, Const.LEVEL_MIX]
235
+ if config.level in valid_mix_level and config.execution_mode == MsConst.PYNATIVE_MODE:
236
+ return True
237
+ return config.level == MsConst.CELL or config.level == Const.LEVEL_L0
238
+
239
+
240
+ @recursion_depth_decorator('msprobe.mindspore.common.utils.is_decorated_by_jit')
241
+ def is_decorated_by_jit(func):
242
+ closure = getattr(func, '__closure__', [])
243
+ if closure:
244
+ for obj in closure:
245
+ if isinstance(obj.cell_contents, JitConfig):
246
+ return True
247
+ elif isinstance(obj.cell_contents, types.FunctionType) and hasattr(obj.cell_contents, '__closure__'):
248
+ if is_decorated_by_jit(obj.cell_contents):
249
+ return True
250
+ return False
251
+
252
+
253
+ @recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names')
254
+ def get_cells_and_names(model, cells_set=None, name_prefix=''):
255
+ cells_set = cells_set if cells_set else set()
256
+ if model in cells_set:
257
+ return
258
+
259
+ cells_set.add(model)
260
+ jit_decorated = is_decorated_by_jit(model.construct)
261
+ yield name_prefix, model, jit_decorated
262
+ if jit_decorated:
263
+ return
264
+
265
+ children_cells = getattr(model, '_cells')
266
+ for name, cell in children_cells.items():
267
+ if cell:
268
+ cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name
269
+ jit_decorated = is_decorated_by_jit(model.construct)
270
+ if jit_decorated:
271
+ yield cells_name_prefix, cell, jit_decorated
272
+ else:
273
+ for ele in get_cells_and_names(cell, cells_set, cells_name_prefix):
274
+ yield ele
275
+
276
+
277
+ def get_cells_and_names_with_index(models):
278
+ cells_with_index_in_pynative_mode = {}
279
+ cells_with_index_in_graph_mode = {}
280
+
281
+ def distinguish_cells(cells):
282
+ cells_in_pynative_mode = []
283
+ cells_in_graph_mode = []
284
+ for name, cell, jit_decorated in cells:
285
+ if jit_decorated:
286
+ cells_in_graph_mode.append((name, cell))
287
+ else:
288
+ cells_in_pynative_mode.append((name, cell))
289
+ return cells_in_pynative_mode, cells_in_graph_mode
290
+
291
+ if is_mindtorch():
292
+ if isinstance(models, (list, tuple)):
293
+ for index, model in enumerate(models):
294
+ cells_with_index_in_pynative_mode[str(index)] = model.named_modules()
295
+ else:
296
+ cells_with_index_in_pynative_mode["-1"] = models.named_modules()
297
+ else:
298
+ if isinstance(models, (list, tuple)):
299
+ for index, model in enumerate(models):
300
+ cells = get_cells_and_names(model)
301
+ cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
302
+ cells_with_index_in_pynative_mode[str(index)] = cells_in_pynative_mode
303
+ cells_with_index_in_graph_mode[str(index)] = cells_in_graph_mode
304
+ else:
305
+ cells = get_cells_and_names(models)
306
+ cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
307
+ cells_with_index_in_pynative_mode["-1"] = cells_in_pynative_mode
308
+ cells_with_index_in_graph_mode["-1"] = cells_in_graph_mode
309
+
310
+ return cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode
311
+
312
+
313
+ def has_kwargs_in_forward_hook():
314
+ global kwargs_exist_in_forward_hook
315
+
316
+ if kwargs_exist_in_forward_hook is None:
317
+ if is_mindtorch():
318
+ kwargs_exist_in_forward_hook = True
319
+ return kwargs_exist_in_forward_hook
320
+
321
+ try:
322
+ func_params = inspect.signature(nn.Cell.register_forward_hook).parameters
323
+ kwargs_exist_in_forward_hook = 'with_kwargs' in func_params
324
+ except Exception:
325
+ kwargs_exist_in_forward_hook = False
326
+ return kwargs_exist_in_forward_hook
327
+
328
+ return kwargs_exist_in_forward_hook