mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -13,17 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import concurrent
17
+ import copy
16
18
  import csv
17
19
  import os
18
- import copy
19
20
  import threading
20
21
  import traceback
21
22
  from datetime import datetime, timezone, timedelta
22
23
 
23
24
  from msprobe.core.common.const import Const, FileCheckConst
24
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
25
- from msprobe.core.common.log import logger
26
25
  from msprobe.core.common.decorator import recursion_depth_decorator
26
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, check_path_before_create
27
+ from msprobe.core.common.log import logger
27
28
 
28
29
  lock = threading.Lock()
29
30
 
@@ -39,6 +40,7 @@ class DataWriter:
39
40
  self.debug_file_path = None
40
41
  self.dump_error_info_path = None
41
42
  self.flush_size = 1000
43
+ self.md5_flush_size = 5000
42
44
  self.larger_flush_size = 20000
43
45
  self.cache_data = {}
44
46
  self.cache_stack = {}
@@ -46,6 +48,9 @@ class DataWriter:
46
48
  self.cache_debug = {}
47
49
  self.stat_stack_list = []
48
50
  self._error_log_initialized = False
51
+ self._cache_logged_error_types = set()
52
+ self.crc32_stack_list = []
53
+ self.data_updated = False
49
54
 
50
55
  @staticmethod
51
56
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -57,11 +62,31 @@ class DataWriter:
57
62
  spawn_writer = csv.writer(csv_file)
58
63
  if not is_exists:
59
64
  spawn_writer.writerow(result_header)
60
- spawn_writer.writerows([result,])
65
+ spawn_writer.writerows([result, ])
61
66
  is_new_file = not is_exists
62
67
  if is_new_file:
63
68
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
64
69
 
70
+ @recursion_depth_decorator("JsonWriter: DataWriter._replace_crc32_placeholders")
71
+ def _replace_crc32_placeholders(self, data, crc32_results):
72
+ """
73
+ 遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32
74
+ """
75
+ if isinstance(data, dict):
76
+ for k, v in list(data.items()):
77
+ if k == Const.MD5_INDEX and isinstance(v, int):
78
+ idx = v
79
+ # 防越界
80
+ crc = crc32_results[idx] if idx < len(crc32_results) else None
81
+ # 删除占位符,改成真实字段
82
+ del data[k]
83
+ data[Const.MD5] = crc
84
+ else:
85
+ self._replace_crc32_placeholders(v, crc32_results)
86
+ elif isinstance(data, (list, tuple)):
87
+ for item in data:
88
+ self._replace_crc32_placeholders(item, crc32_results)
89
+
65
90
  @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders")
66
91
  def _replace_stat_placeholders(self, data, stat_result):
67
92
  if isinstance(data, dict):
@@ -107,6 +132,25 @@ class DataWriter:
107
132
  self.cache_stack = {}
108
133
  self.cache_construct = {}
109
134
  self.cache_debug = {}
135
+ self._cache_logged_error_types = set()
136
+
137
+ def append_crc32_to_buffer(self, future: concurrent.futures.Future) -> int:
138
+ """
139
+ 把一个计算 CRC32 的 Future 放入队列,返回占位符索引
140
+ """
141
+ idx = len(self.crc32_stack_list)
142
+ self.crc32_stack_list.append(future)
143
+ return idx
144
+
145
+ def flush_crc32_stack(self):
146
+ """
147
+ 等待所有 CRC32 计算完成,返回结果列表
148
+ """
149
+ if not self.crc32_stack_list:
150
+ return []
151
+ results = [f.result() for f in self.crc32_stack_list]
152
+ self.crc32_stack_list = []
153
+ return results
110
154
 
111
155
  def initialize_json_file(self, **kwargs):
112
156
  if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug:
@@ -142,18 +186,32 @@ class DataWriter:
142
186
 
143
187
  length = len(dump_data)
144
188
 
145
- threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
189
+ # 1) 先取到 config(如果没有,就拿 None)
190
+ cfg = getattr(self, "config", None)
191
+ # 2) 再取 summary_mode(如果 cfg 是 None 或者没 summary_mode,就拿 None)
192
+ summary_mode = getattr(cfg, "summary_mode", None)
193
+
194
+ if summary_mode == Const.MD5:
195
+ threshold = self.md5_flush_size
196
+ else:
197
+ threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
146
198
 
147
199
  if length % threshold == 0:
148
200
  self.write_json()
149
201
 
150
- def write_error_log(self, message: str):
202
+ def write_error_log(self, message: str, error_type: str):
151
203
  """
152
204
  写错误日志:
153
205
  - 第一次调用时以 'w' 模式清空文件,之后都用 'a' 模式追加
154
206
  - 添加时间戳
155
207
  - 在 message 后写入当前的调用栈(方便追踪日志来源)
156
208
  """
209
+ # 如果同类型错误已经记录过,跳过
210
+ if error_type in self._cache_logged_error_types:
211
+ return
212
+ # 否则添加到已记录集合,并继续写日志
213
+ self._cache_logged_error_types.add(error_type)
214
+
157
215
  try:
158
216
  mode = "w" if not self._error_log_initialized else "a"
159
217
  self._error_log_initialized = True
@@ -182,6 +240,7 @@ class DataWriter:
182
240
  logger.warning(f"The dump data({dump_data}) should be a dict.")
183
241
  return
184
242
 
243
+ self.data_updated = True
185
244
  key = next(iter(new_data.keys()))
186
245
  if key in dump_data:
187
246
  dump_data.get(key).update(new_data.get(key))
@@ -190,6 +249,7 @@ class DataWriter:
190
249
 
191
250
  def update_stack(self, name, stack_data):
192
251
  with lock:
252
+ self.data_updated = True
193
253
  api_list = self.cache_stack.get(stack_data)
194
254
  if api_list is None:
195
255
  self.cache_stack.update({stack_data: [name]})
@@ -198,10 +258,12 @@ class DataWriter:
198
258
 
199
259
  def update_construct(self, new_data):
200
260
  with lock:
261
+ self.data_updated = True
201
262
  self.cache_construct.update(new_data)
202
263
 
203
264
  def update_debug(self, new_data):
204
265
  with lock:
266
+ self.data_updated = True
205
267
  self.cache_debug['data'].update(new_data)
206
268
 
207
269
  def write_data_json(self, file_path):
@@ -268,9 +330,21 @@ class DataWriter:
268
330
  stat_result = self.flush_stat_stack()
269
331
  # 遍历 cache_data,将占位符替换为最终统计值
270
332
  if stat_result:
333
+ self.data_updated = True
271
334
  self._replace_stat_placeholders(self.cache_data, stat_result)
272
335
  if self.cache_debug:
273
336
  self._replace_stat_placeholders(self.cache_debug, stat_result)
337
+
338
+ crc32_result = self.flush_crc32_stack()
339
+ if crc32_result:
340
+ self.data_updated = True
341
+ self._replace_crc32_placeholders(self.cache_data, crc32_result)
342
+ if self.cache_debug:
343
+ self._replace_crc32_placeholders(self.cache_debug, crc32_result)
344
+
345
+ if not self.data_updated:
346
+ return
347
+
274
348
  if self.cache_data:
275
349
  self.write_data_json(self.dump_file_path)
276
350
  if self.cache_stack:
@@ -279,4 +353,4 @@ class DataWriter:
279
353
  self.write_construct_info_json(self.construct_file_path)
280
354
  if self.cache_debug:
281
355
  self.write_debug_info_json(self.debug_file_path)
282
-
356
+ self.data_updated = False
@@ -69,8 +69,7 @@ class BaseScope(ABC):
69
69
  self.scope = scope
70
70
  self.api_list = api_list
71
71
 
72
- @staticmethod
73
- def rectify_args(scope, api_list):
72
+ def rectify_args(self, scope, api_list):
74
73
  if not isinstance(api_list, list):
75
74
  raise ScopeException(ScopeException.InvalidApiStr,
76
75
  f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
@@ -104,12 +103,11 @@ class BaseScope(ABC):
104
103
 
105
104
 
106
105
  class ListScope(BaseScope):
107
- @staticmethod
108
- def rectify_args(scope, api_list):
106
+ def rectify_args(self, scope, api_list):
109
107
  if scope and api_list:
110
108
  raise ScopeException(ScopeException.ArgConflict,
111
109
  f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
112
- return super(ListScope, ListScope).rectify_args(scope, api_list)
110
+ return super().rectify_args(scope, api_list)
113
111
 
114
112
  def check(self, name):
115
113
  if not self.scope or name in self.scope:
@@ -147,7 +145,7 @@ class RangeScope(BaseScope, ABC):
147
145
  f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
148
146
 
149
147
  def rectify_args(self, scope, api_list):
150
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
148
+ scope, api_list = super().rectify_args(scope, api_list)
151
149
  if scope and len(scope) != 2:
152
150
  raise ScopeException(ScopeException.InvalidScope,
153
151
  f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
@@ -13,34 +13,42 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import gc
16
17
  import os
17
18
  import threading
18
19
  from abc import ABC, abstractmethod
19
20
  from collections import defaultdict
20
21
 
21
- from msprobe.core.common.log import logger
22
22
  from msprobe.core.common.runtime import Runtime
23
23
  from msprobe.core.common.utils import Const, ThreadSafe
24
24
  from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs)
25
25
 
26
26
 
27
27
  class HookSet:
28
- def __init__(self, forward_hook=None, forward_pre_hook=None, backward_hook=None, backward_pre_hook=None):
29
- self.forward_hook = forward_hook
28
+ def __init__(
29
+ self,
30
+ forward_pre_hook=None,
31
+ forward_hook=None,
32
+ backward_pre_hook=None,
33
+ backward_hook=None,
34
+ distributed_forward_hook=None
35
+ ):
30
36
  self.forward_pre_hook = forward_pre_hook
31
- self.backward_hook = backward_hook
37
+ self.forward_hook = forward_hook
32
38
  self.backward_pre_hook = backward_pre_hook
39
+ self.backward_hook = backward_hook
40
+ self.distributed_forward_hook = distributed_forward_hook
33
41
 
34
42
 
35
43
  class BaseHookManager(ABC):
36
44
  inner_switch = defaultdict(bool)
45
+ inner_api_count = defaultdict(int)
37
46
  hook_handle_dict = {}
38
47
  params_grad_info = {}
39
48
 
40
- def __init__(self, data_collector, config, attl_manager=None):
49
+ def __init__(self, data_collector, config):
41
50
  self.data_collector = data_collector
42
51
  self.config = config
43
- self.attl_manager = attl_manager
44
52
 
45
53
  @property
46
54
  def _pid(self):
@@ -51,6 +59,30 @@ class BaseHookManager(ABC):
51
59
  def _is_recompute(self):
52
60
  pass
53
61
 
62
+ @staticmethod
63
+ def reset_status():
64
+ BaseHookManager.inner_switch = defaultdict(bool)
65
+ BaseHookManager.inner_api_count = defaultdict(int)
66
+ BaseHookManager.hook_handle_dict.clear()
67
+ BaseHookManager.params_grad_info.clear()
68
+
69
+ @staticmethod
70
+ def ensure_gc_enabled():
71
+ is_gc_disabled = not gc.isenabled()
72
+ if is_gc_disabled:
73
+ gc.enable()
74
+ return is_gc_disabled
75
+
76
+ @staticmethod
77
+ def restore_gc_state(original_state):
78
+ if original_state:
79
+ gc.disable()
80
+
81
+ @staticmethod
82
+ def _clear_input_kwargs(module, tid):
83
+ if hasattr(module, 'msprobe_input_kwargs') and tid in module.msprobe_input_kwargs:
84
+ del module.msprobe_input_kwargs[tid]
85
+
54
86
  @staticmethod
55
87
  @abstractmethod
56
88
  def _no_grad_context():
@@ -63,18 +95,30 @@ class BaseHookManager(ABC):
63
95
 
64
96
  @staticmethod
65
97
  @abstractmethod
66
- def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
98
+ def _get_count(name):
67
99
  pass
68
100
 
69
101
  @staticmethod
70
- def _clear_input_kwargs(module):
71
- if hasattr(module, 'msprobe_input_kwargs'):
72
- del module.msprobe_input_kwargs
102
+ @abstractmethod
103
+ def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs):
104
+ pass
73
105
 
74
106
  @abstractmethod
75
107
  def build_hook(self):
76
108
  pass
77
109
 
110
+ @abstractmethod
111
+ def _register_forward_hook(self, module, api_name):
112
+ pass
113
+
114
+ @abstractmethod
115
+ def _register_backward_hook(self, module, full_backward_name, args):
116
+ pass
117
+
118
+ @abstractmethod
119
+ def _register_backward_pre_hook(self, module, full_backward_name, output):
120
+ pass
121
+
78
122
  @abstractmethod
79
123
  def _get_params_dict(self, module):
80
124
  pass
@@ -96,7 +140,7 @@ class BaseHookManager(ABC):
96
140
  old_handle = BaseHookManager.hook_handle_dict.get(name)
97
141
  if old_handle and hasattr(old_handle, "remove"):
98
142
  old_handle.remove()
99
- handle = param.register_hook(self._build_grad_hook(module, ori_name, param_name))
143
+ handle = param.register_hook(self._build_grad_hook(ori_name, param_name))
100
144
  BaseHookManager.hook_handle_dict[name] = handle
101
145
 
102
146
  def _init_params_grad_info(self, module, params_dict):
@@ -115,108 +159,116 @@ class BaseHookManager(ABC):
115
159
  # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
116
160
  self.data_collector.handle_data(grad_name, data_info,
117
161
  flush=self.data_collector.data_processor.is_terminated)
162
+ self.data_collector.params_grad_record[grad_name] = True
118
163
  # 记录当前模块的参数梯度信息已占位
119
164
  BaseHookManager.params_grad_info[grad_name] = True
120
165
 
121
- def _should_execute_hook(self, hook_type, module, is_forward, tid):
122
- is_module_hook = hook_type == Const.MODULE
123
- if is_module_hook and not Runtime.is_running:
124
- return False
125
- elif not is_module_hook and is_forward and not Runtime.is_running:
166
+ def _should_execute_hook(self, hook_type, tid, is_forward=True):
167
+ is_api_hook = hook_type == Const.API
168
+ if BaseHookManager.inner_switch[tid]:
126
169
  return False
127
- elif not is_module_hook and not is_forward and not module.forward_data_collected:
170
+ if not is_api_hook and not Runtime.is_running:
128
171
  return False
129
- if BaseHookManager.inner_switch[tid]:
172
+ elif is_api_hook and is_forward and not Runtime.is_running:
130
173
  return False
131
174
  if not self.data_collector or self.data_collector.data_processor.is_terminated:
132
175
  return False
133
176
  return True
134
177
 
135
- def _build_grad_hook(self, module, ori_name, param_name):
178
+ def _build_grad_hook(self, ori_name, param_name):
136
179
  def hook_fn(grad):
137
180
  tid = threading.get_ident()
138
- if not self._should_execute_hook(Const.MODULE, module, False, tid):
181
+ if not self._should_execute_hook(Const.MODULE, tid):
139
182
  return
140
183
  with ThreadSafe():
184
+ original_state = self.ensure_gc_enabled()
141
185
  BaseHookManager.inner_switch[tid] = True
142
186
  self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
143
187
  BaseHookManager.inner_switch[tid] = False
188
+ self.restore_gc_state(original_state)
144
189
  return
145
190
 
146
191
  return hook_fn
147
192
 
148
- def _build_forward_pre_hook(self, hook_type, full_name, api_name):
193
+ def _build_forward_pre_hook(self, hook_type, api_name):
149
194
  def forward_pre_hook(module, args, kwargs=None):
150
- """
151
- 为确保多线程场景下 L1 级别数据采集的正确性,每个封装后的 API 的 init 方法和 forward_pre_hook 需要确保在一个线程内连续完成,
152
- 因此在 API 的 init 方法执行 ThreadSafe.acquire() 加锁操作,
153
- 并且在 API 的 forward_pre_hook 方法执行 ThreadSafe.release() 释放锁操作。
154
- """
155
195
  if hook_type == Const.MODULE:
156
- return
196
+ return None
157
197
 
158
198
  tid = threading.get_ident()
159
- if not self._should_execute_hook(hook_type, module, True, tid):
160
- ThreadSafe.release()
161
- return
199
+ if not self._should_execute_hook(hook_type, tid):
200
+ return None
162
201
 
163
- module.forward_data_collected = True
164
- self._add_count(api_name)
165
- if getattr(self.config, "online_run_ut", False):
166
- ThreadSafe.release()
167
- return
202
+ with ThreadSafe():
203
+ original_state = self.ensure_gc_enabled()
204
+ self._register_forward_hook(module, api_name)
205
+ BaseHookManager.inner_api_count[tid] += 1
206
+ if BaseHookManager.inner_api_count[tid] != 1:
207
+ return None
208
+
209
+ full_forward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.FORWARD
210
+ full_backward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.BACKWARD
211
+ module.full_forward_name = full_forward_name
212
+ if kwargs is None:
213
+ kwargs = module.msprobe_input_kwargs.get(tid, {}) if hasattr(module, 'msprobe_input_kwargs') else {}
214
+ BaseHookManager.inner_switch[tid] = True
215
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
168
216
 
169
- BaseHookManager.inner_switch[tid] = True
170
- if kwargs is None:
171
- kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
172
- try:
217
+ args = self._register_backward_hook(module, full_backward_name, args)
173
218
  with self._no_grad_context():
174
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
175
- self.data_collector.update_api_or_module_name(full_name)
219
+ self.data_collector.update_api_or_module_name(full_forward_name)
176
220
  self.data_collector.forward_input_data_collect(
177
- full_name,
221
+ full_forward_name,
178
222
  module,
179
223
  self._pid,
180
224
  module_input_output,
181
225
  self._is_recompute
182
226
  )
183
- except Exception as e:
184
- logger.error(f"The forward pre hook execution of the {full_name} API failed.")
185
- raise e
186
- finally:
187
227
  BaseHookManager.inner_switch[tid] = False
188
- ThreadSafe.release()
228
+ self.restore_gc_state(original_state)
229
+ return args
189
230
 
190
231
  return forward_pre_hook
191
232
 
192
- def _build_forward_hook(self, hook_type, full_name):
233
+ def _build_forward_hook(self, hook_type, api_name):
193
234
  def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
194
235
  tid = threading.get_ident()
195
- if not self._should_execute_hook(hook_type, module, True, tid):
196
- self._clear_input_kwargs(module)
236
+ if not self._should_execute_hook(hook_type, tid):
237
+ self._clear_input_kwargs(module, tid)
197
238
  return None
198
239
 
199
240
  with ThreadSafe():
200
- kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs)
241
+ original_state = self.ensure_gc_enabled()
242
+ if hook_type == Const.API:
243
+ if BaseHookManager.inner_api_count[tid] != 1:
244
+ if BaseHookManager.inner_api_count[tid] > 1:
245
+ BaseHookManager.inner_api_count[tid] -= 1
246
+ self._clear_input_kwargs(module, tid)
247
+ return None
248
+
249
+ kwargs, output = self._process_kwargs_and_output(
250
+ module,
251
+ tid,
252
+ hook_type,
253
+ kwargs_or_output,
254
+ output_or_kwargs
255
+ )
201
256
  BaseHookManager.inner_switch[tid] = True
202
- self.data_collector.update_api_or_module_name(full_name)
203
257
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
258
+ if hook_type == Const.API:
259
+ full_forward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.FORWARD
260
+ full_backward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.BACKWARD
261
+ output = self._register_backward_pre_hook(module, full_backward_name, output)
262
+
204
263
  with self._no_grad_context():
205
- if getattr(self.config, "online_run_ut", False):
206
- if self.data_collector.scope and not self.data_collector.scope.check(full_name):
207
- return None
208
- if self.attl_manager:
209
- self.attl_manager.attl_send(full_name, args, kwargs, output)
210
- BaseHookManager.inner_switch[tid] = False
211
- return None
212
264
  if hook_type == Const.MODULE:
213
265
  params_dict = self._get_params_dict(module)
214
266
  setattr(module_input_output, Const.PARAMS, params_dict)
215
267
  if params_dict:
216
- self._register_param_hook(full_name, module, params_dict)
217
- self.data_collector.update_api_or_module_name(full_name)
268
+ self._register_param_hook(api_name, module, params_dict)
269
+ self.data_collector.update_api_or_module_name(api_name)
218
270
  self.data_collector.forward_data_collect(
219
- full_name,
271
+ api_name,
220
272
  module,
221
273
  self._pid,
222
274
  module_input_output,
@@ -224,37 +276,40 @@ class BaseHookManager(ABC):
224
276
  )
225
277
  self._init_params_grad_info(module, params_dict)
226
278
  else:
279
+ self.data_collector.update_api_or_module_name(full_forward_name)
227
280
  self.data_collector.forward_output_data_collect(
228
- full_name,
281
+ full_forward_name,
229
282
  module,
230
283
  self._pid,
231
284
  module_input_output,
232
285
  self._is_recompute
233
286
  )
234
- self._clear_input_kwargs(module)
287
+ self._add_count(api_name)
288
+ BaseHookManager.inner_api_count[tid] -= 1
289
+ self._clear_input_kwargs(module, tid)
235
290
 
236
291
  if self.data_collector.if_return_forward_new_output():
237
292
  forward_new_output = self.data_collector.get_forward_new_output()
238
293
  BaseHookManager.inner_switch[tid] = False
239
294
  return forward_new_output
240
295
 
241
- BaseHookManager.inner_switch[tid] = False
242
- return output
296
+ BaseHookManager.inner_switch[tid] = False
297
+ self.restore_gc_state(original_state)
298
+ return output
243
299
 
244
300
  return forward_hook
245
301
 
246
302
  def _build_backward_hook(self, hook_type, full_name):
247
303
  def backward_hook(module, grad_input, grad_output):
248
304
  tid = threading.get_ident()
249
- if not self._should_execute_hook(hook_type, module, False, tid):
305
+ if not self._should_execute_hook(hook_type, tid, is_forward=False):
250
306
  return
251
307
 
252
308
  with ThreadSafe():
309
+ original_state = self.ensure_gc_enabled()
253
310
  BaseHookManager.inner_switch[tid] = True
254
311
  self.data_collector.update_api_or_module_name(full_name)
255
- if getattr(self.config, "online_run_ut", False):
256
- BaseHookManager.inner_switch[tid] = False
257
- return
312
+
258
313
  need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
259
314
  if need_exchange:
260
315
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
@@ -267,6 +322,10 @@ class BaseHookManager(ABC):
267
322
  module_input_output,
268
323
  self._is_recompute
269
324
  )
325
+ if hook_type == Const.MODULE:
326
+ params_dict = self._get_params_dict(module)
327
+ self.data_collector.params_data_collect_in_bw_hook(params_dict, full_name)
270
328
  BaseHookManager.inner_switch[tid] = False
329
+ self.restore_gc_state(original_state)
271
330
 
272
331
  return backward_hook