mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,17 +12,18 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import argparse
18
17
  import os
19
18
  from collections import namedtuple
20
19
 
20
+ from msprobe.core.common.file_utils import create_directory
21
+ from msprobe.pytorch.parse_tool.lib.compare import Compare
21
22
  from msprobe.pytorch.parse_tool.lib.config import Const
23
+ from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
22
24
  from msprobe.pytorch.parse_tool.lib.utils import Util
23
- from msprobe.pytorch.parse_tool.lib.compare import Compare
24
25
  from msprobe.pytorch.parse_tool.lib.visualization import Visualization
25
- from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
26
- from msprobe.core.common.file_utils import create_directory
26
+
27
27
 
28
28
  class ParseTool:
29
29
  def __init__(self):
@@ -117,7 +117,8 @@ class ParseTool:
117
117
  self.util.check_path_valid(args.golden_dump_path)
118
118
  self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX)
119
119
  self.util.check_file_path_format(args.golden_dump_path, Const.NPY_SUFFIX)
120
- compare_data_args = namedtuple('compare_data_args', ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count'])
120
+ compare_data_args = namedtuple('compare_data_args',
121
+ ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count'])
121
122
  compare_data_args.__new__.__defaults__ = (False, 0.001, 0.001, 20)
122
123
  res = compare_data_args(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count)
123
124
  self.compare.compare_data(res)
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,24 +12,24 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
16
+ import hashlib
17
17
  import os
18
18
  import re
19
- import sys
20
19
  import subprocess
21
- import hashlib
20
+ import sys
22
21
  import time
23
- import numpy as np
24
22
  from collections import namedtuple
25
- from msprobe.pytorch.parse_tool.lib.config import Const
26
- from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc
27
- from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
28
- from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
29
- check_path_executable, check_path_owner_consistent
23
+
24
+ import numpy as np
30
25
  from msprobe.core.common.const import FileCheckConst
26
+ from msprobe.core.common.file_utils import change_mode, check_other_user_writable, \
27
+ check_path_executable, check_path_owner_consistent
31
28
  from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type, os_walk_for_files
32
29
  from msprobe.pytorch.common.log import logger
33
-
30
+ from msprobe.pytorch.parse_tool.lib.config import Const
31
+ from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc
32
+ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
34
33
 
35
34
  try:
36
35
  from rich.traceback import install
@@ -135,7 +134,7 @@ class Util:
135
134
  zero_mask = (data == 0)
136
135
  data[zero_mask] += np.finfo(float).eps
137
136
  return data
138
-
137
+
139
138
  @staticmethod
140
139
  def dir_contains_only(path, endfix):
141
140
  files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
@@ -143,11 +142,11 @@ class Util:
143
142
  if not file['file'].endswith(endfix):
144
143
  return False
145
144
  return True
146
-
145
+
147
146
  @staticmethod
148
147
  def localtime_str():
149
148
  return time.strftime("%Y%m%d%H%M%S", time.localtime())
150
-
149
+
151
150
  @staticmethod
152
151
  def change_filemode_safe(path):
153
152
  change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -208,7 +207,7 @@ class Util:
208
207
 
209
208
  def list_numpy_files(self, path, extern_pattern=''):
210
209
  return self.list_file_with_pattern(path, Const.NUMPY_PATTERN, extern_pattern,
211
- self._gen_numpy_file_info)
210
+ self._gen_numpy_file_info)
212
211
 
213
212
  def create_columns(self, content):
214
213
  if not Columns:
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,14 +12,14 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import json
18
- import numpy as np
19
17
 
18
+ import numpy as np
19
+ from msprobe.core.common.file_utils import FileOpen, load_npy, save_npy_to_txt
20
20
  from msprobe.pytorch.parse_tool.lib.config import Const
21
- from msprobe.pytorch.parse_tool.lib.utils import Util
22
21
  from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
23
- from msprobe.core.common.file_utils import FileOpen, load_npy, save_npy_to_txt
22
+ from msprobe.pytorch.parse_tool.lib.utils import Util
24
23
 
25
24
 
26
25
  class Visualization:
@@ -77,7 +76,7 @@ class Visualization:
77
76
  self.util.log.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2]))
78
77
  self.util.log.info(" {}".format(item[3]))
79
78
  continue
80
- if len(msg) > 5 and len(msg[5]) >= 3:
79
+ if len(msg) > 5 and len(msg[5]) >= 3:
81
80
  summery_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \
82
81
  .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2])
83
82
  if not title_printed:
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,7 +19,7 @@ from collections import namedtuple
19
19
 
20
20
  import torch
21
21
  from msprobe.core.common.const import Const
22
- from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
22
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
23
23
  from msprobe.core.common.file_utils import create_directory
24
24
  from msprobe.core.common.utils import print_tools_ends_info
25
25
  from msprobe.core.data_dump.data_collector import build_data_collector
@@ -29,10 +29,10 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
29
29
  from msprobe.pytorch.common.log import logger
30
30
  from msprobe.pytorch.common.utils import get_rank_if_initialized
31
31
  from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
- from msprobe.pytorch.hook_module import remove_dropout
32
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
33
33
  from msprobe.pytorch.hook_module.api_registry import api_register
34
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
35
- from msprobe.pytorch.module_processer import ModuleProcesser
35
+ from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
36
36
 
37
37
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
38
38
  if torch_version_above_or_equal_2:
@@ -48,100 +48,175 @@ class Service:
48
48
  self.data_collector = build_data_collector(config)
49
49
  self.module_processor = ModuleProcesser(self.data_collector.scope)
50
50
  self.switch = False
51
+ self.inner_switch = False
51
52
  self.current_iter = 0
52
53
  self.first_start = True
53
54
  self.current_rank = None
54
55
  self.dump_iter_dir = None
55
56
  self.should_stop_service = False
56
57
  self.attl = None
57
-
58
- @staticmethod
59
- def forward_backward_dump_end():
60
- logger.info_on_rank_0("Data needed ends here.")
61
- api_register.api_originality()
62
-
63
- @staticmethod
64
- def is_registered_backward_hook(module):
65
- if hasattr(module, '_backward_hooks') and \
66
- len(module._backward_hooks) > 0 and \
67
- module._is_full_backward_hook is False:
68
- return True
69
- return False
70
-
71
- def check_register_full_backward_hook(self, module):
72
- if self.is_registered_backward_hook(module):
73
- module._backward_hooks.clear()
74
- module._is_full_backward_hook = None
75
- logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
58
+ self.params_grad_info = {}
59
+ # 提前注册,确保注册尽可能多的API hook
60
+ self.register_api_hook()
76
61
 
77
62
  def build_hook(self, module_type, name):
78
63
  def pre_hook(api_or_module_name, module, args, kwargs):
79
- if not self.should_execute_hook():
64
+ if not self.should_execute_hook(module_type, module, True):
80
65
  return args, kwargs
81
66
 
67
+ self.inner_switch = True
82
68
  if module_type == BaseScope.Module_Type_Module:
83
- api_or_module_name = module.mindstudio_reserved_name
69
+ api_or_module_name = module.mindstudio_reserved_name[-1]
70
+ else:
71
+ module.forward_data_collected = True
72
+ HOOKModule.add_module_count(name)
84
73
  self.data_collector.update_api_or_module_name(api_or_module_name)
85
74
 
86
75
  if self.config.online_run_ut:
76
+ self.inner_switch = False
87
77
  return None, None
88
78
  if self.data_collector:
89
79
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
90
- self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output)
80
+ self.data_collector.forward_input_data_collect(api_or_module_name, module, pid, module_input_output)
81
+
82
+ self.inner_switch = False
91
83
  return args, kwargs
92
84
 
85
+ def grad_hook(module, ori_name, param_name):
86
+ def hook_fn(grad):
87
+ if not self.should_execute_hook(module_type, module, False):
88
+ return grad
89
+ self.inner_switch = True
90
+ self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
91
+ self.inner_switch = False
92
+ return grad
93
+
94
+ return hook_fn
95
+
96
+ def register_param_hook(ori_name, module, params_dict):
97
+ '''
98
+ 注册参数hook
99
+ '''
100
+ # data_mode为forward时,不注册参数hook
101
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
102
+ for param_name, param in params_dict.items():
103
+ if param.requires_grad:
104
+ param.register_hook(grad_hook(module, ori_name, param_name))
105
+
106
+ def init_params_grad_info(module, params_dict):
107
+ '''
108
+ 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
109
+ '''
110
+ if not params_dict:
111
+ return
112
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
113
+ grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
114
+ # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
115
+ if not self.params_grad_info.get(grad_name):
116
+ data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
117
+ # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
118
+ if data_info.get(grad_name):
119
+ # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
120
+ self.data_collector.handle_data(grad_name, data_info,
121
+ flush=self.data_collector.data_processor.is_terminated)
122
+ # 记录当前模块的参数梯度信息已占位
123
+ self.params_grad_info[grad_name] = True
124
+
93
125
  def forward_hook(api_or_module_name, module, args, kwargs, output):
94
- if not self.should_execute_hook():
126
+ if not self.should_execute_hook(module_type, module, True):
95
127
  return None
96
128
 
97
- if module_type == BaseScope.Module_Type_Module:
98
- api_or_module_name = module.mindstudio_reserved_name
99
- self.data_collector.update_api_or_module_name(api_or_module_name)
100
-
129
+ self.inner_switch = True
101
130
  if self.config.online_run_ut:
131
+ self.data_collector.update_api_or_module_name(api_or_module_name)
102
132
  if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
103
133
  return None
104
- api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank)
134
+ api_data = ApiData(
135
+ api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)],
136
+ args,
137
+ kwargs,
138
+ output,
139
+ self.current_iter,
140
+ self.current_rank
141
+ )
105
142
  self.attl_send(api_data)
143
+ self.inner_switch = False
106
144
  return None
107
145
 
108
- if self.data_collector:
109
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
110
- self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
111
- if self.data_collector.if_return_forward_new_output():
112
- return self.data_collector.get_forward_new_output()
146
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
147
+ if module_type == BaseScope.Module_Type_Module:
148
+ api_or_module_name = module.mindstudio_reserved_name[-1]
149
+ self.data_collector.update_api_or_module_name(api_or_module_name)
150
+ params_dict = {key.split(Const.SEP)[-1]: value for key, value in module.named_parameters(recurse=False)}
151
+ setattr(module_input_output, Const.PARAMS, params_dict)
152
+ # 判断是否需要注册参数hook
153
+ if not hasattr(module, 'params_grad_name') and params_dict:
154
+ ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
155
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
156
+ # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
157
+ setattr(module, 'params_grad_name', grad_name)
158
+ register_param_hook(ori_name, module, params_dict)
159
+ self.data_collector.forward_data_collect(
160
+ api_or_module_name,
161
+ module,
162
+ pid,
163
+ module_input_output
164
+ )
165
+ init_params_grad_info(module, params_dict)
166
+ else:
167
+ self.data_collector.update_api_or_module_name(api_or_module_name)
168
+ self.data_collector.forward_output_data_collect(
169
+ api_or_module_name,
170
+ module,
171
+ pid,
172
+ module_input_output
173
+ )
174
+
175
+ if self.data_collector.if_return_forward_new_output():
176
+ forward_new_output = self.data_collector.get_forward_new_output()
177
+ self.inner_switch = False
178
+ return forward_new_output
179
+ self.inner_switch = False
113
180
  return output
114
181
 
115
182
  def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
116
183
  return forward_hook(api_or_module_name, module, args, {}, output)
117
184
 
118
185
  def backward_hook(api_or_module_name, module, grad_input, grad_output):
119
- if not self.should_execute_hook():
186
+ if not self.should_execute_hook(module_type, module, False):
120
187
  return
121
188
 
189
+ self.inner_switch = True
122
190
  if module_type == BaseScope.Module_Type_Module:
123
- api_or_module_name = module.mindstudio_reserved_name
191
+ api_or_module_name = module.mindstudio_reserved_name[-1]
124
192
  self.data_collector.update_api_or_module_name(api_or_module_name)
125
193
 
126
194
  if self.config.online_run_ut:
195
+ self.inner_switch = False
127
196
  return
128
197
 
129
198
  if self.data_collector:
130
199
  # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
131
200
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
132
201
  self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
202
+ self.inner_switch = False
133
203
 
134
204
  pid = os.getpid()
135
- forward_name_template = name + Const.FORWARD
136
- backward_name_template = name + Const.BACKWARD
137
- pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
138
- forward_hook_fn = functools.partial(forward_hook, forward_name_template)
139
- backward_hook_fn = functools.partial(backward_hook, backward_name_template)
140
- forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
141
- forward_name_template)
205
+ full_forward_name = None
206
+ full_backward_name = None
207
+ if module_type == BaseScope.Module_Type_API:
208
+ full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
209
+ full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD
210
+ pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name)
211
+ forward_hook_fn = functools.partial(forward_hook, full_forward_name)
212
+ backward_hook_fn = functools.partial(backward_hook, full_backward_name)
213
+ forward_hook_torch_version_below_2_fn = functools.partial(
214
+ forward_hook_torch_version_below_2,
215
+ full_forward_name
216
+ )
142
217
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
143
218
 
144
- def start(self, model, api_origin=False):
219
+ def start(self, model):
145
220
  if self.need_stop_service():
146
221
  return
147
222
 
@@ -155,10 +230,8 @@ class Service:
155
230
 
156
231
  if self.config.rank and self.current_rank not in self.config.rank:
157
232
  return
158
- self.register_hook_new()
233
+ self.register_module_hook()
159
234
  self.first_start = False
160
- if api_origin:
161
- api_register.api_modularity()
162
235
  if self.config.online_run_ut and torch_version_above_or_equal_2:
163
236
  run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
164
237
  self.switch = True
@@ -170,30 +243,31 @@ class Service:
170
243
  def stop(self):
171
244
  if self.should_stop_service:
172
245
  return
173
- if self.config.level == "L2":
174
- return
175
246
  if self.config.step and self.current_iter not in self.config.step:
176
247
  return
177
248
  if self.config.rank and self.current_rank not in self.config.rank:
178
249
  return
179
250
  self.switch = False
251
+ if self.config.level == Const.LEVEL_L2:
252
+ return
180
253
  if self.config.online_run_ut and torch_version_above_or_equal_2:
181
254
  run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
182
255
  return
256
+ if self.config.async_dump:
257
+ self.data_collector.fill_stack_tensor_data()
258
+ self.data_collector.data_processor.dump_async_data()
183
259
  self.data_collector.write_json()
184
260
 
185
261
  def step(self):
186
262
  if self.should_stop_service:
187
263
  return
264
+ if self.config.async_dump:
265
+ self.data_collector.fill_stack_tensor_data()
266
+ self.data_collector.data_processor.dump_async_data()
267
+ self.data_collector.write_json()
188
268
  self.current_iter += 1
189
269
  self.data_collector.update_iter(self.current_iter)
190
-
191
- ModuleProcesser.reset_module_stats()
192
- HOOKModule.reset_module_stats()
193
- self.data_collector.data_writer.reset_cache()
194
-
195
- if self.config.level == Const.LEVEL_L2:
196
- self.data_collector.data_processor.reset_status()
270
+ self.reset_status()
197
271
 
198
272
  def need_stop_service(self):
199
273
  if self.should_stop_service:
@@ -204,8 +278,6 @@ class Service:
204
278
  if self.config.online_run_ut:
205
279
  # send stop signal if online_run_ut
206
280
  self.attl_stop()
207
- if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
208
- api_register.api_originality()
209
281
  self.switch = False
210
282
  self.should_stop_service = True
211
283
  print_tools_ends_info()
@@ -214,10 +286,18 @@ class Service:
214
286
  return True
215
287
  return False
216
288
 
217
- def should_execute_hook(self):
218
- if not self.switch:
289
+ def should_execute_hook(self, hook_type, module, is_forward):
290
+ is_module_hook = hook_type == BaseScope.Module_Type_Module
291
+ if is_module_hook and not self.switch:
292
+ return False
293
+ elif not is_module_hook and is_forward and not self.switch:
219
294
  return False
220
- if self.data_collector and self.data_collector.data_processor.is_terminated:
295
+ elif not is_module_hook and not is_forward and not module.forward_data_collected:
296
+ return False
297
+
298
+ if self.inner_switch:
299
+ return False
300
+ if not self.data_collector or self.data_collector.data_processor.is_terminated:
221
301
  return False
222
302
  return True
223
303
 
@@ -244,50 +324,26 @@ class Service:
244
324
  construct_file_path = os.path.join(dump_dir, "construct.json")
245
325
  free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
246
326
  self.data_collector.update_dump_paths(
247
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path)
248
-
249
- def register_hook_new(self):
250
- logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
251
- if self.config.level in ["L0", "mix"]:
252
- if self.model is None:
253
- logger.error_log_with_exp("The model is None.", MsprobeException.INVALID_PARAM_ERROR)
254
- logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available")
255
- for name, module in self.model.named_modules():
256
- if module == self.model:
257
- continue
258
- prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
259
- module.__class__.__name__ + Const.SEP
260
-
261
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
262
- BaseScope.Module_Type_Module, prefix)
263
- if torch_version_above_or_equal_2:
264
- module.register_forward_hook(forward_hook, with_kwargs=True)
265
- else:
266
- self.check_register_full_backward_hook(module)
267
- module.register_full_backward_hook(
268
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
269
- module.register_forward_hook(forward_hook_torch_version_below_2)
270
- self.check_register_full_backward_hook(module)
271
- module.register_full_backward_hook(backward_hook)
272
-
273
- module.register_forward_pre_hook(
274
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
275
- module.register_forward_hook(
276
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
277
- if torch_version_above_or_equal_2:
278
- module.register_full_backward_pre_hook(
279
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
280
- self.check_register_full_backward_hook(module)
281
- module.register_full_backward_hook(
282
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
283
-
284
- if self.config.level in ["mix", "L1", "L2"]:
285
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
286
- self.config.online_run_ut)
327
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path
328
+ )
329
+ self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
330
+
331
+ def register_api_hook(self):
332
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
333
+ logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
334
+ api_register.initialize_hook(
335
+ functools.partial(self.build_hook, BaseScope.Module_Type_API),
336
+ self.config.online_run_ut
337
+ )
287
338
  api_register.api_modularity()
288
339
 
289
- if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
290
- remove_dropout()
340
+ if self.config.level == Const.LEVEL_MIX:
341
+ register_optimizer_hook(self.data_collector)
342
+
343
+ def register_module_hook(self):
344
+ if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
345
+ logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
346
+ self.module_processor.register_module_hook(self.model, self.build_hook)
291
347
 
292
348
  def attl_init(self):
293
349
  if self.config.online_run_ut:
@@ -319,3 +375,17 @@ class Service:
319
375
  elif self.attl.socket_manager is not None:
320
376
  logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
321
377
  self.attl.socket_manager.send_stop_signal()
378
+
379
+ def reset_status(self):
380
+ ModuleProcesser.reset_module_stats()
381
+ HOOKModule.reset_module_stats()
382
+ self.data_collector.data_writer.reset_cache()
383
+ self.params_grad_info.clear()
384
+
385
+ if self.config.level == Const.LEVEL_L2:
386
+ self.data_collector.data_processor.reset_status()
387
+ return
388
+ if self.config.step and self.current_iter not in self.config.step:
389
+ return
390
+ if self.config.rank and self.current_rank not in self.config.rank:
391
+ return