mindstudio-probe 8.2.0__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +63 -61
  3. msprobe/README.md +4 -4
  4. msprobe/core/common/const.py +6 -0
  5. msprobe/core/common/db_manager.py +35 -4
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/megatron_utils.py +59 -0
  8. msprobe/core/common/utils.py +14 -3
  9. msprobe/core/compare/diff_analyze/first_diff_analyze.py +16 -4
  10. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  11. msprobe/core/compare/find_first/analyzer.py +8 -7
  12. msprobe/core/compare/find_first/graph.py +11 -3
  13. msprobe/core/compare/find_first/utils.py +3 -2
  14. msprobe/core/compare/highlight.py +13 -6
  15. msprobe/core/compare/multiprocessing_compute.py +17 -10
  16. msprobe/core/compare/utils.py +14 -5
  17. msprobe/core/data_dump/data_collector.py +18 -21
  18. msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
  19. msprobe/core/data_dump/json_writer.py +18 -8
  20. msprobe/core/data_dump/scope.py +4 -6
  21. msprobe/core/hook_manager.py +21 -0
  22. msprobe/core/service.py +2 -0
  23. msprobe/core/single_save/single_comparator.py +16 -3
  24. msprobe/docs/01.installation.md +7 -5
  25. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  26. msprobe/docs/06.data_dump_MindSpore.md +1 -1
  27. msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
  28. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  29. msprobe/docs/19.monitor.md +2 -0
  30. msprobe/docs/21.visualization_PyTorch.md +15 -80
  31. msprobe/docs/22.visualization_MindSpore.md +20 -104
  32. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  33. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  34. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  35. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  36. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  37. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  38. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  39. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  40. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  41. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  42. msprobe/mindspore/cell_processor.py +33 -5
  43. msprobe/mindspore/compare/common_dir_compare.py +22 -26
  44. msprobe/mindspore/debugger/precision_debugger.py +1 -1
  45. msprobe/mindspore/dump/cell_dump_process.py +73 -62
  46. msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
  47. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
  48. msprobe/pytorch/compare/utils.py +2 -1
  49. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  50. msprobe/pytorch/dump/module_dump/module_processer.py +15 -8
  51. msprobe/pytorch/monitor/module_hook.py +28 -9
  52. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  53. msprobe/visualization/builder/graph_builder.py +169 -64
  54. msprobe/visualization/builder/graph_merger.py +0 -1
  55. msprobe/visualization/builder/msprobe_adapter.py +1 -1
  56. msprobe/visualization/db_utils.py +25 -2
  57. msprobe/visualization/graph/base_node.py +0 -24
  58. msprobe/visualization/graph/graph.py +5 -14
  59. msprobe/visualization/graph_service.py +29 -53
  60. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  61. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  62. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  63. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,8 @@ from msprobe.core.compare.utils import gen_api_batches
26
26
 
27
27
  cur_dir = os.path.dirname(os.path.realpath(__file__))
28
28
  diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml')
29
+ ignore_op_list_yaml_path = os.path.join(cur_dir, 'ignore_op_list.yaml')
30
+ ignore_list = load_yaml(ignore_op_list_yaml_path)
29
31
  thresholds = load_yaml(diff_threshold_yaml_path)
30
32
  cmp_metrics = thresholds.get('compare_metrics')
31
33
 
@@ -51,7 +53,7 @@ class FirstDiffAnalyze:
51
53
  return True
52
54
  return False
53
55
 
54
- def single_api_check(self, result_slice, header):
56
+ def single_api_check(self, result_slice, header, api_name=None):
55
57
  """
56
58
  单个api差异检查
57
59
 
@@ -65,14 +67,18 @@ class FirstDiffAnalyze:
65
67
  }
66
68
 
67
69
  column_indices = {name: idx for idx, name in enumerate(header)}
68
-
70
+ output_idx = -1
69
71
  for line in result_slice:
70
72
  op_item = {
71
73
  column_name: line[column_indices[column_name]]
72
74
  for column_name in header
73
75
  }
74
76
  single_check_result['op_items'].append(op_item)
75
-
77
+ if op_item['state'] != 'output':
78
+ continue
79
+ output_idx += 1
80
+ if output_idx in ignore_list.get(api_name, []):
81
+ continue
76
82
  # set is_same
77
83
  if self.mode_config.dump_mode == Const.MD5:
78
84
  if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF:
@@ -117,7 +123,13 @@ class FirstDiffAnalyze:
117
123
  with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar:
118
124
  for api_batch in api_batches:
119
125
  result_slice = result[api_batch.start: api_batch.params_grad_end_index]
120
- check_result[api_batch.api_name] = self.single_api_check(result_slice, header)
126
+ api_compo = api_batch.api_name.split('.')
127
+ # suppose name is Tensor.MatMul.0.forward
128
+ if len(api_compo) < 4:
129
+ continue
130
+ # get MatMul as api_name
131
+ api_name = api_compo[-3]
132
+ check_result[api_batch.api_name] = self.single_api_check(result_slice, header, api_name)
121
133
  progress_bar.update(1)
122
134
 
123
135
  return check_result
@@ -0,0 +1,3 @@
1
+ npu_fusion_attention:
2
+ - 4
3
+ - 5
@@ -47,7 +47,6 @@ class DiffAnalyzer:
47
47
  analyze_func()
48
48
  if self._diff_nodes:
49
49
  self._gen_analyze_info()
50
- self._post_process()
51
50
  return
52
51
  logger.info('Cannot find any diff node, no need to generate analyze file.')
53
52
 
@@ -56,12 +55,6 @@ class DiffAnalyzer:
56
55
  self._resolve_input_path(self._output_path)
57
56
  logger.info("Pre Process completed.")
58
57
 
59
- def _post_process(self):
60
- for rank_path in self._paths.values():
61
- dump_path = rank_path.dump_path
62
- logger.debug(f"Remove {dump_path} success")
63
- logger.info("Post Process completed.")
64
-
65
58
  """
66
59
  这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中
67
60
  """
@@ -105,6 +98,8 @@ class DiffAnalyzer:
105
98
  logger.warning(f'Rank {path.rank} has no dump data!')
106
99
  continue
107
100
  for op_name, op_data in dump_data.items():
101
+ if is_ignore_op(op_name):
102
+ continue
108
103
  if is_communication_op(op_name):
109
104
  self._first_comm_nodes[path.rank] = op_name
110
105
  break
@@ -131,10 +126,16 @@ class DiffAnalyzer:
131
126
  for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]:
132
127
  searched_ranks.add(rank)
133
128
  seen_nodes = set()
129
+ last_node = None
134
130
  for cur_node in nodes.values():
131
+ is_overflow = last_node and hasattr(last_node, 'layer') and hasattr(cur_node, 'layer') and \
132
+ last_node.layer >= cur_node.layer
133
+ if is_overflow:
134
+ cur_node.layer = last_node.layer + 1
135
135
  conn_info = cur_node.find_connected_nodes()
136
136
  if not conn_info.get('ranks'):
137
137
  conn_info['ranks'] = self._rank_comm_nodes_dict.keys()
138
+ last_node = cur_node
138
139
  if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes):
139
140
  logger.debug(f'Cannot find connected communication node for "{cur_node.node_id}".')
140
141
 
@@ -52,19 +52,25 @@ class DataNode:
52
52
  metrics = {}
53
53
  for cmp_data in self.op_data:
54
54
  name = cmp_data.get(CompareConst.NPU_NAME)
55
+ # 构建度量指标字典
56
+ metrics = {}
57
+
55
58
  if CompareConst.NPU_MAX in cmp_data:
56
59
  metrics = {CompareConst.NPU_MAX: cmp_data.get(CompareConst.NPU_MAX),
57
60
  CompareConst.NPU_MIN: cmp_data.get(CompareConst.NPU_MIN),
58
61
  CompareConst.NPU_MEAN: cmp_data.get(CompareConst.NPU_MEAN),
59
62
  CompareConst.NPU_NORM: cmp_data.get(CompareConst.NPU_NORM)}
60
63
  elif CompareConst.NPU_MD5 in cmp_data:
61
- metrics = {CompareConst.NPU_MD5: cmp_data.get(CompareConst.NPU_MD5)}
64
+ metrics[CompareConst.NPU_MD5] = cmp_data.get(CompareConst.NPU_MD5)
65
+
66
+ if CompareConst.NPU_P2POP_PEER in cmp_data:
67
+ metrics[CompareConst.NPU_P2POP_PEER] = cmp_data.get(CompareConst.NPU_P2POP_PEER)
62
68
 
63
69
  if cmp_data.get(CompareConst.STACK) != CompareConst.N_A and not self.stack:
64
70
  self.stack = cmp_data.get(CompareConst.STACK)
65
- if Const.INPUT in name:
71
+ if cmp_data.get('state') == "input":
66
72
  self.inputs[name] = metrics
67
- elif Const.OUTPUT in name:
73
+ elif cmp_data.get('state') == "output":
68
74
  self.outputs[name] = metrics
69
75
 
70
76
  def gen_node_info(self, path: RankPath):
@@ -161,6 +167,8 @@ class CommunicationNode:
161
167
  if val and val.startswith('[') and val.endswith(']'):
162
168
  val = [int(part) for part in val.strip('[]').split(',')]
163
169
  ranks.update(val)
170
+ elif v.get(CompareConst.NPU_P2POP_PEER) != "None":
171
+ ranks.add(v.get(CompareConst.NPU_P2POP_PEER))
164
172
 
165
173
  return {'ranks': ranks, 'api': f'Distributed.{tar_api}',
166
174
  'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)}
@@ -120,7 +120,8 @@ def is_communication_op(op_name):
120
120
  def is_ignore_op(op_name):
121
121
  ignore_keywords = [
122
122
  'Torch.empty',
123
- 'Torch.fill'
123
+ 'Torch.fill',
124
+ 'Tensor.__setitem__'
124
125
  ]
125
126
  return any(keyword in op_name for keyword in ignore_keywords)
126
127
 
@@ -181,7 +182,7 @@ def analyze_diff_in_group(nodes_group):
181
182
  input_diff_nodes = list(filter(lambda node: node.is_diff, src_list))
182
183
  # 如果有异常回溯计算节点找到异常来源
183
184
  # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
184
- get_compute_ops_from_comm_nodes(input_diff_nodes)
185
+ get_compute_ops_from_comm_nodes(nodes_group)
185
186
  # 筛选入参没问题但出参有问题的通信节点
186
187
  output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group))
187
188
  get_comm_ops(output_diff_nodes)
@@ -26,7 +26,7 @@ from tqdm import tqdm
26
26
  from msprobe.core.common.const import CompareConst, Const
27
27
  from msprobe.core.common.file_utils import save_workbook
28
28
  from msprobe.core.common.log import logger
29
- from msprobe.core.common.utils import get_header_index
29
+ from msprobe.core.common.utils import get_header_index, CompareException
30
30
  from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches
31
31
  from msprobe.core.compare.config import ModeConfig
32
32
 
@@ -359,18 +359,25 @@ class HighLight:
359
359
 
360
360
  def err_call(args):
361
361
  logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
362
- try:
363
- pool.close()
364
- except OSError:
365
- logger.error("Pool terminate failed")
366
362
 
367
363
  result_df_columns = result_df.columns.tolist()
368
364
  for column in result_df_columns:
369
365
  self.value_check(column)
366
+ async_results = []
370
367
  for df_chunk in chunks:
371
- pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
368
+ result = pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
369
+ async_results.append(result)
372
370
 
373
371
  pool.close()
372
+
373
+ for ar in async_results:
374
+ try:
375
+ ar.get(timeout=3600)
376
+ except Exception as e:
377
+ logger.error(f"Task failed with exception: {e}")
378
+ pool.terminate()
379
+ raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
380
+
374
381
  pool.join()
375
382
 
376
383
  def df_malicious_value_check(self, result_df):
@@ -52,16 +52,20 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
52
52
 
53
53
  def err_call(args):
54
54
  logger.error('multiprocess compare failed! Reason: {}'.format(args))
55
- try:
56
- pool.close()
57
- except OSError as e:
58
- logger.error(f'pool terminate failed: {str(e)}')
59
55
 
60
56
  for df_chunk in df_chunks:
61
57
  result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
62
58
  results.append(result)
63
- final_results = [r.get() for r in results]
59
+
64
60
  pool.close()
61
+
62
+ try:
63
+ final_results = [r.get(timeout=3600) for r in results]
64
+ except Exception as e:
65
+ logger.error(f"Task failed with exception: {e}")
66
+ pool.terminate()
67
+ raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
68
+
65
69
  pool.join()
66
70
  return pd.concat(final_results, ignore_index=True)
67
71
 
@@ -277,10 +281,6 @@ class CompareRealData:
277
281
 
278
282
  def err_call(args):
279
283
  logger.error('multiprocess compare failed! Reason: {}'.format(args))
280
- try:
281
- pool.close()
282
- except OSError:
283
- logger.error("pool terminate failed")
284
284
 
285
285
  progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
286
286
 
@@ -298,7 +298,14 @@ class CompareRealData:
298
298
  )
299
299
  results.append(result)
300
300
 
301
- final_results = [r.get() for r in results]
302
301
  pool.close()
302
+
303
+ try:
304
+ final_results = [r.get(timeout=3600) for r in results]
305
+ except Exception as e:
306
+ logger.error(f"Task failed with exception: {e}")
307
+ pool.terminate()
308
+ raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
309
+
303
310
  pool.join()
304
311
  return pd.concat(final_results, ignore_index=True)
@@ -695,10 +695,6 @@ def get_sorted_ranks(npu_dump_dir, bench_dump_dir):
695
695
  def multi_statistics_compare(func, func_args):
696
696
  def err_call(args):
697
697
  logger.error(f'Multiprocess statistics compare failed! Reason: {args}')
698
- try:
699
- pool.close()
700
- except OSError:
701
- logger.error("Pool terminate failed")
702
698
 
703
699
  compare_func, input_param_nr_list, output_path, kwargs = func_args
704
700
 
@@ -715,9 +711,22 @@ def multi_statistics_compare(func, func_args):
715
711
  chunks[i].append(input_param_nr_list[param_num - remainder + i])
716
712
 
717
713
  pool = multiprocessing.Pool(process_num)
714
+
715
+ async_results = []
718
716
  for chunk in chunks:
719
- pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call)
717
+ result = pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call)
718
+ async_results.append(result)
719
+
720
720
  pool.close()
721
+
722
+ for ar in async_results:
723
+ try:
724
+ ar.get(timeout=3600)
725
+ except Exception as e:
726
+ logger.error(f"Task failed with exception: {e}")
727
+ pool.terminate()
728
+ raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
729
+
721
730
  pool.join()
722
731
 
723
732
 
@@ -23,6 +23,7 @@ from msprobe.core.data_dump.json_writer import DataWriter
23
23
  from msprobe.core.common.log import logger
24
24
  from msprobe.core.common.const import Const
25
25
  from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
26
+ from msprobe.core.common.megatron_utils import MegatronStepInfo, get_micro_step, is_megatron
26
27
 
27
28
 
28
29
  def build_data_collector(config):
@@ -270,15 +271,20 @@ class DataCollector:
270
271
  if self.config.level not in DataCollector.level_without_construct:
271
272
  if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
272
273
  if self.optimizer_status_first_start[self.optimizer_status]:
273
- self.data_writer.update_construct({self.optimizer_status: None})
274
+ self.data_writer.update_construct(
275
+ {self.optimizer_status: None if not is_megatron() else [None, get_micro_step()]})
274
276
  self.optimizer_status_first_start[self.optimizer_status] = False
275
- self.data_writer.update_construct({name: self.optimizer_status})
277
+ self.data_writer.update_construct(
278
+ {name: self.optimizer_status if not is_megatron() else [self.optimizer_status, get_micro_step()]})
276
279
  else:
277
280
  if self.config.level == Const.LEVEL_MIX and \
278
281
  not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
279
282
  self.data_writer.update_construct(
280
283
  {name: self.module_processor.api_parent_node.get(threading.get_ident())}
281
284
  )
285
+ if MegatronStepInfo.is_megatron:
286
+ micro_step_number = max(MegatronStepInfo.forward_micro_step, MegatronStepInfo.backward_micro_step)
287
+ self.data_writer.update_construct({Const.MEGATRON_MICRO_STEP_NUMBER: micro_step_number})
282
288
 
283
289
  self.data_writer.update_construct(self.module_processor.module_node)
284
290
 
@@ -302,25 +308,16 @@ class DataCollector:
302
308
  self.data_processor.update_iter(current_iter)
303
309
 
304
310
  def params_data_collect(self, name, param_name, pid, data):
305
- try:
306
- grad_name = name + Const.SEP + Const.PARAMS_GRAD
307
- self.update_api_or_module_name(grad_name)
308
- if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
309
- if self.data_writer.cache_data.get("data"):
310
- self.data_writer.cache_data.get("data").pop(grad_name, None)
311
- self.params_grad_record[grad_name] = False
312
- return
313
- data_info = self.data_processor.analyze_params(grad_name, param_name, data)
314
- self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
315
- self.params_grad_record[grad_name] = False
316
- except Exception as e:
317
- error_type = type(e).__name__
318
- tb = traceback.format_exc()
319
- self.data_writer.write_error_log(
320
- f"[ERROR] params_data_collect failed: "
321
- f"name={name}, param_name={param_name}, pid={pid}\n{tb}",
322
- error_type=error_type
323
- )
311
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
312
+ self.update_api_or_module_name(grad_name)
313
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
314
+ if self.data_writer.cache_data.get("data"):
315
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
316
+ self.params_grad_record[grad_name] = False
317
+ return
318
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
319
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
320
+ self.params_grad_record[grad_name] = False
324
321
 
325
322
  def params_data_collect_in_bw_hook(self, params_dict, name):
326
323
  try:
@@ -13,13 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import ctypes
16
17
  import os
17
18
  import zlib
18
- import ctypes
19
19
  from collections.abc import Iterable
20
+ from concurrent.futures import ThreadPoolExecutor
20
21
  from dataclasses import asdict
21
22
  from typing import List
22
- from concurrent.futures import ThreadPoolExecutor
23
23
 
24
24
  import numpy as np
25
25
  import torch
@@ -29,7 +29,6 @@ from torch.distributed.distributed_c10d import _get_default_group
29
29
  from msprobe.core.common.const import Const
30
30
  from msprobe.core.common.decorator import recursion_depth_decorator
31
31
  from msprobe.core.common.exceptions import MsprobeException
32
- from msprobe.core.common.file_utils import path_len_exceeds_limit
33
32
  from msprobe.core.common.log import logger
34
33
  from msprobe.core.common.utils import convert_tuple, is_int
35
34
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
@@ -48,15 +47,28 @@ class TensorHandler:
48
47
  def __init__(self):
49
48
  self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor")
50
49
  self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor")
50
+ self.has_async_collective_tensor = hasattr(dist, "_functional_collectives") and \
51
+ hasattr(dist._functional_collectives, "AsyncCollectiveTensor")
52
+
53
+ @staticmethod
54
+ def free_tensor(tensor, tensor_name):
55
+ try:
56
+ tensor.untyped_storage().resize_(0)
57
+ except Exception as e:
58
+ logger.warning(f"Failed to free tensor: {tensor_name}, the detail info: {e}.")
51
59
 
52
60
  def is_dtensor(self, tensor):
53
- return self.has_dtensor and isinstance(tensor, torch.distributed.tensor.DTensor)
61
+ return self.has_dtensor and isinstance(tensor, dist.tensor.DTensor)
54
62
 
55
63
  def is_fake_tensor(self, tensor):
56
64
  return self.has_fake_tensor and isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor)
57
65
 
66
+ def is_async_collective_tensor(self, tensor):
67
+ return self.has_async_collective_tensor and \
68
+ isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor)
69
+
58
70
  def is_empty_data(self, tensor):
59
- return tensor.is_meta or self.is_fake_tensor(tensor)
71
+ return tensor.is_meta or self.is_fake_tensor(tensor) or self.is_async_collective_tensor(tensor)
60
72
 
61
73
  def convert_common_tensor(self, tensor):
62
74
  if self.is_dtensor(tensor):
@@ -71,6 +83,8 @@ class TensorHandler:
71
83
  return Const.DTENSOR_TYPE
72
84
  if self.is_fake_tensor(tensor):
73
85
  return Const.FAKE_TENSOR_TYPE
86
+ if self.is_async_collective_tensor(tensor):
87
+ return Const.AC_TENSOR_TYPE
74
88
  return Const.TENSOR_TYPE
75
89
 
76
90
  def get_dtensor_info(self, tensor):
@@ -94,6 +108,18 @@ class TensorHandler:
94
108
  dtensor_info.update({"placements": placements})
95
109
  return dtensor_info
96
110
 
111
+ def save_tensor(self, tensor, file_path):
112
+ common_tensor = self.convert_common_tensor(tensor)
113
+ if self.is_empty_data(common_tensor):
114
+ logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
115
+ return
116
+ if common_tensor.untyped_storage().data_ptr() == 0:
117
+ logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
118
+ return
119
+ saved_tensor = common_tensor.clone().contiguous().detach()
120
+ save_pt(saved_tensor, file_path)
121
+ self.free_tensor(saved_tensor, file_path)
122
+
97
123
 
98
124
  class PytorchDataProcessor(BaseDataProcessor):
99
125
  pytorch_special_type = (
@@ -288,7 +314,7 @@ class PytorchDataProcessor(BaseDataProcessor):
288
314
 
289
315
  def dump_async_data(self):
290
316
  for file_path, tensor in self._async_dump_cache.items():
291
- save_pt(tensor.contiguous(), file_path)
317
+ self.tensor_handler.save_tensor(tensor, file_path)
292
318
  self._async_dump_cache.clear()
293
319
 
294
320
  def analyze_single_element(self, element, suffix_stack):
@@ -385,24 +411,24 @@ class PytorchDataProcessor(BaseDataProcessor):
385
411
  def _analyze_and_save_tensor(self, tensor, suffix):
386
412
  dump_data_name, file_path = self.get_save_file_path(suffix)
387
413
  single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
388
- if self.tensor_handler.is_empty_data(tensor) or tensor.untyped_storage().data_ptr() == 0:
389
- logger.debug(
390
- "Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0, "
391
- f"the current api/module name is {self.current_api_or_module_name}."
392
- )
414
+ common_tensor = self.tensor_handler.convert_common_tensor(tensor)
415
+ if self.tensor_handler.is_empty_data(common_tensor):
416
+ logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
417
+ return single_arg
418
+ if common_tensor.untyped_storage().data_ptr() == 0:
419
+ logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
393
420
  return single_arg
394
421
 
395
422
  single_arg.update({"data_name": dump_data_name})
396
423
  if self.config.async_dump:
397
- self._async_dump_cache[file_path] = tensor.clone().detach()
424
+ self._async_dump_cache[file_path] = common_tensor.clone().detach()
398
425
  else:
399
- saved_tensor = tensor.clone().contiguous().detach()
400
- save_pt(saved_tensor, file_path)
426
+ self.tensor_handler.save_tensor(common_tensor, file_path)
401
427
  return single_arg
402
428
 
403
429
  def _analyze_and_save_ndarray(self, ndarray, suffix):
404
430
  dump_data_name, file_path = self.get_save_file_path(suffix)
405
- save_pt(torch.tensor(ndarray), file_path)
431
+ self.tensor_handler.save_tensor(torch.tensor(ndarray), file_path)
406
432
  ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
407
433
  ndarray_json.update({"data_name": dump_data_name})
408
434
  return ndarray_json
@@ -493,7 +519,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
493
519
  self._analyze_maybe_overflow_flag()
494
520
  if self.has_overflow:
495
521
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
496
- save_pt(tensor.clone().contiguous().detach(), file_path)
522
+ self.tensor_handler.save_tensor(tensor, file_path)
497
523
  self.real_overflow_nums += 1
498
524
  if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
499
525
  logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
@@ -538,10 +564,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
538
564
 
539
565
  def _analyze_tensor(self, tensor, suffix):
540
566
  dump_data_name, file_path = self.get_save_file_path(suffix)
541
- if not path_len_exceeds_limit(file_path):
542
- self.cached_tensors_and_file_paths.update({file_path: tensor})
543
- else:
544
- logger.warning(f'The file path {file_path} length exceeds limit.')
567
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
545
568
  single_arg = super()._analyze_tensor(tensor, suffix)
546
569
  single_arg.update({"data_name": dump_data_name})
547
570
  if not self.has_overflow and self.support_inf_nan:
@@ -13,18 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import concurrent
17
+ import copy
16
18
  import csv
17
19
  import os
18
- import copy
19
20
  import threading
20
21
  import traceback
21
22
  from datetime import datetime, timezone, timedelta
22
23
 
23
- import concurrent
24
24
  from msprobe.core.common.const import Const, FileCheckConst
25
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
26
- from msprobe.core.common.log import logger
27
25
  from msprobe.core.common.decorator import recursion_depth_decorator
26
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, check_path_before_create
27
+ from msprobe.core.common.log import logger
28
28
 
29
29
  lock = threading.Lock()
30
30
 
@@ -40,6 +40,7 @@ class DataWriter:
40
40
  self.debug_file_path = None
41
41
  self.dump_error_info_path = None
42
42
  self.flush_size = 1000
43
+ self.md5_flush_size = 5000
43
44
  self.larger_flush_size = 20000
44
45
  self.cache_data = {}
45
46
  self.cache_stack = {}
@@ -49,6 +50,7 @@ class DataWriter:
49
50
  self._error_log_initialized = False
50
51
  self._cache_logged_error_types = set()
51
52
  self.crc32_stack_list = []
53
+ self.data_updated = False
52
54
 
53
55
  @staticmethod
54
56
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -60,7 +62,7 @@ class DataWriter:
60
62
  spawn_writer = csv.writer(csv_file)
61
63
  if not is_exists:
62
64
  spawn_writer.writerow(result_header)
63
- spawn_writer.writerows([result,])
65
+ spawn_writer.writerows([result, ])
64
66
  is_new_file = not is_exists
65
67
  if is_new_file:
66
68
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -190,7 +192,7 @@ class DataWriter:
190
192
  summary_mode = getattr(cfg, "summary_mode", None)
191
193
 
192
194
  if summary_mode == Const.MD5:
193
- threshold = self.flush_size
195
+ threshold = self.md5_flush_size
194
196
  else:
195
197
  threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
196
198
 
@@ -238,6 +240,7 @@ class DataWriter:
238
240
  logger.warning(f"The dump data({dump_data}) should be a dict.")
239
241
  return
240
242
 
243
+ self.data_updated = True
241
244
  key = next(iter(new_data.keys()))
242
245
  if key in dump_data:
243
246
  dump_data.get(key).update(new_data.get(key))
@@ -246,6 +249,7 @@ class DataWriter:
246
249
 
247
250
  def update_stack(self, name, stack_data):
248
251
  with lock:
252
+ self.data_updated = True
249
253
  api_list = self.cache_stack.get(stack_data)
250
254
  if api_list is None:
251
255
  self.cache_stack.update({stack_data: [name]})
@@ -254,10 +258,12 @@ class DataWriter:
254
258
 
255
259
  def update_construct(self, new_data):
256
260
  with lock:
261
+ self.data_updated = True
257
262
  self.cache_construct.update(new_data)
258
263
 
259
264
  def update_debug(self, new_data):
260
265
  with lock:
266
+ self.data_updated = True
261
267
  self.cache_debug['data'].update(new_data)
262
268
 
263
269
  def write_data_json(self, file_path):
@@ -324,17 +330,21 @@ class DataWriter:
324
330
  stat_result = self.flush_stat_stack()
325
331
  # 遍历 cache_data,将占位符替换为最终统计值
326
332
  if stat_result:
333
+ self.data_updated = True
327
334
  self._replace_stat_placeholders(self.cache_data, stat_result)
328
335
  if self.cache_debug:
329
336
  self._replace_stat_placeholders(self.cache_debug, stat_result)
330
337
 
331
- # 2) 再 flush CRC32
332
338
  crc32_result = self.flush_crc32_stack()
333
339
  if crc32_result:
340
+ self.data_updated = True
334
341
  self._replace_crc32_placeholders(self.cache_data, crc32_result)
335
342
  if self.cache_debug:
336
343
  self._replace_crc32_placeholders(self.cache_debug, crc32_result)
337
344
 
345
+ if not self.data_updated:
346
+ return
347
+
338
348
  if self.cache_data:
339
349
  self.write_data_json(self.dump_file_path)
340
350
  if self.cache_stack:
@@ -343,4 +353,4 @@ class DataWriter:
343
353
  self.write_construct_info_json(self.construct_file_path)
344
354
  if self.cache_debug:
345
355
  self.write_debug_info_json(self.debug_file_path)
346
-
356
+ self.data_updated = False
@@ -69,8 +69,7 @@ class BaseScope(ABC):
69
69
  self.scope = scope
70
70
  self.api_list = api_list
71
71
 
72
- @staticmethod
73
- def rectify_args(scope, api_list):
72
+ def rectify_args(self, scope, api_list):
74
73
  if not isinstance(api_list, list):
75
74
  raise ScopeException(ScopeException.InvalidApiStr,
76
75
  f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
@@ -104,12 +103,11 @@ class BaseScope(ABC):
104
103
 
105
104
 
106
105
  class ListScope(BaseScope):
107
- @staticmethod
108
- def rectify_args(scope, api_list):
106
+ def rectify_args(self, scope, api_list):
109
107
  if scope and api_list:
110
108
  raise ScopeException(ScopeException.ArgConflict,
111
109
  f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
112
- return super(ListScope, ListScope).rectify_args(scope, api_list)
110
+ return super().rectify_args(scope, api_list)
113
111
 
114
112
  def check(self, name):
115
113
  if not self.scope or name in self.scope:
@@ -147,7 +145,7 @@ class RangeScope(BaseScope, ABC):
147
145
  f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
148
146
 
149
147
  def rectify_args(self, scope, api_list):
150
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
148
+ scope, api_list = super().rectify_args(scope, api_list)
151
149
  if scope and len(scope) != 2:
152
150
  raise ScopeException(ScopeException.InvalidScope,
153
151
  f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")