mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,36 +13,22 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from collections import namedtuple
16
+ from torch.utils.data import dataloader
17
17
 
18
- import torch
19
- from msprobe.core.common.const import Const, FileCheckConst, MsgConst
18
+ from msprobe.core.common.const import Const, MsgConst
20
19
  from msprobe.core.common.exceptions import MsprobeException
21
- from msprobe.core.common.file_utils import FileChecker
22
- from msprobe.core.common.utils import get_real_step_or_rank
20
+ from msprobe.core.common.utils import check_token_range
21
+ from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
23
22
  from msprobe.pytorch.common.log import logger
24
- from msprobe.pytorch.common.utils import check_save_param
23
+ from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module
25
24
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
26
25
  from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
27
26
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
28
- from msprobe.pytorch.pt_config import parse_json_config
29
- from msprobe.pytorch.service import Service
30
- from torch.utils.data import dataloader
31
-
32
- ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
33
- "dump_path", "level", "model"])
34
-
27
+ from msprobe.pytorch.pytorch_service import PytorchService
28
+ from msprobe.pytorch.pt_config import parse_task_config
35
29
 
36
- class PrecisionDebugger:
37
- _instance = None
38
- tasks_not_need_debugger = [Const.GRAD_PROBE]
39
30
 
40
- def __new__(cls, *args, **kwargs):
41
- if cls._instance is None:
42
- cls._instance = super(PrecisionDebugger, cls).__new__(cls)
43
- cls._instance.config = None
44
- cls._instance.enable_dataloader = False
45
- return cls._instance
31
+ class PrecisionDebugger(BasePrecisionDebugger):
46
32
 
47
33
  def __init__(
48
34
  self,
@@ -53,90 +39,48 @@ class PrecisionDebugger:
53
39
  model=None,
54
40
  step=None
55
41
  ):
56
- if not hasattr(self, "initialized"):
57
- config_params = ConfigParameters(config_path,
58
- task,
59
- dump_path,
60
- level,
61
- model)
62
- self.check_input_params(config_params)
63
-
64
- self.initialized = True
65
- self.model = model
66
- common_config, task_config = parse_json_config(config_path, task)
67
- self.task = task if task else common_config.task
68
- if self.task == Const.GRAD_PROBE:
69
- self.gm = GradientMonitor(common_config, task_config)
70
- return
71
- if step is not None:
72
- common_config.step = get_real_step_or_rank(step, Const.STEP)
73
- self.config = DebuggerConfig(
74
- common_config, task_config, task, dump_path, level
75
- )
76
- self.service = Service(self.config)
77
- self.module_dumper = ModuleDumper(self.service)
78
- self.enable_dataloader = self.config.enable_dataloader
79
- if self.enable_dataloader:
80
- logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
81
- dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
42
+ if self.initialized:
43
+ return
44
+ super().__init__(config_path, task, dump_path, level, step)
45
+ self.model = model
46
+ if self.task == Const.GRAD_PROBE:
47
+ self.gm = GradientMonitor(self.common_config, self.task_config)
48
+ return
49
+ self.config = DebuggerConfig(
50
+ self.common_config, self.task_config, task, dump_path, level
51
+ )
52
+ self.service = PytorchService(self.config)
53
+ self.module_dumper = ModuleDumper(self.service)
54
+ self.ori_customer_func = {}
55
+ self.enable_dataloader = self.config.enable_dataloader
56
+ self.param_warning()
82
57
 
83
58
  @property
84
59
  def instance(self):
85
60
  return self._instance
86
61
 
87
62
  @staticmethod
88
- def check_input_params(args):
89
- if args.config_path is not None:
90
- if not isinstance(args.config_path, str):
91
- raise MsprobeException(
92
- MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
93
- file_checker = FileChecker(
94
- file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
95
- file_checker.common_check()
96
-
97
- if args.task is not None and args.task not in Const.TASK_LIST:
98
- raise MsprobeException(
99
- MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
100
-
101
- if args.dump_path is not None:
102
- if not isinstance(args.dump_path, str):
103
- raise MsprobeException(
104
- MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
105
-
106
- if args.level is not None and args.level not in Const.LEVEL_LIST:
107
- raise MsprobeException(
108
- MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
109
-
110
- if args.model is not None:
111
- logger.warning_on_rank_0(
112
- "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
113
- "It is recommended to pass the 'model' parameter in the start interface instead."
114
- )
63
+ def get_task_config(task, json_config):
64
+ return parse_task_config(task, json_config)
115
65
 
116
66
  @classmethod
117
- def start(cls, model=None):
118
- instance = cls._instance
119
- if not instance:
120
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
121
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
67
+ def start(cls, model=None, token_range=None):
68
+ instance = cls.get_instance()
69
+ if instance is None:
122
70
  return
123
- instance.config.check_model(instance, model)
71
+
72
+ check_token_range(token_range)
73
+ instance.config.check_model(instance, model, token_range)
74
+
124
75
  if instance.enable_dataloader:
125
76
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
126
77
  else:
127
- instance.service.start(instance.model)
128
-
129
- @classmethod
130
- def forward_backward_dump_end(cls):
131
- instance = cls._instance
132
- instance.stop()
78
+ instance.service.start(instance.model, token_range)
133
79
 
134
80
  @classmethod
135
81
  def stop(cls):
136
- instance = cls._instance
137
- if not instance:
138
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
139
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
82
+ instance = cls.get_instance()
83
+ if instance is None:
140
84
  return
141
85
  if instance.enable_dataloader:
142
86
  logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
@@ -145,9 +89,8 @@ class PrecisionDebugger:
145
89
 
146
90
  @classmethod
147
91
  def step(cls):
148
- if not cls._instance:
149
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
150
- if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
92
+ instance = cls.get_instance()
93
+ if instance is None:
151
94
  return
152
95
  cls._instance.service.step()
153
96
 
@@ -172,12 +115,23 @@ class PrecisionDebugger:
172
115
  return
173
116
  instance.service.save(variable, name, save_backward)
174
117
 
118
+ def param_warning(self):
119
+ if self.model is not None:
120
+ logger.warning_on_rank_0(
121
+ "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
122
+ "It is recommended to pass the 'model' parameter in the start interface instead."
123
+ )
124
+ if self.enable_dataloader:
125
+ logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
126
+ dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
127
+
175
128
 
176
129
  def module_dump(module, dump_name):
177
- if not isinstance(module, torch.nn.Module):
130
+ if not is_torch_nn_module(module):
178
131
  raise MsprobeException(
179
132
  MsprobeException.INVALID_PARAM_ERROR,
180
- f"the module argument in module_dump must be a torch.nn.Module subclass"
133
+ f"the module argument in module_dump must be a torch.nn.Module type, "
134
+ f"but currently there is an unsupported {type(module)} type."
181
135
  )
182
136
  if not isinstance(dump_name, str):
183
137
  raise MsprobeException(
@@ -0,0 +1,93 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import wraps
17
+
18
+ import torch
19
+ from torch.utils.hooks import BackwardHook
20
+
21
+ from msprobe.core.common.const import Const
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
23
+ from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.common.utils import is_float8_tensor
25
+
26
+
27
+ def wrap_setup_backward_hook(func):
28
+ def requires_clone(tensor):
29
+ return isinstance(tensor, torch.Tensor) and not is_float8_tensor(tensor) and \
30
+ tensor.requires_grad and torch.is_grad_enabled()
31
+
32
+ @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
33
+ def parse_tensor(item, tensor_list):
34
+ if requires_clone(item):
35
+ tensor_list.append(item)
36
+ elif isinstance(item, (list, tuple)):
37
+ for value in item:
38
+ parse_tensor(value, tensor_list)
39
+ elif isinstance(item, dict):
40
+ for value in item.values():
41
+ parse_tensor(value, tensor_list)
42
+
43
+ @recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
44
+ def rebuild_args(item, tensor_iter):
45
+ if requires_clone(item):
46
+ result = next(tensor_iter)
47
+ if hasattr(result, "_base") and result._base is not None:
48
+ if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
49
+ torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
50
+ return result
51
+ if isinstance(item, list):
52
+ for index, value in enumerate(item):
53
+ item[index] = rebuild_args(value, tensor_iter)
54
+ return item
55
+ if isinstance(item, dict):
56
+ for key, value in item.items():
57
+ item[key] = rebuild_args(value, tensor_iter)
58
+ return item
59
+ if isinstance(item, tuple):
60
+ if hasattr(item, '_fields'):
61
+ return type(item)(*[rebuild_args(i, tensor_iter) for i in item])
62
+ return type(item)([rebuild_args(i, tensor_iter) for i in item])
63
+ return item
64
+
65
+ @wraps(func)
66
+ def wrap_setup_hook_func(*args, **kwargs):
67
+ if len(args) < 2:
68
+ return func(*args, **kwargs)
69
+
70
+ actual_args = args[1]
71
+
72
+ tensor_list = []
73
+
74
+ parse_tensor(actual_args, tensor_list)
75
+
76
+ new_args = args[0], tuple(tensor_list)
77
+ hooked_tensors = func(*new_args, **kwargs)
78
+
79
+ tensor_iter = iter(hooked_tensors)
80
+ try:
81
+ new_data = rebuild_args(actual_args, tensor_iter)
82
+ except Exception as e:
83
+ logger.debug(f"Unsupported data in setup input/output hook. The detail info: {e}")
84
+ new_data = actual_args
85
+
86
+ return new_data
87
+
88
+ return wrap_setup_hook_func
89
+
90
+
91
+ def wrap_setup_input_output_hook():
92
+ BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
93
+ BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
@@ -13,74 +13,28 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import torch
17
- from msprobe.core.common.const import Const
18
- from msprobe.core.data_dump.scope import BaseScope
19
16
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.hook_module.api_registry import api_register
21
-
22
- torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
17
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
18
+ from msprobe.pytorch.hook_module.api_register import get_api_register
23
19
 
24
20
 
25
21
  class ModuleDumper:
26
22
  def __init__(self, service):
27
23
  self.service = service
28
- self.hook_handle_list = []
24
+ self.api_register = get_api_register()
29
25
 
30
26
  def start_module_dump(self, module, dump_name):
31
- api_register.api_originality()
32
- self.register_hook(module, dump_name)
33
-
34
- def stop_module_dump(self):
35
- api_register.api_modularity()
36
- for hook_handle in self.hook_handle_list:
37
- if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
38
- hook_handle.remove()
39
- self.hook_handle_list.clear()
27
+ if hasattr(module, 'msprobe_hook') and not hasattr(module, 'msprobe_module_dump'):
28
+ logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
29
+ return
40
30
 
41
- def register_hook(self, module, dump_name):
42
- prefix_name = (
43
- BaseScope.Module_Type_Module + Const.SEP +
44
- dump_name + Const.SEP +
45
- module.__class__.__name__ + Const.SEP
46
- )
47
- module_processor = self.service.module_processor
48
- _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
49
- BaseScope.Module_Type_Module,
50
- prefix_name
51
- )
31
+ ModuleProcesser.enable_module_dump = True
32
+ self.api_register.restore_all_api()
33
+ if not hasattr(module, 'msprobe_module_dump'):
34
+ self.service.module_processor.register_module_hook(module, self.service.build_hook,
35
+ recursive=False, module_names=[dump_name])
36
+ setattr(module, 'msprobe_module_dump', True)
52
37
 
53
- if module_processor.has_register_backward_hook(module):
54
- logger.warning(
55
- f"The {dump_name} module has registered deprecated register_backward_hook,"
56
- f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
57
- )
58
- if torch_version_above_or_equal_2:
59
- forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
60
- else:
61
- if not module_processor.has_register_backward_hook(module):
62
- backward_hook_handle = module.register_full_backward_hook(
63
- module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
64
- )
65
- self.hook_handle_list.append(backward_hook_handle)
66
- forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
67
- self.hook_handle_list.append(forward_hook_handle)
68
- if not module_processor.has_register_backward_hook(module):
69
- backward_hook_handle = module.register_full_backward_hook(backward_hook)
70
- self.hook_handle_list.append(backward_hook_handle)
71
-
72
- forward_pre_hook_handle = module.register_forward_pre_hook(
73
- module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
74
- )
75
- forward_hook_handle = module.register_forward_hook(
76
- module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
77
- )
78
- self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
79
- if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
80
- backward_pre_hook_handle = module.register_full_backward_pre_hook(
81
- module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
82
- )
83
- backward_hook_handle = module.register_full_backward_hook(
84
- module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
85
- )
86
- self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
38
+ def stop_module_dump(self):
39
+ ModuleProcesser.enable_module_dump = False
40
+ self.api_register.register_all_api()