mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,92 +13,21 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os.path
16
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
17
+ from msprobe.pytorch.compare.utils import read_pt_data
17
18
 
18
- import torch
19
19
 
20
- from msprobe.core.common.const import FileCheckConst
21
- from msprobe.core.common.exceptions import FileCheckException
22
- from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
23
- from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
24
- set_dump_path
25
- from msprobe.core.compare.acc_compare import Comparator, ModeConfig
26
- from msprobe.core.compare.utils import set_stack_json_path
27
- from msprobe.pytorch.common.log import logger
28
- from msprobe.pytorch.common.utils import load_pt
29
-
30
-
31
- class PTComparator(Comparator):
32
- def __init__(self, mode_config, data_mapping=None):
33
- super().__init__(mode_config)
34
-
35
- self.stack_mode = mode_config.stack_mode
36
- self.auto_analyze = mode_config.auto_analyze
37
- self.fuzzy_match = mode_config.fuzzy_match
38
- self.dump_mode = mode_config.dump_mode
39
-
40
- self.frame_name = PTComparator.__name__
41
- self.data_mapping = data_mapping
42
- if isinstance(self.data_mapping, str) or self.data_mapping is None:
43
- self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
44
- elif isinstance(self.data_mapping, dict):
45
- self.data_mapping_dict = self.data_mapping
46
- else:
47
- raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
48
- f"{type(self.data_mapping)}")
49
-
50
- @staticmethod
51
- def load_mapping_file(mapping_file):
52
- if isinstance(mapping_file, str):
53
- mapping_dict = load_yaml(mapping_file)
54
- else:
55
- mapping_dict = {}
56
- return mapping_dict
57
-
58
- def read_npy_data(self, dir_path, file_name):
59
- if not file_name:
60
- return None
61
- data_path = os.path.join(dir_path, file_name)
62
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
63
- FileCheckConst.PT_SUFFIX, False)
64
- data_path = path_checker.common_check()
65
- try:
66
- # detach because numpy can not process gradient information
67
- data_value = load_pt(data_path, to_cpu=True).detach()
68
- except RuntimeError as e:
69
- # 这里捕获 load_pt 中抛出的异常
70
- logger.error(f"Failed to load the .pt file at {data_path}.")
71
- raise CompareException(CompareException.INVALID_FILE_ERROR) from e
72
- except AttributeError as e:
73
- # 这里捕获 detach 方法抛出的异常
74
- logger.error(f"Failed to detach the loaded tensor.")
75
- raise CompareException(CompareException.DETACH_ERROR) from e
76
- if data_value.dtype == torch.bfloat16:
77
- data_value = data_value.to(torch.float32)
78
- data_value = data_value.numpy()
79
- return data_value
20
+ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tuple:
21
+ n_value = read_pt_data(npu_dir, npu_data_name)
22
+ b_value = read_pt_data(bench_dir, bench_data_name)
23
+ return n_value, b_value
80
24
 
81
25
 
82
26
  def compare(input_param, output_path, **kwargs):
83
- try:
84
- auto_analyze = kwargs.get('auto_analyze', True)
85
- fuzzy_match = kwargs.get('fuzzy_match', False)
86
- data_mapping = kwargs.get('data_mapping', None)
87
- suffix = kwargs.get('suffix', '')
88
-
89
- set_dump_path(input_param)
90
- dump_mode = get_dump_mode(input_param)
91
- if "stack_json_path" in input_param:
92
- stack_mode = kwargs.get('stack_mode', False)
93
- else:
94
- stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
95
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
96
- create_directory(output_path)
97
- check_compare_param(input_param, output_path, dump_mode, stack_mode)
98
- except (CompareException, FileCheckException) as error:
99
- logger.error('Compare failed. Please check the arguments and do it again!')
100
- raise CompareException(error.code) from error
27
+ config = setup_comparison(input_param, output_path, **kwargs)
101
28
 
102
- mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
103
- pt_comparator = PTComparator(mode_config, data_mapping)
104
- pt_comparator.compare_core(input_param, output_path, suffix=suffix)
29
+ mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
30
+ config.dump_mode, config.compared_file_type)
31
+ mapping_config = MappingConfig(data_mapping=config.data_mapping)
32
+ pt_comparator = Comparator(read_real_data, mode_config, mapping_config)
33
+ pt_comparator.compare_core(input_param, output_path, suffix=config.suffix)
@@ -0,0 +1,47 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import torch
19
+
20
+ from msprobe.core.common.utils import logger, CompareException
21
+ from msprobe.core.common.file_utils import FileChecker, FileCheckConst
22
+ from msprobe.pytorch.common.utils import load_pt
23
+
24
+
25
+ def read_pt_data(dir_path, file_name):
26
+ if not file_name:
27
+ return None
28
+
29
+ data_path = os.path.join(dir_path, file_name)
30
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
31
+ FileCheckConst.PT_SUFFIX, False)
32
+ data_path = path_checker.common_check()
33
+ try:
34
+ # detach because numpy can not process gradient information
35
+ data_value = load_pt(data_path, to_cpu=True).detach()
36
+ except RuntimeError as e:
37
+ # 这里捕获 load_pt 中抛出的异常
38
+ logger.error(f"Failed to load the .pt file at {data_path}.")
39
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from e
40
+ except AttributeError as e:
41
+ # 这里捕获 detach 方法抛出的异常
42
+ logger.error(f"Failed to detach the loaded tensor.")
43
+ raise CompareException(CompareException.DETACH_ERROR) from e
44
+ if data_value.dtype == torch.bfloat16:
45
+ data_value = data_value.to(torch.float32)
46
+ data_value = data_value.numpy()
47
+ return data_value
@@ -13,11 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import torch
17
-
18
16
  from msprobe.core.common.const import Const
19
17
  from msprobe.core.common.exceptions import MsprobeException
20
18
  from msprobe.pytorch.common.log import logger
19
+ from msprobe.pytorch.common.utils import is_torch_nn_module
21
20
 
22
21
 
23
22
  class DebuggerConfig:
@@ -60,6 +59,7 @@ class DebuggerConfig:
60
59
  if isinstance(task_config.online_run_ut_recompute, bool) else False
61
60
 
62
61
  self.check()
62
+ self._check_statistics_config(task_config)
63
63
 
64
64
  if self.level == Const.LEVEL_L2:
65
65
  self.is_backward_kernel_dump = False
@@ -78,10 +78,13 @@ class DebuggerConfig:
78
78
  if not isinstance(self.async_dump, bool):
79
79
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
80
  f"The parameters async_dump should be bool.")
81
- if self.async_dump and self.task == Const.TENSOR and not self.list:
82
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
83
- f"The parameters async_dump is true in tensor task, the parameters list cannot be "
84
- f"empty.")
81
+ if self.async_dump and self.task == Const.TENSOR:
82
+ if self.level == Const.LEVEL_DEBUG:
83
+ self.list = [] # async_dump + debug level case ignore list
84
+ if not self.list and self.level != Const.LEVEL_DEBUG:
85
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
86
+ f"The parameters async_dump is true in tensor task, the parameters list cannot be "
87
+ f"empty.")
85
88
  if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
86
89
  logger.warning_on_rank_0(
87
90
  f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
@@ -93,25 +96,24 @@ class DebuggerConfig:
93
96
  self.check_kwargs()
94
97
  return True
95
98
 
96
- def check_model(self, instance, start_model):
97
- if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
98
- if instance.model is not None or start_model is not None:
99
- logger.info_on_rank_0(
100
- f"The current level is not L0 or mix level, so the model parameters will not be used.")
99
+ def check_model(self, instance, start_model, token_range=None):
100
+ instance.model = start_model if start_model is not None else instance.model
101
+ if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
101
102
  return
102
- if start_model is None and instance.model is None:
103
+
104
+ if instance.model is None:
103
105
  logger.error_on_rank_0(
104
- f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
106
+ f"For level {self.level} or non-empty token_range, "
107
+ f"PrecisionDebugger or start interface must receive a 'model' parameter.")
105
108
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
106
109
 
107
- instance.model = start_model if start_model is not None else instance.model
108
- if isinstance(instance.model, torch.nn.Module):
110
+ if is_torch_nn_module(instance.model):
109
111
  return
110
112
 
111
113
  error_model = None
112
114
  if isinstance(instance.model, (list, tuple)):
113
115
  for model in instance.model:
114
- if not isinstance(model, torch.nn.Module):
116
+ if not is_torch_nn_module(model):
115
117
  error_model = model
116
118
  break
117
119
  else:
@@ -119,7 +121,7 @@ class DebuggerConfig:
119
121
 
120
122
  if error_model is not None:
121
123
  error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
122
- f"type, currently there is a {type(error_model)} type.")
124
+ f"type, currently there is an unsupported {type(error_model)} type.")
123
125
  raise MsprobeException(
124
126
  MsprobeException.INVALID_PARAM_ERROR, error_info)
125
127
 
@@ -130,8 +132,23 @@ class DebuggerConfig:
130
132
  if not self.list or len(self.list) != 1:
131
133
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
132
134
  f"When level is set to L2, the list must be configured as a list with one api name.")
135
+ if self.task != Const.TENSOR:
136
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
137
+ f"When level is set to L2, the task must be set to tensor.")
138
+
133
139
  api_name = self.list[0]
134
140
  if api_name.endswith(Const.BACKWARD):
135
141
  self.is_backward_kernel_dump = True
136
142
  api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
137
143
  self.list.append(api_forward_name)
144
+
145
+ def _check_statistics_config(self, task_config):
146
+ if self.task != Const.STATISTICS:
147
+ return
148
+ self.tensor_list = []
149
+ if not hasattr(task_config, "tensor_list"):
150
+ return
151
+ if self.level == Const.LEVEL_DEBUG and task_config.tensor_list:
152
+ logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.")
153
+ return
154
+ self.tensor_list = task_config.tensor_list
@@ -13,36 +13,22 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from collections import namedtuple
16
+ from torch.utils.data import dataloader
17
17
 
18
- import torch
19
- from msprobe.core.common.const import Const, FileCheckConst, MsgConst
18
+ from msprobe.core.common.const import Const, MsgConst
20
19
  from msprobe.core.common.exceptions import MsprobeException
21
- from msprobe.core.common.file_utils import FileChecker
22
- from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
20
+ from msprobe.core.common.utils import check_token_range
21
+ from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
23
22
  from msprobe.pytorch.common.log import logger
24
- from msprobe.pytorch.common.utils import check_save_param
23
+ from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module
25
24
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
26
25
  from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
27
26
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
28
- from msprobe.pytorch.pt_config import parse_json_config
29
- from msprobe.pytorch.service import Service
30
- from torch.utils.data import dataloader
31
-
32
- ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
33
- "dump_path", "level", "model"])
27
+ from msprobe.pytorch.pytorch_service import PytorchService
28
+ from msprobe.pytorch.pt_config import parse_task_config
34
29
 
35
30
 
36
- class PrecisionDebugger:
37
- _instance = None
38
- tasks_not_need_debugger = [Const.GRAD_PROBE]
39
-
40
- def __new__(cls, *args, **kwargs):
41
- if cls._instance is None:
42
- cls._instance = super(PrecisionDebugger, cls).__new__(cls)
43
- cls._instance.config = None
44
- cls._instance.enable_dataloader = False
45
- return cls._instance
31
+ class PrecisionDebugger(BasePrecisionDebugger):
46
32
 
47
33
  def __init__(
48
34
  self,
@@ -53,90 +39,65 @@ class PrecisionDebugger:
53
39
  model=None,
54
40
  step=None
55
41
  ):
56
- if not hasattr(self, "initialized"):
57
- config_params = ConfigParameters(config_path,
58
- task,
59
- dump_path,
60
- level,
61
- model)
62
- self.check_input_params(config_params)
63
-
64
- self.initialized = True
65
- self.model = model
66
- common_config, task_config = parse_json_config(config_path, task)
67
- self.task = task if task else common_config.task
68
- if self.task == Const.GRAD_PROBE:
69
- self.gm = GradientMonitor(common_config, task_config)
70
- return
71
- if step is not None:
72
- common_config.step = get_real_step_or_rank(step, Const.STEP)
73
- self.config = DebuggerConfig(
74
- common_config, task_config, task, dump_path, level
75
- )
76
- self.service = Service(self.config)
77
- self.module_dumper = ModuleDumper(self.service)
78
- self.enable_dataloader = self.config.enable_dataloader
79
- if self.enable_dataloader:
80
- logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
81
- dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
82
-
83
- @property
84
- def instance(self):
85
- return self._instance
42
+ if self.initialized:
43
+ return
44
+ super().__init__(config_path, task, dump_path, level, step)
45
+ self.model = model
46
+ if self.task == Const.GRAD_PROBE:
47
+ self.gm = GradientMonitor(self.common_config, self.task_config)
48
+ return
49
+ self.config = DebuggerConfig(
50
+ self.common_config, self.task_config, task, dump_path, level
51
+ )
52
+ self.service = PytorchService(self.config)
53
+ self.module_dumper = ModuleDumper(self.service)
54
+ self.ori_customer_func = {}
55
+ self.enable_dataloader = self.config.enable_dataloader
56
+ self._param_warning()
86
57
 
87
58
  @staticmethod
88
- def check_input_params(args):
89
- if args.config_path is not None:
90
- if not isinstance(args.config_path, str):
91
- raise MsprobeException(
92
- MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
93
- file_checker = FileChecker(
94
- file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
95
- file_checker.common_check()
59
+ def _get_task_config(task, json_config):
60
+ return parse_task_config(task, json_config)
96
61
 
97
- if args.task is not None and args.task not in Const.TASK_LIST:
98
- raise MsprobeException(
99
- MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
100
-
101
- if args.dump_path is not None:
102
- if not isinstance(args.dump_path, str):
62
+ @staticmethod
63
+ def _iter_tracer(func):
64
+ def func_wrapper(*args, **kwargs):
65
+ debugger_instance = PrecisionDebugger._instance
66
+ if not debugger_instance:
103
67
  raise MsprobeException(
104
- MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
68
+ MsprobeException.INTERFACE_USAGE_ERROR,
69
+ f"PrecisionDebugger must be instantiated before executing the dataloader iteration"
70
+ )
105
71
 
106
- if args.level is not None and args.level not in Const.LEVEL_LIST:
107
- raise MsprobeException(
108
- MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
72
+ debugger_instance.enable_dataloader = False
73
+ if not debugger_instance.service.first_start:
74
+ debugger_instance.stop()
75
+ debugger_instance.step()
76
+ result = func(*args, **kwargs)
77
+ debugger_instance.start()
78
+ debugger_instance.enable_dataloader = True
79
+ return result
109
80
 
110
- if args.model is not None:
111
- logger.warning_on_rank_0(
112
- "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
113
- "It is recommended to pass the 'model' parameter in the start interface instead."
114
- )
81
+ return func_wrapper
115
82
 
116
83
  @classmethod
117
- def start(cls, model=None):
118
- instance = cls._instance
119
- if not instance:
120
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
121
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
84
+ def start(cls, model=None, token_range=None):
85
+ instance = cls._get_instance()
86
+ if instance is None:
122
87
  return
123
- instance.config.check_model(instance, model)
88
+
89
+ check_token_range(token_range)
90
+ instance.config.check_model(instance, model, token_range)
91
+
124
92
  if instance.enable_dataloader:
125
93
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
126
94
  else:
127
- instance.service.start(instance.model)
128
-
129
- @classmethod
130
- def forward_backward_dump_end(cls):
131
- instance = cls._instance
132
- instance.stop()
95
+ instance.service.start(instance.model, token_range)
133
96
 
134
97
  @classmethod
135
98
  def stop(cls):
136
- instance = cls._instance
137
- if not instance:
138
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
139
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
99
+ instance = cls._get_instance()
100
+ if instance is None:
140
101
  return
141
102
  if instance.enable_dataloader:
142
103
  logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
@@ -145,9 +106,8 @@ class PrecisionDebugger:
145
106
 
146
107
  @classmethod
147
108
  def step(cls):
148
- if not cls._instance:
149
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
150
- if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
109
+ instance = cls._get_instance()
110
+ if instance is None:
151
111
  return
152
112
  cls._instance.service.step()
153
113
 
@@ -172,21 +132,23 @@ class PrecisionDebugger:
172
132
  return
173
133
  instance.service.save(variable, name, save_backward)
174
134
 
175
- @classmethod
176
- def set_init_step(cls, step):
177
- instance = cls._instance
178
- if not instance:
179
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
180
- check_init_step(step)
181
- instance.service.init_step = step
182
- instance.service.loop = 0
135
+ def _param_warning(self):
136
+ if self.model is not None:
137
+ logger.warning_on_rank_0(
138
+ "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
139
+ "It is recommended to pass the 'model' parameter in the start interface instead."
140
+ )
141
+ if self.enable_dataloader:
142
+ logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
143
+ dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__)
183
144
 
184
145
 
185
146
  def module_dump(module, dump_name):
186
- if not isinstance(module, torch.nn.Module):
147
+ if not is_torch_nn_module(module):
187
148
  raise MsprobeException(
188
149
  MsprobeException.INVALID_PARAM_ERROR,
189
- f"the module argument in module_dump must be a torch.nn.Module subclass"
150
+ f"the module argument in module_dump must be a torch.nn.Module type, "
151
+ f"but currently there is an unsupported {type(module)} type."
190
152
  )
191
153
  if not isinstance(dump_name, str):
192
154
  raise MsprobeException(
@@ -210,17 +172,3 @@ def module_dump_end():
210
172
  f"PrecisionDebugger must be instantiated before using module_dump_end interface"
211
173
  )
212
174
  instance.module_dumper.stop_module_dump()
213
-
214
-
215
- def iter_tracer(func):
216
- def func_wrapper(*args, **kwargs):
217
- debugger_instance = PrecisionDebugger.instance
218
- debugger_instance.enable_dataloader = False
219
- if not debugger_instance.service.first_start:
220
- debugger_instance.stop()
221
- debugger_instance.step()
222
- result = func(*args, **kwargs)
223
- debugger_instance.start()
224
- debugger_instance.enable_dataloader = True
225
- return result
226
- return func_wrapper
@@ -0,0 +1,93 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import wraps
17
+
18
+ import torch
19
+ from torch.utils.hooks import BackwardHook
20
+
21
+ from msprobe.core.common.const import Const
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
23
+ from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.common.utils import is_float8_tensor
25
+
26
+
27
+ def wrap_setup_backward_hook(func):
28
+ def requires_clone(tensor):
29
+ return isinstance(tensor, torch.Tensor) and not is_float8_tensor(tensor) and \
30
+ tensor.requires_grad and torch.is_grad_enabled()
31
+
32
+ @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
33
+ def parse_tensor(item, tensor_list):
34
+ if requires_clone(item):
35
+ tensor_list.append(item)
36
+ elif isinstance(item, (list, tuple)):
37
+ for value in item:
38
+ parse_tensor(value, tensor_list)
39
+ elif isinstance(item, dict):
40
+ for value in item.values():
41
+ parse_tensor(value, tensor_list)
42
+
43
+ @recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
44
+ def rebuild_args(item, tensor_iter):
45
+ if requires_clone(item):
46
+ result = next(tensor_iter)
47
+ if hasattr(result, "_base") and result._base is not None:
48
+ if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
49
+ torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
50
+ return result
51
+ if isinstance(item, list):
52
+ for index, value in enumerate(item):
53
+ item[index] = rebuild_args(value, tensor_iter)
54
+ return item
55
+ if isinstance(item, dict):
56
+ for key, value in item.items():
57
+ item[key] = rebuild_args(value, tensor_iter)
58
+ return item
59
+ if isinstance(item, tuple):
60
+ if hasattr(item, '_fields'):
61
+ return type(item)(*[rebuild_args(i, tensor_iter) for i in item])
62
+ return type(item)([rebuild_args(i, tensor_iter) for i in item])
63
+ return item
64
+
65
+ @wraps(func)
66
+ def wrap_setup_hook_func(*args, **kwargs):
67
+ if len(args) < 2:
68
+ return func(*args, **kwargs)
69
+
70
+ actual_args = args[1]
71
+
72
+ tensor_list = []
73
+
74
+ parse_tensor(actual_args, tensor_list)
75
+
76
+ new_args = args[0], tuple(tensor_list)
77
+ hooked_tensors = func(*new_args, **kwargs)
78
+
79
+ tensor_iter = iter(hooked_tensors)
80
+ try:
81
+ new_data = rebuild_args(actual_args, tensor_iter)
82
+ except Exception as e:
83
+ logger.debug(f"Unsupported data in setup input/output hook. The detail info: {e}")
84
+ new_data = actual_args
85
+
86
+ return new_data
87
+
88
+ return wrap_setup_hook_func
89
+
90
+
91
+ def wrap_setup_input_output_hook():
92
+ BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
93
+ BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)