mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -17,6 +17,7 @@ import os
17
17
  from collections import defaultdict, namedtuple
18
18
 
19
19
  import mindspore as ms
20
+ from mindspore.ops.operations import _inner_ops as inner
20
21
  from mindspore._c_expression import MSContext
21
22
 
22
23
  from msprobe.core.common.const import Const, MsgConst
@@ -28,7 +29,8 @@ from msprobe.mindspore.common.const import Const as MsConst
28
29
  from msprobe.mindspore.common.utils import (
29
30
  set_register_backward_hook_functions,
30
31
  check_save_param,
31
- is_graph_mode_cell_dump_allowed
32
+ is_graph_mode_cell_dump_allowed,
33
+ wrap_backward_hook_call_func
32
34
  )
33
35
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
34
36
  from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
@@ -41,6 +43,7 @@ from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
41
43
 
42
44
  try:
43
45
  from mindspore._c_expression import _dump_start, _dump_stop, _dump_step, _set_init_iter, _dump_set_dynamic
46
+ import mindspore as ms
44
47
  except ImportError:
45
48
  enable_dynamic_kbyk_dump = False
46
49
  else:
@@ -80,6 +83,9 @@ class PrecisionDebugger(BasePrecisionDebugger):
80
83
  if self._is_kernel_dump() and not self.task_config.is_regex_valid:
81
84
  raise ValueError('Illegal regular expressions exist in the list.')
82
85
 
86
+ setattr(inner.CellBackwardHook, '__call__',
87
+ wrap_backward_hook_call_func(getattr(inner.CellBackwardHook, '__call__')))
88
+
83
89
  if self._is_kernel_dump() and _msprobe_c:
84
90
  os.environ["MS_HOOK_ENABLE"] = "on"
85
91
  _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
@@ -90,7 +96,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
90
96
 
91
97
  Runtime.step_count = 0
92
98
  Runtime.is_running = False
93
- if enable_dynamic_kbyk_dump:
99
+ if enable_dynamic_kbyk_dump and self.config.level_ori == Const.LEVEL_L2:
94
100
  _dump_set_dynamic()
95
101
 
96
102
  @staticmethod
@@ -160,7 +166,8 @@ class PrecisionDebugger(BasePrecisionDebugger):
160
166
  instance.service.stop()
161
167
  else:
162
168
  Runtime.is_running = False
163
- if enable_dynamic_kbyk_dump:
169
+ if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2:
170
+ ms.runtime.synchronize()
164
171
  _dump_stop()
165
172
  if cls._is_kernel_dump() and _msprobe_c:
166
173
  _msprobe_c._PrecisionDebugger().stop()
@@ -175,8 +182,8 @@ class PrecisionDebugger(BasePrecisionDebugger):
175
182
  with ThreadSafe():
176
183
  instance.service.step()
177
184
  if is_graph_mode_cell_dump_allowed(instance.config):
178
- GraphModeCellDump.step()
179
- if enable_dynamic_kbyk_dump:
185
+ GraphModeCellDump.step(instance.config.dump_path, instance.config.step, instance.config.task)
186
+ if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2:
180
187
  _dump_step(1)
181
188
  if cls._is_kernel_dump() and _msprobe_c:
182
189
  _msprobe_c._PrecisionDebugger().step()
@@ -207,12 +214,9 @@ class PrecisionDebugger(BasePrecisionDebugger):
207
214
  check_save_param(variable, name, save_backward)
208
215
  except ValueError:
209
216
  return
210
-
211
- instance.config.execution_mode = cls._get_execution_mode()
212
- if cls._need_service():
213
- if not instance.service:
214
- instance.service = MindsporeService(instance.config)
215
- instance.service.save(variable, name, save_backward)
217
+ if not instance.service:
218
+ instance.service = MindsporeService(instance.config)
219
+ instance.service.save(variable, name, save_backward)
216
220
 
217
221
  @classmethod
218
222
  def _need_service(cls):
@@ -220,7 +224,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
220
224
  if not instance:
221
225
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
222
226
  if instance.config.level_ori == Const.LEVEL_L2:
223
- return False
227
+ return not instance._is_graph_dump(instance.config)
224
228
  if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
225
229
  return False
226
230
  else:
@@ -38,15 +38,19 @@ DEFAULT_RANK_DIR = "rank0"
38
38
  KEY_LAYERS = "layers"
39
39
  construct = {}
40
40
  cell_list = []
41
+ free_cells = {}
42
+ parent_cell_types = {}
41
43
  KEY_SIDE_EFFECT = "side_effect_io"
42
44
  KEY_TOPLAYER = "TopLayer"
43
45
  KEY_FORWARD = CoreConst.FORWARD
44
46
  KEY_BACKWARD = CoreConst.BACKWARD
45
47
  KEY_INPUT = CoreConst.INPUT
46
48
  KEY_OUTPUT = CoreConst.OUTPUT
47
- KEY_DUMP_TENSOR_DATA = "dump_tensor_data_"
49
+ KEY_DUMP_TENSOR_DATA = "dump_tensor_data/"
48
50
  KEY_STATISTIC_CSV = "statistic.csv"
49
51
  KEY_TD_FLAG = "td_flag"
52
+ # 设置落盘文件检测超时时间
53
+ TIMEOUT = 600
50
54
  td = ops.TensorDump()
51
55
  if (ms.__version__ >= "2.5.0"):
52
56
  td_in = ops.TensorDump("in")
@@ -219,8 +223,16 @@ def cell_construct_wrapper(func, self):
219
223
  def sort_filenames(path):
220
224
  filenames = os.listdir(path)
221
225
  id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$')
222
- filenames.sort(key=lambda x: int(id_pattern.findall(x)[0]))
223
- return filenames
226
+ # 只保留能提取到数字id的文件,避免数组越界
227
+ valid_files = []
228
+ for filename in filenames:
229
+ match = id_pattern.findall(filename)
230
+ if match and match[0].isdigit():
231
+ valid_files.append(filename)
232
+ else:
233
+ logger.warning(f"File {filename} does not match the expected pattern and will be ignored.")
234
+ valid_files.sort(key=lambda x: int(id_pattern.findall(x)[0]))
235
+ return valid_files
224
236
 
225
237
 
226
238
  def rename_filename(path="", data_df=None):
@@ -294,7 +306,24 @@ def check_relation(cell_name, parent_cell_name):
294
306
  return False
295
307
 
296
308
 
309
+ def get_parent_cell_name(child_cell_name):
310
+ parent_cell_name = ''
311
+
312
+ last_dot_index = child_cell_name.rfind(CoreConst.SEP)
313
+ if last_dot_index == -1:
314
+ return parent_cell_name
315
+
316
+ layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$"
317
+ if re.search(layers_pattern, child_cell_name):
318
+ parent_cell_name = re.sub(layers_pattern, '', child_cell_name)
319
+ else:
320
+ parent_cell_name = child_cell_name[:last_dot_index]
321
+
322
+ return parent_cell_name
323
+
324
+
297
325
  def get_construct(cell_list_input):
326
+ global free_cells, parent_cell_types
298
327
  for cell in cell_list_input:
299
328
  cell_name = get_cell_name(cell)
300
329
  cell_data_mode = get_data_mode(cell)
@@ -308,7 +337,20 @@ def get_construct(cell_list_input):
308
337
  found_flag = True
309
338
  break
310
339
  if not found_flag:
311
- construct.update({cell: None})
340
+ cell_name_with_mode = f'{cell_name}{CoreConst.SEP}{cell_data_mode}'
341
+ if cell_name_with_mode in free_cells:
342
+ construct.update({cell: free_cells.get(cell_name_with_mode)})
343
+ continue
344
+
345
+ parent_cell = None
346
+ parent_cell_name = get_parent_cell_name(cell_name)
347
+ if parent_cell_name and cell_name in parent_cell_types:
348
+ parent_cell = CoreConst.SEP.join([CoreConst.CELL, parent_cell_name, parent_cell_types.get(cell_name)])
349
+ second_last_dot_index = cell.rfind(CoreConst.SEP, 0, cell.rfind(CoreConst.SEP))
350
+ parent_cell = f'{parent_cell}{cell[second_last_dot_index:]}'
351
+ free_cells[cell_name_with_mode] = parent_cell
352
+
353
+ construct.update({cell: parent_cell})
312
354
 
313
355
 
314
356
  def generate_construct(path):
@@ -462,7 +504,7 @@ def process_csv(path):
462
504
  if col_name in columns:
463
505
  value = convert_special_values(row[col_name])
464
506
  tensor_json[json_key] = value
465
-
507
+
466
508
  if io_key == KEY_INPUT:
467
509
  data_info.append([op_name, CoreConst.INPUT_ARGS, tensor_json])
468
510
  elif io_key == KEY_OUTPUT:
@@ -534,59 +576,75 @@ def generate_stack_info(path):
534
576
  logger.info(f"Stack data saved to {json_path}")
535
577
 
536
578
 
537
- def is_download_finished(directory, interval=3):
579
+ def is_download_finished(directory, save_flag):
538
580
  """
539
581
  判断指定目录在一段时间后是否有数据被下载完成
540
582
  :param directory: 指定目录的路径
541
- :param interval: 检查的时间间隔(秒),默认为 3 秒
583
+ :param save_flag: 数据落盘完成后的标志文件
542
584
  :return: 如有数据被下载完成返回 True,否则返回 False
543
585
  """
586
+ # 设定一定的延迟间隔,避免频繁进行磁盘的io读取操作
587
+ time.sleep(0.5)
588
+ logger.info("Waiting for download...")
544
589
  # 检查目录是否存在
545
590
  if not os.path.exists(directory):
546
591
  logger.warning(f"The specified directory {directory} does not exist.")
547
592
  return False
548
- initial_modification_time = os.path.getmtime(directory)
549
- time.sleep(interval)
550
- current_modification_time = os.path.getmtime(directory)
551
- # 比较初始和当前修改时间
552
- if current_modification_time > initial_modification_time:
553
- return False
554
- else:
555
- return True
593
+
594
+ # 遍历当前目录中的所有条目
595
+ for entry_path in os.listdir(directory):
596
+ if entry_path.startswith(save_flag):
597
+ return True
556
598
 
599
+ return False
600
+
601
+
602
+ def process_step(dump_path, flag_path, step, step_list):
603
+ if step not in step_list:
604
+ return
605
+
606
+ if not os.path.exists(dump_path):
607
+ logger.warning('No grap cell data is dumped.')
608
+ create_directory(dump_path)
609
+ return
557
610
 
558
- def process(dump_path):
559
611
  rank_id = os.environ.get('RANK_ID')
560
612
  rank_dir = DEFAULT_RANK_DIR
561
613
  if rank_id is not None:
562
614
  rank_dir = CoreConst.RANK + str(rank_id)
563
615
 
564
- step_dir_list = os.listdir(dump_path)
565
- for step_dir in step_dir_list:
566
- step_path = os.path.join(dump_path, step_dir)
567
- rank_path = os.path.join(step_path, rank_dir)
568
- npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA)
569
- while True:
570
- is_finished = is_download_finished(npy_path)
571
- if not is_finished:
572
- logger.info("There is data being downloaded in the specified directory, continue checking...")
573
- else:
574
- logger.info("There is no data being downloaded in the specified directory, Stop checking.")
575
- break
576
- logger.info("==========Start processing data that has already been stored on the disk!==========")
577
- rename_filename(path=npy_path)
578
- generate_construct(npy_path)
579
- generate_dump_info(npy_path)
580
- generate_stack_info(npy_path)
581
- # 单卡场景,rank目录名称为rank
582
- if rank_id is None:
583
- new_rank_path = os.path.join(step_path, CoreConst.RANK)
584
- try:
585
- move_directory(rank_path, new_rank_path)
586
- logger.info(f"Directory was successfully renamed to: {new_rank_path}")
587
- except Exception as e:
588
- logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
589
- logger.info("==========JSON file generation completed!==========")
616
+ step_dir = CoreConst.STEP + str(step)
617
+
618
+ step_path = os.path.join(dump_path, step_dir)
619
+ rank_path = os.path.join(step_path, rank_dir)
620
+ npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA)
621
+ save_finish_flag = f"step_{step}"
622
+ start_time = time.time()
623
+ while True:
624
+ is_finished = is_download_finished(flag_path, save_finish_flag)
625
+ if not is_finished:
626
+ logger.info("There is data being downloaded in the specified directory, continue checking...")
627
+ else:
628
+ logger.info("There is no data being downloaded in the specified directory, Stop checking.")
629
+ break
630
+ elapsed_time = time.time() - start_time
631
+ if elapsed_time > TIMEOUT:
632
+ logger.error(f"Check timed out after {TIMEOUT} seconds. Exiting.")
633
+ return
634
+ logger.info(f"==========Start processing step_{step}'s data that has already been stored on the disk!==========")
635
+ rename_filename(path=npy_path)
636
+ generate_construct(npy_path)
637
+ generate_dump_info(npy_path)
638
+ generate_stack_info(npy_path)
639
+ # 单卡场景,rank目录名称为rank
640
+ if rank_id is None:
641
+ new_rank_path = os.path.join(step_path, CoreConst.RANK)
642
+ try:
643
+ move_directory(rank_path, new_rank_path)
644
+ logger.info(f"Directory was successfully renamed to: {new_rank_path}")
645
+ except Exception as e:
646
+ logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
647
+ logger.info(f"==========Step_{step}'s JSON file generation completed!==========")
590
648
 
591
649
 
592
650
  # 删除csv文件中每行数据最后面的逗号
@@ -644,7 +702,15 @@ def merge_file(dump_path, rank_dir, file_dict):
644
702
  " and the index is out of bounds.")
645
703
 
646
704
 
647
- def process_statistics(dump_path):
705
+ def process_statistics_step(dump_path, step, step_list):
706
+ if step_list and step not in step_list:
707
+ return
708
+
709
+ if not os.path.exists(dump_path):
710
+ logger.warning('No grap cell data is dumped.')
711
+ create_directory(dump_path)
712
+ return
713
+
648
714
  rank_id = os.environ.get('RANK_ID')
649
715
  rank_dir_kbk = "rank_0"
650
716
  if rank_id is not None:
@@ -673,25 +739,24 @@ def process_statistics(dump_path):
673
739
 
674
740
  rank_dir = rank_dir_kbk.replace(CoreConst.REPLACEMENT_CHARACTER, '')
675
741
  dir_list = os.listdir(dump_path)
676
- step_dir_list = [d for d in dir_list if d.startswith(CoreConst.STEP)]
677
- for step_dir in step_dir_list:
678
- step_path = os.path.join(dump_path, step_dir)
679
- rank_path = os.path.join(step_path, rank_dir)
680
- csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV)
681
- logger.info("==========Start processing data csv!==========")
682
- generate_construct(csv_path)
683
- generate_dump_info(csv_path)
684
- generate_stack_info(csv_path)
685
- remove_path(rank_path_kbk)
686
- # 单卡场景,rank目录名称为rank
687
- if rank_id is None:
688
- new_rank_path = os.path.join(step_path, CoreConst.RANK)
689
- try:
690
- move_directory(rank_path, new_rank_path)
691
- logger.info(f"Directory was successfully renamed to: {new_rank_path}")
692
- except Exception as e:
693
- logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
694
- logger.info("==========JSON file generation completed!==========")
742
+ step_dir = CoreConst.STEP + str(step)
743
+ step_path = os.path.join(dump_path, step_dir)
744
+ rank_path = os.path.join(step_path, rank_dir)
745
+ csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV)
746
+ logger.info("==========Start processing data csv!==========")
747
+ generate_construct(csv_path)
748
+ generate_dump_info(csv_path)
749
+ generate_stack_info(csv_path)
750
+ remove_path(rank_path_kbk)
751
+ # 单卡场景,rank目录名称为rank
752
+ if rank_id is None:
753
+ new_rank_path = os.path.join(step_path, CoreConst.RANK)
754
+ try:
755
+ move_directory(rank_path, new_rank_path)
756
+ logger.info(f"Directory was successfully renamed to: {new_rank_path}")
757
+ except Exception as e:
758
+ logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
759
+ logger.info("==========JSON file generation completed!==========")
695
760
 
696
761
 
697
762
  def get_yaml_keys(yaml_data):
@@ -786,7 +851,7 @@ def create_kbyk_json(dump_path, summary_mode, step):
786
851
 
787
852
 
788
853
  def start(config: CellDumpConfig):
789
- global dump_task
854
+ global dump_task, parent_cell_types
790
855
  dump_task = config.task
791
856
  net = config.net
792
857
  dump_path = config.dump_path
@@ -814,7 +879,7 @@ def start(config: CellDumpConfig):
814
879
  return
815
880
 
816
881
  if isinstance(net, nn.Cell):
817
- net = (('', net),)
882
+ net = (('', net, None),)
818
883
 
819
884
  td_config_path = ""
820
885
  try:
@@ -837,6 +902,7 @@ def start(config: CellDumpConfig):
837
902
  black_list = ["grad_reducer", ""]
838
903
 
839
904
  for name_and_model in net:
905
+ parent_cell_types[name_and_model[0]] = name_and_model[2].__class__.__name__
840
906
  for name, cell in name_and_model[1].cells_and_names(name_prefix=name_and_model[0]):
841
907
  class_name = cell.__class__.__name__
842
908
  # 跳过黑名单cell
@@ -871,7 +937,3 @@ def start(config: CellDumpConfig):
871
937
  cell.data_mode = data_mode
872
938
 
873
939
  logger.info("==========The cell_dump_process_start phase is Finished!==========")
874
- if dump_task == CoreConst.TENSOR:
875
- atexit.register(process, dump_path=dump_path)
876
- if dump_task == CoreConst.STATISTICS:
877
- atexit.register(process_statistics, dump_path=dump_path)
@@ -197,8 +197,16 @@ def cell_construct_wrapper(func, self):
197
197
  def sort_filenames(path):
198
198
  filenames = os.listdir(path)
199
199
  id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$')
200
- filenames.sort(key=lambda x: int(id_pattern.findall(x)[0]))
201
- return filenames
200
+ # 只保留能提取到数字id的文件,避免数组越界
201
+ valid_files = []
202
+ for filename in filenames:
203
+ match = id_pattern.findall(filename)
204
+ if match and match[0].isdigit():
205
+ valid_files.append(filename)
206
+ else:
207
+ logger.warning(f"File {filename} does not match the expected pattern and will be ignored.")
208
+ valid_files.sort(key=lambda x: int(id_pattern.findall(x)[0]))
209
+ return valid_files
202
210
 
203
211
 
204
212
  def rename_filename(path="", data_df=None):
@@ -14,7 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
-
17
+ import glob
18
+ import tempfile
18
19
  import mindspore as ms
19
20
  from mindspore import hal, ops, Tensor
20
21
  from mindspore.ops.primitive import _run_op
@@ -28,15 +29,20 @@ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
28
29
  import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
29
30
 
30
31
  tensordump_flag = True
32
+ DEFAULT_RANK_DIR = "rank0"
31
33
  try:
32
34
  from mindspore._c_expression import _tensordump_set_step
33
35
  except ImportError:
34
36
  tensordump_flag = False
35
37
 
38
+ graph_step_flag = True
39
+ try:
40
+ from mindspore._c_expression import _dump_step
41
+ except ImportError:
42
+ graph_step_flag = False
36
43
 
37
- class GraphModeCellDump:
38
- task = CoreConst.STATISTICS
39
44
 
45
+ class GraphModeCellDump:
40
46
  def __init__(self, config: DebuggerConfig, model, strict=True):
41
47
  self.net = model
42
48
  self.white_list = []
@@ -49,20 +55,40 @@ class GraphModeCellDump:
49
55
  self.list = config.list
50
56
  self.data_mode = config.data_mode
51
57
  self.file_format = config.file_format
52
- GraphModeCellDump.task = config.task
53
58
  self.summary_mode = config.summary_mode
59
+ self.task = config.task
54
60
  self.check_config(strict)
55
61
  self.set_step()
56
62
 
57
63
  @staticmethod
58
- def step():
64
+ def step(dump_path, step_list, task):
59
65
  # 更新TensorDump Step
60
- if GraphModeCellDump.task == CoreConst.TENSOR:
66
+ if task == CoreConst.TENSOR:
61
67
  hal.synchronize()
62
68
  temp_tensor = ms.Tensor([1], dtype=ms.float32)
63
- step_flag = "<tensordump-update-step>"
64
- _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
65
- ops.tensordump(step_flag, temp_tensor)
69
+ rank_id = os.environ.get('RANK_ID')
70
+ rank_dir = DEFAULT_RANK_DIR
71
+
72
+ if rank_id is not None:
73
+ rank_dir = CoreConst.RANK + str(rank_id)
74
+
75
+ with tempfile.TemporaryDirectory(dir=dump_path, prefix=rank_dir) as temp_dir:
76
+ save_file_flag = f"{temp_dir}/step_{Runtime.step_count}"
77
+ _run_op(ops.TensorDump(), "TensorDump", (save_file_flag, temp_tensor))
78
+ step_flag = "<tensordump-update-step>"
79
+ _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
80
+ ops.tensordump(step_flag, temp_tensor)
81
+ cellDumperWithDumpGradient.process_step(dump_path, temp_dir, Runtime.step_count, step_list)
82
+
83
+ # 更新静态图KBK dump的step数
84
+ if task == CoreConst.STATISTICS:
85
+ if not graph_step_flag:
86
+ raise Exception(
87
+ "Importing _dump_step failed, "
88
+ "please use the latest version package of MindSpore."
89
+ )
90
+ _dump_step(1)
91
+ cellDumperWithDumpGradient.process_statistics_step(dump_path, Runtime.step_count, step_list)
66
92
 
67
93
  def check_config(self, strict):
68
94
  if not self.net:
@@ -16,6 +16,8 @@
16
16
  import os
17
17
  from collections import OrderedDict
18
18
  import mindspore as ms
19
+ from mindspore import hal, ops, Tensor
20
+ from mindspore.ops.primitive import _run_op
19
21
 
20
22
 
21
23
  def _iterate_items(data):
@@ -121,3 +123,12 @@ def save_grad(save_dir, name, data):
121
123
  dump_dir = generate_dump_dir(save_dir)
122
124
  suffix_name = name + '_grad'
123
125
  return _SaveGradCell(dump_dir, suffix_name)(data)
126
+
127
+
128
+ def step():
129
+ hal.synchronize()
130
+ temp_tensor = Tensor([1], dtype=ms.float32)
131
+ step_flag = "<tensordump-update-step>"
132
+ _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
133
+ ops.tensordump(step_flag, temp_tensor)
134
+ hal.synchronize()
@@ -40,36 +40,36 @@ cur_path = os.path.dirname(os.path.realpath(__file__))
40
40
  if not is_mindtorch():
41
41
  _api_types = {
42
42
  Const.MS_FRAMEWORK: {
43
- Const.MS_API_TYPE_OPS: (ops, (ops,)),
44
- Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
45
- Const.MS_API_TYPE_MINT: (mint, (mint,)),
46
- Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
47
- Const.MS_API_TYPE_COM: (comm_func, (comm_func,)),
48
- Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,))
43
+ Const.MS_API_TYPE_OPS: ((ops,), (ops,)),
44
+ Const.MS_API_TYPE_TENSOR: ((Tensor,), (Tensor,)),
45
+ Const.MS_API_TYPE_MINT: ((mint,), (mint,)),
46
+ Const.MS_API_TYPE_MINT_FUNC: ((functional,), (functional,)),
47
+ Const.MS_API_TYPE_COM: ((comm_func,), (comm_func,)),
48
+ Const.MS_API_TYPE_MINT_DIST: ((distributed,), (distributed,))
49
49
  }
50
50
  }
51
51
  if stub_tensor_existed:
52
52
  _api_types.get(Const.MS_FRAMEWORK).update(
53
- {Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))}
53
+ {Const.MS_API_TYPE_STUB_TENSOR: ((StubTensor,), (StubTensor,))}
54
54
  )
55
55
 
56
56
  _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
57
- _backlist = []
57
+ _blacklist = []
58
58
  else:
59
59
  import torch
60
60
  import torch_npu
61
61
  _api_types = {
62
62
  Const.MT_FRAMEWORK: {
63
- Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
64
- Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
65
- Const.PT_API_TYPE_TORCH: (torch, (torch,)),
66
- Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)),
67
- Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d))
63
+ Const.PT_API_TYPE_FUNCTIONAL: ((torch.nn.functional,), (torch.nn.functional,)),
64
+ Const.PT_API_TYPE_TENSOR: ((torch.Tensor,), (torch.Tensor,)),
65
+ Const.PT_API_TYPE_TORCH: ((torch,), (torch,)),
66
+ Const.PT_API_TYPE_NPU: ((torch_npu,), (torch_npu,)),
67
+ Const.PT_API_TYPE_DIST: ((torch.distributed,), (torch.distributed, torch.distributed.distributed_c10d))
68
68
  }
69
69
  }
70
70
  _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
71
71
  MsConst.SUPPORTED_API_LIST_FILE),)
72
- _backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__']
72
+ _blacklist = []
73
73
 
74
74
  _inner_used_api = {
75
75
  Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
@@ -87,12 +87,11 @@ _inner_used_api = {
87
87
  class ApiTemplate(HOOKCell):
88
88
  def __init__(self, api_name, api_func, prefix, hook_build_func):
89
89
  self.api_name = api_name
90
- self.api_func = api_func
91
90
  self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
92
- super().__init__(hook_build_func)
93
91
  distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX
94
- if prefix == distributed_prefix:
95
- self.op_is_distributed = True
92
+ self.op_is_distributed = prefix == distributed_prefix
93
+ super().__init__(hook_build_func)
94
+ self.api_func = api_func
96
95
 
97
96
  @staticmethod
98
97
  def async_to_sync(output):
@@ -161,7 +160,7 @@ def get_api_register(return_new=False):
161
160
  _inner_used_api,
162
161
  _supported_api_list_path,
163
162
  ApiTemplate,
164
- _backlist
163
+ _blacklist
165
164
  )
166
165
 
167
166
  global api_register
@@ -171,6 +170,6 @@ def get_api_register(return_new=False):
171
170
  _inner_used_api,
172
171
  _supported_api_list_path,
173
172
  ApiTemplate,
174
- _backlist
173
+ _blacklist
175
174
  )
176
175
  return api_register
@@ -19,8 +19,6 @@ from collections import defaultdict
19
19
  import mindspore as ms
20
20
  from mindspore import nn
21
21
 
22
- from msprobe.core.common.runtime import Runtime
23
- from msprobe.core.common.utils import ThreadSafe
24
22
  from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
25
23
 
26
24
  ms_version = ms.__version__
@@ -37,48 +35,28 @@ def get_cell_count(name):
37
35
  def __init__(self, hook_build_func) -> None:
38
36
  super(HOOKCell, self).__init__()
39
37
  self.msprobe_input_kwargs = {}
40
-
41
- self.tid = threading.get_ident()
42
- self.stop_hook = HOOKCell.inner_stop_hook.get(self.tid, False)
43
- if not self.stop_hook:
44
- self.forward_data_collected = False
45
-
46
- if not Runtime.is_running:
47
- return
48
- prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
49
- ThreadSafe.acquire()
50
- if callable(hook_build_func):
51
- hook_set = hook_build_func(prefix)
52
- if ms_version < "2.6.0" and not is_mindtorch():
53
- getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
38
+ prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
39
+ if callable(hook_build_func):
40
+ hook_set = hook_build_func(prefix)
41
+ if ms_version < "2.6.0" and not is_mindtorch():
42
+ getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
43
+ if hook_set.forward_hook:
54
44
  getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook
55
- else:
56
- self.register_forward_pre_hook(hook_set.forward_pre_hook)
45
+ else:
46
+ self.register_forward_pre_hook(hook_set.forward_pre_hook)
47
+ if hook_set.forward_hook:
57
48
  self.register_forward_hook(hook_set.forward_hook)
58
- register_backward_hook_functions["full"](self, hook_set.backward_hook)
59
- register_backward_hook_functions["pre"](self, hook_set.backward_pre_hook)
60
49
 
61
50
 
62
- # 重载call,加全局标志。
63
51
  def __call__(self, *args, **kwargs):
64
- changed = False
65
- if not self.stop_hook:
66
- HOOKCell.inner_stop_hook[self.tid] = True
67
- changed = True
68
- try:
69
- self.msprobe_input_kwargs = kwargs
70
- out = super(HOOKCell, self).__call__(*args, **kwargs)
71
- except Exception as e:
72
- raise e
73
- finally:
74
- if changed:
75
- HOOKCell.inner_stop_hook[self.tid] = False
52
+ tid = threading.get_ident()
53
+ self.msprobe_input_kwargs[tid] = kwargs
54
+ out = super(HOOKCell, self).__call__(*args, **kwargs)
76
55
  return out
77
56
 
78
57
 
79
58
  hook_cell_dict = {
80
59
  "cell_count": defaultdict(int),
81
- "inner_stop_hook": defaultdict(bool),
82
60
  "add_cell_count": staticmethod(add_cell_count),
83
61
  "get_cell_count": staticmethod(get_cell_count),
84
62
  "__init__": __init__,