mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,470 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import functools
17
- import os
18
- from collections import namedtuple, defaultdict
19
-
20
- import torch
21
- from msprobe.core.common.const import Const
22
- from msprobe.core.common.exceptions import DistributedNotInitializedError
23
- from msprobe.core.common.file_utils import create_directory
24
- from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation
25
- from msprobe.core.data_dump.data_collector import build_data_collector
26
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
- from msprobe.core.data_dump.scope import BaseScope
28
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
29
- from msprobe.pytorch.common.log import logger
30
- from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
31
- from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
- from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
33
- from msprobe.pytorch.hook_module.api_registry import api_register
34
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
35
- from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
36
-
37
- torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
38
- if torch_version_above_or_equal_2:
39
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
40
-
41
- HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
42
-
43
-
44
- class Service:
45
- def __init__(self, config):
46
- self.model = None
47
- self.config = config
48
- self.data_collector = build_data_collector(config)
49
- self.module_processor = ModuleProcesser(self.data_collector.scope)
50
- self.switch = False
51
- self.inner_switch = False
52
- self.current_iter = 0
53
- self.first_start = True
54
- self.current_rank = None
55
- self.dump_iter_dir = None
56
- self.should_stop_service = False
57
- self.attl = None
58
- self.params_grad_info = {}
59
- self.hook_handle_dict = {}
60
- # 提前注册,确保注册尽可能多的API hook
61
- self.register_api_hook()
62
- self.init_for_debug_level()
63
-
64
- def build_hook(self, module_type, name):
65
- def pre_hook(api_or_module_name, module, args, kwargs):
66
- if not self.should_execute_hook(module_type, module, True):
67
- return args, kwargs
68
- is_recompute = is_recomputation()
69
-
70
- self.inner_switch = True
71
- if module_type == BaseScope.Module_Type_Module:
72
- api_or_module_name = module.mindstudio_reserved_name[-1]
73
- else:
74
- module.forward_data_collected = True
75
- HOOKModule.add_module_count(name)
76
- self.data_collector.update_api_or_module_name(api_or_module_name)
77
-
78
- if self.config.online_run_ut:
79
- self.inner_switch = False
80
- return None, None
81
- if self.data_collector:
82
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
83
- self.data_collector.forward_input_data_collect(
84
- api_or_module_name,
85
- module,
86
- pid,
87
- module_input_output,
88
- is_recompute
89
- )
90
-
91
- self.inner_switch = False
92
- return args, kwargs
93
-
94
- def grad_hook(module, ori_name, param_name):
95
- def hook_fn(grad):
96
- if not self.should_execute_hook(module_type, module, False):
97
- return grad
98
- self.inner_switch = True
99
- self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
100
- self.inner_switch = False
101
- return grad
102
-
103
- return hook_fn
104
-
105
- def register_param_hook(ori_name, module, params_dict):
106
- '''
107
- 注册参数hook
108
- '''
109
- # data_mode为forward时,不注册参数hook
110
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
111
- for param_name, param in params_dict.items():
112
- if param.requires_grad:
113
- name = ori_name + Const.SEP + param_name
114
- old_handle = self.hook_handle_dict.get(name)
115
- if old_handle and hasattr(old_handle, "remove"):
116
- old_handle.remove()
117
- handle = param.register_hook(grad_hook(module, ori_name, param_name))
118
- self.hook_handle_dict[name] = handle
119
-
120
- def init_params_grad_info(module, params_dict):
121
- '''
122
- 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
123
- '''
124
- if not params_dict:
125
- return
126
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
127
- grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
128
- # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
129
- if not self.params_grad_info.get(grad_name):
130
- data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
131
- # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
132
- if data_info.get(grad_name):
133
- # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
134
- self.data_collector.handle_data(grad_name, data_info,
135
- flush=self.data_collector.data_processor.is_terminated)
136
- # 记录当前模块的参数梯度信息已占位
137
- self.params_grad_info[grad_name] = True
138
-
139
- def forward_hook(api_or_module_name, module, args, kwargs, output):
140
- if not self.should_execute_hook(module_type, module, True):
141
- return None
142
- is_recompute = is_recomputation()
143
-
144
- self.inner_switch = True
145
- if self.config.online_run_ut:
146
- self.data_collector.update_api_or_module_name(api_or_module_name)
147
- if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
148
- return None
149
- api_data = ApiData(
150
- api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)],
151
- args,
152
- kwargs,
153
- output,
154
- self.current_iter,
155
- self.current_rank
156
- )
157
- self.attl_send(api_data)
158
- self.inner_switch = False
159
- return None
160
-
161
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
162
- if module_type == BaseScope.Module_Type_Module:
163
- api_or_module_name = module.mindstudio_reserved_name[-1]
164
- self.data_collector.update_api_or_module_name(api_or_module_name)
165
- params_dict = {}
166
- if self.config.task != Const.STRUCTURE:
167
- params_dict = {
168
- key.split(Const.SEP)[-1]: value
169
- for key, value in module.named_parameters(recurse=False)
170
- }
171
- setattr(module_input_output, Const.PARAMS, params_dict)
172
- # 判断是否需要注册参数hook
173
- if params_dict:
174
- ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
175
- grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
176
- # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
177
- setattr(module, 'params_grad_name', grad_name)
178
- register_param_hook(ori_name, module, params_dict)
179
- self.data_collector.forward_data_collect(
180
- api_or_module_name,
181
- module,
182
- pid,
183
- module_input_output,
184
- is_recompute
185
- )
186
- init_params_grad_info(module, params_dict)
187
- else:
188
- self.data_collector.update_api_or_module_name(api_or_module_name)
189
- self.data_collector.forward_output_data_collect(
190
- api_or_module_name,
191
- module,
192
- pid,
193
- module_input_output,
194
- is_recompute
195
- )
196
-
197
- if self.data_collector.if_return_forward_new_output():
198
- forward_new_output = self.data_collector.get_forward_new_output()
199
- self.inner_switch = False
200
- return forward_new_output
201
- self.inner_switch = False
202
- return output
203
-
204
- def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
205
- return forward_hook(api_or_module_name, module, args, {}, output)
206
-
207
- def backward_hook(api_or_module_name, module, grad_input, grad_output):
208
- if not self.should_execute_hook(module_type, module, False):
209
- return
210
- is_recompute = is_recomputation()
211
-
212
- self.inner_switch = True
213
- if module_type == BaseScope.Module_Type_Module:
214
- api_or_module_name = module.mindstudio_reserved_name[-1]
215
- self.data_collector.update_api_or_module_name(api_or_module_name)
216
-
217
- if self.config.online_run_ut:
218
- self.inner_switch = False
219
- return
220
-
221
- if self.data_collector:
222
- # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
223
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
224
- self.data_collector.backward_data_collect(
225
- api_or_module_name,
226
- module,
227
- pid,
228
- module_input_output,
229
- is_recompute
230
- )
231
- self.inner_switch = False
232
-
233
- pid = os.getpid()
234
- full_forward_name = None
235
- full_backward_name = None
236
- if module_type == BaseScope.Module_Type_API:
237
- full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
238
- full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD
239
- pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name)
240
- forward_hook_fn = functools.partial(forward_hook, full_forward_name)
241
- backward_hook_fn = functools.partial(backward_hook, full_backward_name)
242
- forward_hook_torch_version_below_2_fn = functools.partial(
243
- forward_hook_torch_version_below_2,
244
- full_forward_name
245
- )
246
- return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
247
-
248
- def start(self, model):
249
- if self.config.level == Const.LEVEL_DEBUG:
250
- return
251
- if self.need_stop_service():
252
- return
253
-
254
- self.model = model
255
- if self.first_start:
256
- try:
257
- self.current_rank = get_rank_if_initialized()
258
- except DistributedNotInitializedError:
259
- self.current_rank = None
260
- self.attl_init()
261
-
262
- if self.config.rank and self.current_rank not in self.config.rank:
263
- return
264
- self.register_module_hook()
265
- if self.config.level == Const.LEVEL_MIX:
266
- register_optimizer_hook(self.data_collector)
267
- self.first_start = False
268
- if self.config.online_run_ut and torch_version_above_or_equal_2:
269
- run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
270
- self.switch = True
271
- logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
272
- if not self.config.online_run_ut:
273
- self.create_dirs()
274
- logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
275
-
276
- def stop(self):
277
- if self.config.level == Const.LEVEL_DEBUG:
278
- return
279
- if self.should_stop_service:
280
- return
281
- if self.config.step and self.current_iter not in self.config.step:
282
- return
283
- if self.config.rank and self.current_rank not in self.config.rank:
284
- return
285
- self.switch = False
286
- if self.config.level == Const.LEVEL_L2:
287
- return
288
- if self.config.online_run_ut and torch_version_above_or_equal_2:
289
- run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
290
- return
291
- if self.config.async_dump:
292
- self.data_collector.fill_stack_tensor_data()
293
- if self.config.task == Const.TENSOR:
294
- self.data_collector.data_processor.dump_async_data()
295
- self.data_collector.write_json()
296
-
297
- def step(self):
298
- if self.config.level == Const.LEVEL_DEBUG:
299
- return
300
- if self.should_stop_service:
301
- return
302
- if self.config.async_dump:
303
- self.data_collector.fill_stack_tensor_data()
304
- if self.config.task == Const.TENSOR:
305
- self.data_collector.data_processor.dump_async_data()
306
- self.data_collector.write_json()
307
- self.current_iter += 1
308
- self.data_collector.update_iter(self.current_iter)
309
- self.reset_status()
310
-
311
- def need_stop_service(self):
312
- if self.should_stop_service:
313
- return True
314
- end_service = self.config.step and self.current_iter > max(self.config.step) or \
315
- self.data_collector and self.data_collector.data_processor.is_terminated
316
- if end_service:
317
- if self.config.online_run_ut:
318
- # send stop signal if online_run_ut
319
- self.attl_stop()
320
- self.switch = False
321
- self.should_stop_service = True
322
- print_tools_ends_info()
323
- return True
324
- if self.config.step and self.current_iter not in self.config.step:
325
- return True
326
- return False
327
-
328
- def should_execute_hook(self, hook_type, module, is_forward):
329
- is_module_hook = hook_type == BaseScope.Module_Type_Module
330
- if is_module_hook and not self.switch:
331
- return False
332
- elif not is_module_hook and is_forward and not self.switch:
333
- return False
334
- elif not is_module_hook and not is_forward and not module.forward_data_collected:
335
- return False
336
-
337
- if self.inner_switch:
338
- return False
339
- if not self.data_collector or self.data_collector.data_processor.is_terminated:
340
- return False
341
- return True
342
-
343
- def create_dirs(self):
344
- create_directory(self.config.dump_path)
345
- self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
346
- cur_rank = self.current_rank if self.current_rank is not None else ''
347
- if self.config.level == Const.LEVEL_L2:
348
- create_directory(self.dump_iter_dir)
349
- kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
350
- self.config.kernel_config_path = kernel_config_path
351
- return
352
-
353
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
354
- create_directory(dump_dir)
355
- if self.config.task in self.data_collector.tasks_need_tensor_data:
356
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
357
- create_directory(dump_data_dir)
358
- else:
359
- dump_data_dir = None
360
-
361
- dump_path_aggregation = DumpPathAggregation()
362
- dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
363
- dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
364
- dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
365
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
366
- dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
367
- self.data_collector.update_dump_paths(dump_path_aggregation)
368
- self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
369
-
370
- def register_api_hook(self):
371
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
372
- logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
373
- api_register.initialize_hook(
374
- functools.partial(self.build_hook, BaseScope.Module_Type_API),
375
- self.config.online_run_ut
376
- )
377
- api_register.api_modularity()
378
-
379
- def register_module_hook(self):
380
- if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
381
- logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
382
- self.module_processor.register_module_hook(self.model, self.build_hook)
383
-
384
- def attl_init(self):
385
- if self.config.online_run_ut:
386
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
387
- attl_config = ATTLConfig(is_benchmark_device=False,
388
- connect_ip=self.config.host,
389
- connect_port=self.config.port,
390
- nfs_path=self.config.nfs_path,
391
- tls_path=self.config.tls_path)
392
- need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank
393
- self.attl = ATTL('npu', attl_config, need_dump=need_dump)
394
- if self.config.nfs_path:
395
- self.attl.upload("start")
396
-
397
- def attl_send(self, api_data):
398
- logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}")
399
- api_type, _, _ = api_data.name.split(Const.SEP)
400
- if api_type in [Const.DISTRIBUTED]:
401
- logger.info(f"api {api_data.name} is not supported, skip")
402
- return
403
- if self.config.nfs_path:
404
- self.attl.upload(api_data)
405
- else:
406
- self.attl.send(api_data)
407
-
408
- def attl_stop(self):
409
- if self.config.nfs_path:
410
- self.attl.upload("end")
411
- elif self.attl.socket_manager is not None:
412
- logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
413
- self.attl.socket_manager.send_stop_signal()
414
-
415
- def reset_status(self):
416
- ModuleProcesser.reset_module_stats()
417
- HOOKModule.reset_module_stats()
418
- self.data_collector.reset_status()
419
- self.params_grad_info.clear()
420
-
421
- if self.config.level == Const.LEVEL_L2:
422
- self.data_collector.data_processor.reset_status()
423
- return
424
- if self.config.step and self.current_iter not in self.config.step:
425
- return
426
- if self.config.rank and self.current_rank not in self.config.rank:
427
- return
428
-
429
- def init_for_debug_level(self):
430
- if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
431
- return
432
- try:
433
- self.current_rank = get_rank_if_initialized()
434
- except DistributedNotInitializedError:
435
- self.current_rank = None
436
-
437
- # dir: dump_path -- rank{} -- debug.json
438
- self.dump_iter_dir = self.config.dump_path
439
- cur_rank = self.current_rank if self.current_rank is not None else ''
440
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
441
- create_directory(dump_dir)
442
- if self.config.task in self.data_collector.tasks_need_tensor_data:
443
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
444
- create_directory(dump_data_dir)
445
- else:
446
- dump_data_dir = None
447
-
448
- dump_path_aggregation = DumpPathAggregation()
449
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
450
- dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
451
- self.data_collector.update_dump_paths(dump_path_aggregation)
452
- self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
453
-
454
- self.debug_variable_counter = defaultdict(int)
455
-
456
- def save(self, variable, name, save_backward):
457
- if self.config.level != Const.LEVEL_DEBUG:
458
- return
459
- count = self.debug_variable_counter[name]
460
- self.debug_variable_counter[name] += 1
461
-
462
- name_with_count = f"{name}.{count}"
463
- grad_name_with_count = f"{name}_grad.{count}"
464
-
465
- # forward save
466
- self.data_collector.debug_data_collect_forward(variable, name_with_count)
467
-
468
- # backward save
469
- if save_backward:
470
- self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
File without changes