mindstudio-probe 8.2.0__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +63 -61
  3. msprobe/README.md +4 -4
  4. msprobe/core/common/const.py +6 -0
  5. msprobe/core/common/db_manager.py +35 -4
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/megatron_utils.py +59 -0
  8. msprobe/core/common/utils.py +14 -3
  9. msprobe/core/compare/diff_analyze/first_diff_analyze.py +16 -4
  10. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  11. msprobe/core/compare/find_first/analyzer.py +8 -7
  12. msprobe/core/compare/find_first/graph.py +11 -3
  13. msprobe/core/compare/find_first/utils.py +3 -2
  14. msprobe/core/compare/highlight.py +13 -6
  15. msprobe/core/compare/multiprocessing_compute.py +17 -10
  16. msprobe/core/compare/utils.py +14 -5
  17. msprobe/core/data_dump/data_collector.py +18 -21
  18. msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
  19. msprobe/core/data_dump/json_writer.py +18 -8
  20. msprobe/core/data_dump/scope.py +4 -6
  21. msprobe/core/hook_manager.py +21 -0
  22. msprobe/core/service.py +2 -0
  23. msprobe/core/single_save/single_comparator.py +16 -3
  24. msprobe/docs/01.installation.md +7 -5
  25. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  26. msprobe/docs/06.data_dump_MindSpore.md +1 -1
  27. msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
  28. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  29. msprobe/docs/19.monitor.md +2 -0
  30. msprobe/docs/21.visualization_PyTorch.md +15 -80
  31. msprobe/docs/22.visualization_MindSpore.md +20 -104
  32. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  33. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  34. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  35. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  36. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  37. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  38. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  39. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  40. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  41. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  42. msprobe/mindspore/cell_processor.py +33 -5
  43. msprobe/mindspore/compare/common_dir_compare.py +22 -26
  44. msprobe/mindspore/debugger/precision_debugger.py +1 -1
  45. msprobe/mindspore/dump/cell_dump_process.py +73 -62
  46. msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
  47. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
  48. msprobe/pytorch/compare/utils.py +2 -1
  49. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  50. msprobe/pytorch/dump/module_dump/module_processer.py +15 -8
  51. msprobe/pytorch/monitor/module_hook.py +28 -9
  52. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  53. msprobe/visualization/builder/graph_builder.py +169 -64
  54. msprobe/visualization/builder/graph_merger.py +0 -1
  55. msprobe/visualization/builder/msprobe_adapter.py +1 -1
  56. msprobe/visualization/db_utils.py +25 -2
  57. msprobe/visualization/graph/base_node.py +0 -24
  58. msprobe/visualization/graph/graph.py +5 -14
  59. msprobe/visualization/graph_service.py +29 -53
  60. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  61. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  62. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  63. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
-
17
+ import glob
18
+ import tempfile
18
19
  import mindspore as ms
19
20
  from mindspore import hal, ops, Tensor
20
21
  from mindspore.ops.primitive import _run_op
@@ -28,6 +29,7 @@ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
28
29
  import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
29
30
 
30
31
  tensordump_flag = True
32
+ DEFAULT_RANK_DIR = "rank0"
31
33
  try:
32
34
  from mindspore._c_expression import _tensordump_set_step
33
35
  except ImportError:
@@ -41,8 +43,6 @@ except ImportError:
41
43
 
42
44
 
43
45
  class GraphModeCellDump:
44
- task = CoreConst.STATISTICS
45
-
46
46
  def __init__(self, config: DebuggerConfig, model, strict=True):
47
47
  self.net = model
48
48
  self.white_list = []
@@ -55,29 +55,40 @@ class GraphModeCellDump:
55
55
  self.list = config.list
56
56
  self.data_mode = config.data_mode
57
57
  self.file_format = config.file_format
58
- GraphModeCellDump.task = config.task
59
58
  self.summary_mode = config.summary_mode
59
+ self.task = config.task
60
60
  self.check_config(strict)
61
61
  self.set_step()
62
62
 
63
63
  @staticmethod
64
- def step():
64
+ def step(dump_path, step_list, task):
65
65
  # 更新TensorDump Step
66
- if GraphModeCellDump.task == CoreConst.TENSOR:
66
+ if task == CoreConst.TENSOR:
67
67
  hal.synchronize()
68
68
  temp_tensor = ms.Tensor([1], dtype=ms.float32)
69
- step_flag = "<tensordump-update-step>"
70
- _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
71
- ops.tensordump(step_flag, temp_tensor)
69
+ rank_id = os.environ.get('RANK_ID')
70
+ rank_dir = DEFAULT_RANK_DIR
71
+
72
+ if rank_id is not None:
73
+ rank_dir = CoreConst.RANK + str(rank_id)
74
+
75
+ with tempfile.TemporaryDirectory(dir=dump_path, prefix=rank_dir) as temp_dir:
76
+ save_file_flag = f"{temp_dir}/step_{Runtime.step_count}"
77
+ _run_op(ops.TensorDump(), "TensorDump", (save_file_flag, temp_tensor))
78
+ step_flag = "<tensordump-update-step>"
79
+ _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
80
+ ops.tensordump(step_flag, temp_tensor)
81
+ cellDumperWithDumpGradient.process_step(dump_path, temp_dir, Runtime.step_count, step_list)
72
82
 
73
83
  # 更新静态图KBK dump的step数
74
- if GraphModeCellDump.task == CoreConst.STATISTICS:
84
+ if task == CoreConst.STATISTICS:
75
85
  if not graph_step_flag:
76
86
  raise Exception(
77
87
  "Importing _dump_step failed, "
78
88
  "please use the latest version package of MindSpore."
79
89
  )
80
90
  _dump_step(1)
91
+ cellDumperWithDumpGradient.process_statistics_step(dump_path, Runtime.step_count, step_list)
81
92
 
82
93
  def check_config(self, strict):
83
94
  if not self.net:
@@ -203,10 +203,12 @@ class MindsporeHookManager(BaseHookManager):
203
203
  return
204
204
 
205
205
  with ThreadSafe():
206
+ original_state = self.ensure_gc_enabled()
206
207
  BaseHookManager.inner_switch[tid] = True
207
208
  module_input = ModuleBackwardInputs(grad_input=grad_input)
208
209
  self.data_collector.update_api_or_module_name(full_name)
209
210
  self.data_collector.backward_input_data_collect(full_name, module, self._pid, module_input)
210
211
  BaseHookManager.inner_switch[tid] = False
212
+ self.restore_gc_state(original_state)
211
213
 
212
214
  return backward_pre_hook
@@ -35,7 +35,8 @@ def read_pt_data(dir_path, file_name):
35
35
  data_value = load_pt(data_path, to_cpu=True).detach()
36
36
  except RuntimeError as e:
37
37
  # 这里捕获 load_pt 中抛出的异常
38
- logger.error(f"Failed to load the .pt file at {data_path}.")
38
+ data_path_file_name = os.path.basename(data_path)
39
+ logger.error(f"Failed to load the .pt file at {data_path_file_name}.")
39
40
  raise CompareException(CompareException.INVALID_FILE_ERROR) from e
40
41
  except AttributeError as e:
41
42
  # 这里捕获 detach 方法抛出的异常
@@ -24,8 +24,11 @@ from msprobe.pytorch.common.log import logger
24
24
 
25
25
 
26
26
  def wrap_setup_backward_hook(func):
27
- def requires_clone(tensor):
28
- return isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
27
+ def requires_clone(tensor, need_check_leaf=False):
28
+ need_clone = isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
29
+ if need_check_leaf:
30
+ need_clone &= tensor.grad_fn is not None
31
+ return need_clone
29
32
 
30
33
  @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
31
34
  def parse_tensor(item, tensor_list):
@@ -39,20 +42,20 @@ def wrap_setup_backward_hook(func):
39
42
  parse_tensor(value, tensor_list)
40
43
 
41
44
  @recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
42
- def rebuild_args(item, tensor_iter):
43
- if requires_clone(item):
45
+ def rebuild_args(item, tensor_iter, need_check_leaf=False):
46
+ if requires_clone(item, need_check_leaf):
44
47
  result = next(tensor_iter)
45
48
  if hasattr(result, "_base") and result._base is not None:
46
49
  if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
47
50
  torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
48
- return result
51
+ return result
49
52
  if isinstance(item, list):
50
53
  for index, value in enumerate(item):
51
- item[index] = rebuild_args(value, tensor_iter)
54
+ item[index] = rebuild_args(value, tensor_iter, need_check_leaf=True)
52
55
  return item
53
56
  if isinstance(item, dict):
54
57
  for key, value in item.items():
55
- item[key] = rebuild_args(value, tensor_iter)
58
+ item[key] = rebuild_args(value, tensor_iter, need_check_leaf=True)
56
59
  return item
57
60
  if isinstance(item, tuple):
58
61
  if hasattr(item, '_fields'):
@@ -23,13 +23,15 @@ from torch.utils.hooks import BackwardHook, RemovableHandle
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.common.runtime import Runtime
25
25
  from msprobe.core.common.utils import ModuleQueue, ThreadSafe
26
+ from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
26
27
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
27
28
  from msprobe.pytorch.common.log import logger
28
29
  from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
29
30
  from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
30
31
 
31
32
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
32
- if torch_version_above_or_equal_2:
33
+ torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
34
+ if torch_version_above_or_equal_21:
33
35
  from torch.utils.checkpoint import _StopRecomputationError
34
36
 
35
37
 
@@ -61,7 +63,7 @@ def wrap_forward_with_hook_safety(module):
61
63
  hook_fn = list(module._forward_hooks.values())[0]
62
64
  hook_fn(module, args, kwargs, exception_output)
63
65
  raise e
64
- if torch_version_above_or_equal_2:
66
+ if torch_version_above_or_equal_21:
65
67
  module.forward = wrapped_forward
66
68
 
67
69
 
@@ -82,6 +84,8 @@ class ModuleProcesser:
82
84
  from megatron.core.pipeline_parallel import schedules
83
85
  origin_func_id = id(schedules.deallocate_output_tensor)
84
86
  schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
87
+ schedules.forward_step = wrap_megatron_step(schedules.forward_step)
88
+ schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
85
89
  for module in list(sys.modules.values()):
86
90
  if module.__name__ == 'schedules':
87
91
  continue
@@ -258,14 +262,16 @@ class ModuleProcesser:
258
262
  ModuleProcesser.module_stack[tid] = []
259
263
 
260
264
  if self.module_stack[tid]:
261
- ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
265
+ ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1] if not is_megatron() \
266
+ else [self.module_stack[tid][-1], get_micro_step()]
262
267
  else:
263
268
  parent_name = ModuleProcesser.module_queue.find_last(full_name)
264
- ModuleProcesser.module_node[full_name] = parent_name
269
+ ModuleProcesser.module_node[full_name] = parent_name if not is_megatron() \
270
+ else [parent_name, get_micro_step()]
265
271
 
266
272
  ModuleProcesser.module_queue.add_name(full_name)
267
273
  ModuleProcesser.module_stack[tid].append(full_name)
268
- ModuleProcesser.api_parent_node[tid] = full_name
274
+ ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
269
275
  if self.scope:
270
276
  self.scope.begin_module(full_name)
271
277
 
@@ -273,14 +279,15 @@ class ModuleProcesser:
273
279
  tid = threading.get_ident()
274
280
  if torch_version_above_or_equal_2 or is_forward:
275
281
  ModuleProcesser.module_queue.remove_name(full_name)
276
- ModuleProcesser.api_parent_node[tid] = None
282
+ ModuleProcesser.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
277
283
  if self.module_stack.get(tid):
278
284
  ModuleProcesser.module_stack[tid].pop()
279
285
  if self.module_stack.get(tid):
280
- ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
286
+ ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1] if not is_megatron() \
287
+ else [ModuleProcesser.module_stack[tid][-1], get_micro_step()]
281
288
  if self.scope:
282
289
  self.scope.end_module(full_name)
283
290
  else:
284
291
  if self.scope:
285
292
  self.scope.begin_module(full_name)
286
- ModuleProcesser.api_parent_node[tid] = full_name
293
+ ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
@@ -19,12 +19,14 @@ import importlib
19
19
  from collections import defaultdict
20
20
  from datetime import datetime
21
21
  from functools import partial
22
+ from itertools import cycle
22
23
 
23
24
  import pytz
24
25
  import torch
25
26
  import torch.distributed as dist
26
27
  import pandas as pd
27
28
  from torch.utils.hooks import BackwardHook
29
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
28
30
 
29
31
  from msprobe.core.common.const import MonitorConst, Const
30
32
  from msprobe.core.common.file_utils import load_json, save_json, make_dir
@@ -229,6 +231,8 @@ class TrainerMon:
229
231
  self.duplicate_param = {}
230
232
  self.name2tag = {}
231
233
  self.param_name_call_id = {}
234
+ self.flat_prefix_names = []
235
+ self.flat_prefix_reverse_iter = None
232
236
  self.call_id = 0
233
237
  self.module_struct = defaultdict(dict)
234
238
  self.grad_accs = []
@@ -945,13 +949,20 @@ class TrainerMon:
945
949
  return False
946
950
 
947
951
  def _register_chunk(self, model_chunk, prefix):
952
+ if isinstance(model_chunk, FSDP):
953
+ if not model_chunk._use_orig_params:
954
+ raise ValueError("Only Support fsdp1 with use_orig_params=True")
955
+ self.fsdp_wrapped_module = True
948
956
  for (param_name, param) in model_chunk.named_parameters():
949
957
  if not param.requires_grad:
950
958
  continue
951
- if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
952
- self.fsdp_wrapped_module = True
953
959
  if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor":
954
960
  self.fsdp2_wrapped_module = True
961
+ if self.fsdp_wrapped_module: # FSDP1需要记录完整的不被target限制的flat权重前缀名,以供后续对flat解包
962
+ flat_prefix_name, _ = param_name.rsplit(MonitorConst.FSDP_FLAT_SEP, 1)
963
+ if flat_prefix_name not in self.flat_prefix_names:
964
+ self.flat_prefix_names.append(flat_prefix_name)
965
+
955
966
  if self._is_target_param(param_name, param, prefix):
956
967
  name = prefix + squash_param_name(param_name, self.squash_name)
957
968
  if name in self.param2name.values():
@@ -975,6 +986,8 @@ class TrainerMon:
975
986
  k: get_summary_writer_tag_name(name, k, self.rank)
976
987
  for k in keywords
977
988
  }
989
+ if self.fsdp_wrapped_module:
990
+ self.flat_prefix_reverse_iter = cycle(reversed(self.flat_prefix_names)) # post_backward_hook调用顺序是反向的
978
991
 
979
992
  def _register_param_name(self):
980
993
  for vpp_stage, model_chunk in enumerate(self.model):
@@ -1224,17 +1237,22 @@ class TrainerMon:
1224
1237
  每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
1225
1238
  因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
1226
1239
  """
1240
+
1227
1241
  def patch_post_backward_hook(_post_backward_hook):
1228
1242
  def wrapper(state, handle, *unused):
1229
1243
  grad_dict = {}
1230
- offset = 0
1231
- for param, name in self.param2name.items():
1232
- limit = param.numel()
1233
- if not limit:
1244
+ local_names = handle.flat_param._fqns
1245
+ offsets = handle._get_flat_param_offsets()
1246
+ shapes = handle.flat_param._shapes
1247
+ flat_prefix = next(self.flat_prefix_reverse_iter)
1248
+ for local_name, (start, end), local_shape in zip(local_names, offsets, shapes):
1249
+ grad_clip = handle.flat_param.grad[start:end + 1]
1250
+ grad = grad_clip.reshape(local_shape)
1251
+ total_name = f"{flat_prefix}{MonitorConst.FSDP_FLAT_SEP}{local_name}"
1252
+ if total_name not in self.origin2squash:
1253
+ logger.warning(f"{total_name} not in model.named_parameters(), skip.")
1234
1254
  continue
1235
- grad = handle.flat_param.grad[offset:offset + limit]
1236
- offset += limit
1237
- tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
1255
+ tag = self.name2tag.get(self.origin2squash[total_name], {}).get(MonitorConst.PRE_GRAD)
1238
1256
  if tag is None:
1239
1257
  continue
1240
1258
  grad_dict[tag] = grad
@@ -1242,6 +1260,7 @@ class TrainerMon:
1242
1260
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1243
1261
  out = _post_backward_hook(state, handle, *unused)
1244
1262
  return out
1263
+
1245
1264
  return wrapper
1246
1265
 
1247
1266
  logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
@@ -17,7 +17,7 @@ import json
17
17
  import os
18
18
  import time
19
19
  import multiprocessing
20
- from multiprocessing import Pool
20
+ from multiprocessing import Pool, Lock
21
21
 
22
22
  import torch
23
23
  from torch.utils._python_dispatch import TorchDispatchMode
@@ -39,6 +39,7 @@ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, ge
39
39
  from msprobe.pytorch.online_dispatch.compare import Comparator
40
40
  from msprobe.core.common.utils import check_str_param, safe_get_value
41
41
 
42
+ child_global_lock = None
42
43
  current_time = time.strftime("%Y%m%d%H%M%S")
43
44
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
44
45
  DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
@@ -86,14 +87,14 @@ class PtdbgDispatch(TorchDispatchMode):
86
87
  yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
87
88
  self.get_ops(yaml_path)
88
89
 
89
- self.lock = None
90
+ self.lock = Lock() if process_num > 0 else None
90
91
  max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
91
92
  if process_num > max_process_num:
92
93
  logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
93
94
  raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
94
95
  f'but got {process_num}!')
95
96
  if process_num > 0:
96
- self.pool = Pool(process_num)
97
+ self.pool = Pool(process_num, initializer=self._init_child_process, initargs=(self.lock,))
97
98
  if debug:
98
99
  logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
99
100
  f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
@@ -114,18 +115,17 @@ class PtdbgDispatch(TorchDispatchMode):
114
115
  logger.error("Please check train log, An exception may have occurred!")
115
116
  return
116
117
  check_file_or_directory_path(summary_path, False)
117
- fp_handle = FileOpen(summary_path, "r")
118
- while True:
119
- json_line_data = fp_handle.readline()
120
- if json_line_data == '\n':
121
- continue
122
- if len(json_line_data) == 0:
123
- break
124
- msg = json.loads(json_line_data)
125
- if len(msg) < 2:
126
- raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
127
- self.all_summary[msg[0]] = msg[1]
128
- fp_handle.close()
118
+ with FileOpen(summary_path, "r") as fp_handle:
119
+ while True:
120
+ json_line_data = fp_handle.readline()
121
+ if json_line_data == '\n':
122
+ continue
123
+ if len(json_line_data) == 0:
124
+ break
125
+ msg = json.loads(json_line_data)
126
+ if len(msg) < 2:
127
+ raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
128
+ self.all_summary[msg[0]] = msg[1]
129
129
 
130
130
  if self.debug_flag:
131
131
  input_num = 0
@@ -163,11 +163,16 @@ class PtdbgDispatch(TorchDispatchMode):
163
163
 
164
164
  call_stack = get_callstack()
165
165
  self.call_stack_list.append(call_stack)
166
- self.api_index += 1
167
- if aten_api not in self.single_api_index_dict:
168
- self.single_api_index_dict[aten_api] = 1
169
- else:
170
- self.single_api_index_dict[aten_api] += 1
166
+
167
+ self.lock.acquire() if self.process_num > 0 else None
168
+ try:
169
+ self.api_index += 1
170
+ if aten_api not in self.single_api_index_dict:
171
+ self.single_api_index_dict[aten_api] = 1
172
+ else:
173
+ self.single_api_index_dict[aten_api] += 1
174
+ finally:
175
+ self.lock.release() if self.process_num > 0 else None
171
176
 
172
177
  run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
173
178
 
@@ -180,7 +185,7 @@ class PtdbgDispatch(TorchDispatchMode):
180
185
  cpu_kwargs = []
181
186
  data_to_cpu(args, 0, cpu_args)
182
187
  data_to_cpu(kwargs, 0, cpu_kwargs)
183
-
188
+
184
189
  cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
185
190
  cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
186
191
 
@@ -194,7 +199,12 @@ class PtdbgDispatch(TorchDispatchMode):
194
199
  try:
195
200
  cpu_out = func(*cpu_args, **cpu_kwargs)
196
201
  except RuntimeError as e:
197
- self.api_index -= 1
202
+ self.lock.acquire() if self.process_num > 0 else None
203
+ try:
204
+ self.api_index -= 1
205
+ self.single_api_index_dict[aten_api] -= 1
206
+ finally:
207
+ self.lock.release() if self.process_num > 0 else None
198
208
  logger.warning(f"RuntimeError: {e}")
199
209
  logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
200
210
  return npu_out
@@ -215,7 +225,7 @@ class PtdbgDispatch(TorchDispatchMode):
215
225
  run_param.process_flag = True
216
226
  if self.check_fun(func, run_param):
217
227
  data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
218
- self.lock)
228
+ child_global_lock)
219
229
  self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
220
230
  error_callback=error_call)
221
231
  else:
@@ -233,12 +243,20 @@ class PtdbgDispatch(TorchDispatchMode):
233
243
  return True
234
244
  return False
235
245
 
246
+ @staticmethod
247
+ def _init_child_process(lock):
248
+ global child_global_lock
249
+ child_global_lock = lock
250
+
236
251
  def get_dir_name(self, tag):
237
252
  # guarantee file uniqueness
238
253
  time.sleep(1)
239
- time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
254
+ # 时间格式:年-月-日-时-分-秒-毫秒(精确到千分之一秒)
255
+ time_now = time.strftime("%Y%m%d%H%M%S%f", time.localtime(time.time()))[:-3] # 取前3位毫秒
256
+
240
257
  if tag is None or not isinstance(tag, str):
241
258
  logger.warning('There is not tag or the type of tag is not string.')
259
+ # 目录名格式:msprobe_rank{设备ID}_{毫秒时间戳}
242
260
  dir_name = f'msprobe_rank{self.device_id}_{time_now}'
243
261
  else:
244
262
  dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'