mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
  3. msprobe/README.md +7 -5
  4. msprobe/core/common/const.py +6 -0
  5. msprobe/core/common/db_manager.py +35 -4
  6. msprobe/core/common/file_utils.py +105 -27
  7. msprobe/core/common/framework_adapter.py +7 -6
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/utils.py +14 -3
  10. msprobe/core/compare/find_first/analyzer.py +8 -7
  11. msprobe/core/compare/find_first/graph.py +11 -3
  12. msprobe/core/compare/find_first/utils.py +2 -1
  13. msprobe/core/compare/highlight.py +13 -6
  14. msprobe/core/compare/multiprocessing_compute.py +17 -10
  15. msprobe/core/compare/utils.py +14 -5
  16. msprobe/core/data_dump/data_collector.py +18 -21
  17. msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
  18. msprobe/core/data_dump/json_writer.py +18 -8
  19. msprobe/core/data_dump/scope.py +4 -6
  20. msprobe/core/hook_manager.py +37 -3
  21. msprobe/core/service.py +18 -5
  22. msprobe/core/single_save/single_comparator.py +16 -3
  23. msprobe/docs/01.installation.md +7 -5
  24. msprobe/docs/02.config_introduction.md +14 -1
  25. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  26. msprobe/docs/06.data_dump_MindSpore.md +1 -1
  27. msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
  29. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  30. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  31. msprobe/docs/19.monitor.md +2 -0
  32. msprobe/docs/21.visualization_PyTorch.md +15 -80
  33. msprobe/docs/22.visualization_MindSpore.md +20 -104
  34. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  35. msprobe/docs/25.tool_function_introduction.md +1 -0
  36. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  37. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  38. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  39. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  40. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  41. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  42. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  43. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  44. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  45. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  46. msprobe/mindspore/cell_processor.py +33 -5
  47. msprobe/mindspore/compare/common_dir_compare.py +22 -26
  48. msprobe/mindspore/compare/utils.py +1 -2
  49. msprobe/mindspore/debugger/precision_debugger.py +1 -1
  50. msprobe/mindspore/dump/cell_dump_process.py +73 -62
  51. msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
  52. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
  53. msprobe/msprobe.py +6 -4
  54. msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
  55. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
  58. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  59. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
  60. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  61. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
  62. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
  63. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
  64. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  65. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
  66. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  67. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
  68. msprobe/pytorch/attl_manager.py +65 -0
  69. msprobe/pytorch/common/utils.py +22 -2
  70. msprobe/pytorch/compare/utils.py +3 -3
  71. msprobe/pytorch/debugger/debugger_config.py +10 -0
  72. msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
  73. msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
  74. msprobe/pytorch/hook_module/api_register.py +6 -1
  75. msprobe/pytorch/monitor/module_hook.py +28 -9
  76. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  77. msprobe/pytorch/pt_config.py +57 -2
  78. msprobe/pytorch/pytorch_service.py +11 -2
  79. msprobe/visualization/builder/graph_builder.py +170 -64
  80. msprobe/visualization/builder/graph_merger.py +0 -1
  81. msprobe/visualization/builder/msprobe_adapter.py +1 -1
  82. msprobe/visualization/db_utils.py +25 -2
  83. msprobe/visualization/graph/base_node.py +0 -24
  84. msprobe/visualization/graph/graph.py +5 -14
  85. msprobe/visualization/graph_service.py +29 -53
  86. msprobe/visualization/utils.py +11 -1
  87. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
  88. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
  89. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
  90. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,6 @@ except ImportError:
39
39
  else:
40
40
  is_gpu = False
41
41
 
42
-
43
42
  torch_without_guard_version = torch.__version__ >= '2.1'
44
43
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
45
44
 
@@ -416,7 +415,8 @@ def is_recomputation():
416
415
 
417
416
  # Identify indices in the call stack where the specific function is being executed
418
417
  for idx, frame_info in enumerate(call_stack):
419
- if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
418
+ if (frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward' and
419
+ "megatron" in frame_info.filename):
420
420
  backward_function_indices.append(idx)
421
421
 
422
422
  # Check if the execution is within 'torch/autograd/function.py' file
@@ -471,3 +471,23 @@ def register_forward_hook(module, forward_hook):
471
471
  module.register_forward_hook(forward_hook, with_kwargs=True)
472
472
  else:
473
473
  module.register_forward_hook(forward_hook)
474
+
475
+
476
+ def save_api_data(api_data):
477
+ """Save data to io stream"""
478
+ try:
479
+ io_buff = io.BytesIO()
480
+ torch.save(api_data, io_buff)
481
+ except Exception as e:
482
+ raise RuntimeError(f"save api_data to io_buff failed") from e
483
+ return io_buff
484
+
485
+
486
+ def load_api_data(api_data_bytes):
487
+ """Load data from bytes stream"""
488
+ try:
489
+ buffer = io.BytesIO(api_data_bytes)
490
+ buffer = torch.load(buffer, map_location="cpu")
491
+ except Exception as e:
492
+ raise RuntimeError(f"load api_data from bytes failed") from e
493
+ return buffer
@@ -27,15 +27,15 @@ def read_pt_data(dir_path, file_name):
27
27
  return None
28
28
 
29
29
  data_path = os.path.join(dir_path, file_name)
30
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
31
- FileCheckConst.PT_SUFFIX, False)
30
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX)
32
31
  data_path = path_checker.common_check()
33
32
  try:
34
33
  # detach because numpy can not process gradient information
35
34
  data_value = load_pt(data_path, to_cpu=True).detach()
36
35
  except RuntimeError as e:
37
36
  # 这里捕获 load_pt 中抛出的异常
38
- logger.error(f"Failed to load the .pt file at {data_path}.")
37
+ data_path_file_name = os.path.basename(data_path)
38
+ logger.error(f"Failed to load the .pt file at {data_path_file_name}.")
39
39
  raise CompareException(CompareException.INVALID_FILE_ERROR) from e
40
40
  except AttributeError as e:
41
41
  # 这里捕获 detach 方法抛出的异常
@@ -48,6 +48,16 @@ class DebuggerConfig:
48
48
  "max_sample": task_config.max_sample
49
49
  }
50
50
 
51
+ self.online_run_ut = False
52
+ if self.task == Const.TENSOR:
53
+ # dump api tensor and collaborate with online run_ut
54
+ self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
55
+ self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
56
+ self.tls_path = task_config.tls_path if task_config.tls_path else ""
57
+ self.host = task_config.host if task_config.host else ""
58
+ self.port = task_config.port if task_config.port else -1
59
+ self.online_run_ut_recompute = task_config.online_run_ut_recompute \
60
+ if isinstance(task_config.online_run_ut_recompute, bool) else False
51
61
 
52
62
  self.check()
53
63
  self._check_statistics_config(task_config)
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from functools import wraps
17
+ from typing import Any, Callable
17
18
 
18
19
  import torch
19
20
  from torch.utils.hooks import BackwardHook
@@ -21,11 +22,17 @@ from torch.utils.hooks import BackwardHook
21
22
  from msprobe.core.common.const import Const
22
23
  from msprobe.core.common.decorator import recursion_depth_decorator
23
24
  from msprobe.pytorch.common.log import logger
25
+ from msprobe.pytorch.hook_module.api_register import get_api_register
26
+
27
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
28
 
25
29
 
26
30
  def wrap_setup_backward_hook(func):
27
- def requires_clone(tensor):
28
- return isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
31
+ def requires_clone(tensor, need_check_leaf=False):
32
+ need_clone = isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
33
+ if need_check_leaf:
34
+ need_clone &= tensor.grad_fn is not None
35
+ return need_clone
29
36
 
30
37
  @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
31
38
  def parse_tensor(item, tensor_list):
@@ -39,20 +46,20 @@ def wrap_setup_backward_hook(func):
39
46
  parse_tensor(value, tensor_list)
40
47
 
41
48
  @recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
42
- def rebuild_args(item, tensor_iter):
43
- if requires_clone(item):
49
+ def rebuild_args(item, tensor_iter, need_check_leaf=False):
50
+ if requires_clone(item, need_check_leaf):
44
51
  result = next(tensor_iter)
45
52
  if hasattr(result, "_base") and result._base is not None:
46
53
  if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
47
54
  torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
48
- return result
55
+ return result
49
56
  if isinstance(item, list):
50
57
  for index, value in enumerate(item):
51
- item[index] = rebuild_args(value, tensor_iter)
58
+ item[index] = rebuild_args(value, tensor_iter, need_check_leaf=True)
52
59
  return item
53
60
  if isinstance(item, dict):
54
61
  for key, value in item.items():
55
- item[key] = rebuild_args(value, tensor_iter)
62
+ item[key] = rebuild_args(value, tensor_iter, need_check_leaf=True)
56
63
  return item
57
64
  if isinstance(item, tuple):
58
65
  if hasattr(item, '_fields'):
@@ -89,3 +96,23 @@ def wrap_setup_backward_hook(func):
89
96
  def wrap_setup_input_output_hook():
90
97
  BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
91
98
  BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
99
+
100
+
101
+ def get_apply_func_wrapper(original_func: Callable) -> Callable:
102
+ @wraps(original_func)
103
+ def wrapped_apply(*args, **kwargs) -> Any:
104
+ api_register = get_api_register()
105
+ if api_register:
106
+ api_register.restore_inner_used_api()
107
+ result = original_func(*args, **kwargs)
108
+ if api_register:
109
+ api_register.register_inner_used_api()
110
+ return result
111
+
112
+ return wrapped_apply
113
+
114
+
115
+ def wrap_backward_hook_function_apply():
116
+ if torch_version_above_or_equal_2:
117
+ original_apply = torch.nn.modules._functions.BackwardHookFunction.apply
118
+ torch.nn.modules._functions.BackwardHookFunction.apply = get_apply_func_wrapper(original_apply)
@@ -13,23 +13,29 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import threading
17
16
  import sys
17
+ import threading
18
18
  from collections import OrderedDict
19
19
 
20
20
  import torch
21
21
  from torch.utils.hooks import BackwardHook, RemovableHandle
22
22
 
23
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
24
25
  from msprobe.core.common.runtime import Runtime
25
26
  from msprobe.core.common.utils import ModuleQueue, ThreadSafe
26
27
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
27
28
  from msprobe.pytorch.common.log import logger
28
29
  from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
29
- from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
30
+ from msprobe.pytorch.dump.module_dump.hook_wrapper import (
31
+ wrap_setup_input_output_hook,
32
+ wrap_backward_hook_function_apply
33
+ )
34
+
30
35
 
31
36
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
32
- if torch_version_above_or_equal_2:
37
+ torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
38
+ if torch_version_above_or_equal_21:
33
39
  from torch.utils.checkpoint import _StopRecomputationError
34
40
 
35
41
 
@@ -61,7 +67,8 @@ def wrap_forward_with_hook_safety(module):
61
67
  hook_fn = list(module._forward_hooks.values())[0]
62
68
  hook_fn(module, args, kwargs, exception_output)
63
69
  raise e
64
- if torch_version_above_or_equal_2:
70
+
71
+ if torch_version_above_or_equal_21:
65
72
  module.forward = wrapped_forward
66
73
 
67
74
 
@@ -78,10 +85,13 @@ class ModuleProcesser:
78
85
  def __init__(self, scope):
79
86
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
80
87
  wrap_setup_input_output_hook()
88
+ wrap_backward_hook_function_apply()
81
89
  try:
82
90
  from megatron.core.pipeline_parallel import schedules
83
91
  origin_func_id = id(schedules.deallocate_output_tensor)
84
92
  schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
93
+ schedules.forward_step = wrap_megatron_step(schedules.forward_step)
94
+ schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
85
95
  for module in list(sys.modules.values()):
86
96
  if module.__name__ == 'schedules':
87
97
  continue
@@ -258,14 +268,16 @@ class ModuleProcesser:
258
268
  ModuleProcesser.module_stack[tid] = []
259
269
 
260
270
  if self.module_stack[tid]:
261
- ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
271
+ ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1] if not is_megatron() \
272
+ else [self.module_stack[tid][-1], get_micro_step()]
262
273
  else:
263
274
  parent_name = ModuleProcesser.module_queue.find_last(full_name)
264
- ModuleProcesser.module_node[full_name] = parent_name
275
+ ModuleProcesser.module_node[full_name] = parent_name if not is_megatron() \
276
+ else [parent_name, get_micro_step()]
265
277
 
266
278
  ModuleProcesser.module_queue.add_name(full_name)
267
279
  ModuleProcesser.module_stack[tid].append(full_name)
268
- ModuleProcesser.api_parent_node[tid] = full_name
280
+ ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
269
281
  if self.scope:
270
282
  self.scope.begin_module(full_name)
271
283
 
@@ -273,14 +285,15 @@ class ModuleProcesser:
273
285
  tid = threading.get_ident()
274
286
  if torch_version_above_or_equal_2 or is_forward:
275
287
  ModuleProcesser.module_queue.remove_name(full_name)
276
- ModuleProcesser.api_parent_node[tid] = None
288
+ ModuleProcesser.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
277
289
  if self.module_stack.get(tid):
278
290
  ModuleProcesser.module_stack[tid].pop()
279
291
  if self.module_stack.get(tid):
280
- ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
292
+ ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1] if not is_megatron() \
293
+ else [ModuleProcesser.module_stack[tid][-1], get_micro_step()]
281
294
  if self.scope:
282
295
  self.scope.end_module(full_name)
283
296
  else:
284
297
  if self.scope:
285
298
  self.scope.begin_module(full_name)
286
- ModuleProcesser.api_parent_node[tid] = full_name
299
+ ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
@@ -43,7 +43,6 @@ else:
43
43
 
44
44
  torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
45
45
 
46
- _inner_used_api = {}
47
46
  _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
48
47
  _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
49
48
  dist_data_collect_func = {}
@@ -85,6 +84,12 @@ if not is_gpu:
85
84
  mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
86
85
  dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
87
86
 
87
+ _inner_used_api = {
88
+ Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_TENSOR: (
89
+ torch.Tensor, "view_as"
90
+ )
91
+ }
92
+
88
93
 
89
94
  @parameter_adapter
90
95
  def tensor_module_forward(module, *args, **kwargs):
@@ -19,12 +19,14 @@ import importlib
19
19
  from collections import defaultdict
20
20
  from datetime import datetime
21
21
  from functools import partial
22
+ from itertools import cycle
22
23
 
23
24
  import pytz
24
25
  import torch
25
26
  import torch.distributed as dist
26
27
  import pandas as pd
27
28
  from torch.utils.hooks import BackwardHook
29
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
28
30
 
29
31
  from msprobe.core.common.const import MonitorConst, Const
30
32
  from msprobe.core.common.file_utils import load_json, save_json, make_dir
@@ -229,6 +231,8 @@ class TrainerMon:
229
231
  self.duplicate_param = {}
230
232
  self.name2tag = {}
231
233
  self.param_name_call_id = {}
234
+ self.flat_prefix_names = []
235
+ self.flat_prefix_reverse_iter = None
232
236
  self.call_id = 0
233
237
  self.module_struct = defaultdict(dict)
234
238
  self.grad_accs = []
@@ -945,13 +949,20 @@ class TrainerMon:
945
949
  return False
946
950
 
947
951
  def _register_chunk(self, model_chunk, prefix):
952
+ if isinstance(model_chunk, FSDP):
953
+ if not model_chunk._use_orig_params:
954
+ raise ValueError("Only Support fsdp1 with use_orig_params=True")
955
+ self.fsdp_wrapped_module = True
948
956
  for (param_name, param) in model_chunk.named_parameters():
949
957
  if not param.requires_grad:
950
958
  continue
951
- if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
952
- self.fsdp_wrapped_module = True
953
959
  if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor":
954
960
  self.fsdp2_wrapped_module = True
961
+ if self.fsdp_wrapped_module: # FSDP1需要记录完整的不被target限制的flat权重前缀名,以供后续对flat解包
962
+ flat_prefix_name, _ = param_name.rsplit(MonitorConst.FSDP_FLAT_SEP, 1)
963
+ if flat_prefix_name not in self.flat_prefix_names:
964
+ self.flat_prefix_names.append(flat_prefix_name)
965
+
955
966
  if self._is_target_param(param_name, param, prefix):
956
967
  name = prefix + squash_param_name(param_name, self.squash_name)
957
968
  if name in self.param2name.values():
@@ -975,6 +986,8 @@ class TrainerMon:
975
986
  k: get_summary_writer_tag_name(name, k, self.rank)
976
987
  for k in keywords
977
988
  }
989
+ if self.fsdp_wrapped_module:
990
+ self.flat_prefix_reverse_iter = cycle(reversed(self.flat_prefix_names)) # post_backward_hook调用顺序是反向的
978
991
 
979
992
  def _register_param_name(self):
980
993
  for vpp_stage, model_chunk in enumerate(self.model):
@@ -1224,17 +1237,22 @@ class TrainerMon:
1224
1237
  每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
1225
1238
  因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
1226
1239
  """
1240
+
1227
1241
  def patch_post_backward_hook(_post_backward_hook):
1228
1242
  def wrapper(state, handle, *unused):
1229
1243
  grad_dict = {}
1230
- offset = 0
1231
- for param, name in self.param2name.items():
1232
- limit = param.numel()
1233
- if not limit:
1244
+ local_names = handle.flat_param._fqns
1245
+ offsets = handle._get_flat_param_offsets()
1246
+ shapes = handle.flat_param._shapes
1247
+ flat_prefix = next(self.flat_prefix_reverse_iter)
1248
+ for local_name, (start, end), local_shape in zip(local_names, offsets, shapes):
1249
+ grad_clip = handle.flat_param.grad[start:end + 1]
1250
+ grad = grad_clip.reshape(local_shape)
1251
+ total_name = f"{flat_prefix}{MonitorConst.FSDP_FLAT_SEP}{local_name}"
1252
+ if total_name not in self.origin2squash:
1253
+ logger.warning(f"{total_name} not in model.named_parameters(), skip.")
1234
1254
  continue
1235
- grad = handle.flat_param.grad[offset:offset + limit]
1236
- offset += limit
1237
- tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
1255
+ tag = self.name2tag.get(self.origin2squash[total_name], {}).get(MonitorConst.PRE_GRAD)
1238
1256
  if tag is None:
1239
1257
  continue
1240
1258
  grad_dict[tag] = grad
@@ -1242,6 +1260,7 @@ class TrainerMon:
1242
1260
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1243
1261
  out = _post_backward_hook(state, handle, *unused)
1244
1262
  return out
1263
+
1245
1264
  return wrapper
1246
1265
 
1247
1266
  logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
@@ -17,7 +17,7 @@ import json
17
17
  import os
18
18
  import time
19
19
  import multiprocessing
20
- from multiprocessing import Pool
20
+ from multiprocessing import Pool, Lock
21
21
 
22
22
  import torch
23
23
  from torch.utils._python_dispatch import TorchDispatchMode
@@ -39,6 +39,7 @@ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, ge
39
39
  from msprobe.pytorch.online_dispatch.compare import Comparator
40
40
  from msprobe.core.common.utils import check_str_param, safe_get_value
41
41
 
42
+ child_global_lock = None
42
43
  current_time = time.strftime("%Y%m%d%H%M%S")
43
44
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
44
45
  DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
@@ -86,14 +87,14 @@ class PtdbgDispatch(TorchDispatchMode):
86
87
  yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
87
88
  self.get_ops(yaml_path)
88
89
 
89
- self.lock = None
90
+ self.lock = Lock() if process_num > 0 else None
90
91
  max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
91
92
  if process_num > max_process_num:
92
93
  logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
93
94
  raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
94
95
  f'but got {process_num}!')
95
96
  if process_num > 0:
96
- self.pool = Pool(process_num)
97
+ self.pool = Pool(process_num, initializer=self._init_child_process, initargs=(self.lock,))
97
98
  if debug:
98
99
  logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
99
100
  f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
@@ -114,18 +115,17 @@ class PtdbgDispatch(TorchDispatchMode):
114
115
  logger.error("Please check train log, An exception may have occurred!")
115
116
  return
116
117
  check_file_or_directory_path(summary_path, False)
117
- fp_handle = FileOpen(summary_path, "r")
118
- while True:
119
- json_line_data = fp_handle.readline()
120
- if json_line_data == '\n':
121
- continue
122
- if len(json_line_data) == 0:
123
- break
124
- msg = json.loads(json_line_data)
125
- if len(msg) < 2:
126
- raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
127
- self.all_summary[msg[0]] = msg[1]
128
- fp_handle.close()
118
+ with FileOpen(summary_path, "r") as fp_handle:
119
+ while True:
120
+ json_line_data = fp_handle.readline()
121
+ if json_line_data == '\n':
122
+ continue
123
+ if len(json_line_data) == 0:
124
+ break
125
+ msg = json.loads(json_line_data)
126
+ if len(msg) < 2:
127
+ raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
128
+ self.all_summary[msg[0]] = msg[1]
129
129
 
130
130
  if self.debug_flag:
131
131
  input_num = 0
@@ -163,11 +163,16 @@ class PtdbgDispatch(TorchDispatchMode):
163
163
 
164
164
  call_stack = get_callstack()
165
165
  self.call_stack_list.append(call_stack)
166
- self.api_index += 1
167
- if aten_api not in self.single_api_index_dict:
168
- self.single_api_index_dict[aten_api] = 1
169
- else:
170
- self.single_api_index_dict[aten_api] += 1
166
+
167
+ self.lock.acquire() if self.process_num > 0 else None
168
+ try:
169
+ self.api_index += 1
170
+ if aten_api not in self.single_api_index_dict:
171
+ self.single_api_index_dict[aten_api] = 1
172
+ else:
173
+ self.single_api_index_dict[aten_api] += 1
174
+ finally:
175
+ self.lock.release() if self.process_num > 0 else None
171
176
 
172
177
  run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
173
178
 
@@ -180,7 +185,7 @@ class PtdbgDispatch(TorchDispatchMode):
180
185
  cpu_kwargs = []
181
186
  data_to_cpu(args, 0, cpu_args)
182
187
  data_to_cpu(kwargs, 0, cpu_kwargs)
183
-
188
+
184
189
  cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
185
190
  cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
186
191
 
@@ -194,7 +199,12 @@ class PtdbgDispatch(TorchDispatchMode):
194
199
  try:
195
200
  cpu_out = func(*cpu_args, **cpu_kwargs)
196
201
  except RuntimeError as e:
197
- self.api_index -= 1
202
+ self.lock.acquire() if self.process_num > 0 else None
203
+ try:
204
+ self.api_index -= 1
205
+ self.single_api_index_dict[aten_api] -= 1
206
+ finally:
207
+ self.lock.release() if self.process_num > 0 else None
198
208
  logger.warning(f"RuntimeError: {e}")
199
209
  logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
200
210
  return npu_out
@@ -215,7 +225,7 @@ class PtdbgDispatch(TorchDispatchMode):
215
225
  run_param.process_flag = True
216
226
  if self.check_fun(func, run_param):
217
227
  data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
218
- self.lock)
228
+ child_global_lock)
219
229
  self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
220
230
  error_callback=error_call)
221
231
  else:
@@ -233,12 +243,20 @@ class PtdbgDispatch(TorchDispatchMode):
233
243
  return True
234
244
  return False
235
245
 
246
+ @staticmethod
247
+ def _init_child_process(lock):
248
+ global child_global_lock
249
+ child_global_lock = lock
250
+
236
251
  def get_dir_name(self, tag):
237
252
  # guarantee file uniqueness
238
253
  time.sleep(1)
239
- time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
254
+ # 时间格式:年-月-日-时-分-秒-毫秒(精确到千分之一秒)
255
+ time_now = time.strftime("%Y%m%d%H%M%S%f", time.localtime(time.time()))[:-3] # 取前3位毫秒
256
+
240
257
  if tag is None or not isinstance(tag, str):
241
258
  logger.warning('There is not tag or the type of tag is not string.')
259
+ # 目录名格式:msprobe_rank{设备ID}_{毫秒时间戳}
242
260
  dir_name = f'msprobe_rank{self.device_id}_{time_now}'
243
261
  else:
244
262
  dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
@@ -35,15 +35,48 @@ from msprobe.pytorch.hook_module.utils import get_ops
35
35
  class TensorConfig(BaseConfig):
36
36
  def __init__(self, json_config):
37
37
  super().__init__(json_config)
38
+ self.online_run_ut = json_config.get("online_run_ut", False)
39
+ self.nfs_path = json_config.get("nfs_path", "")
40
+ self.host = json_config.get("host", "")
41
+ self.port = json_config.get("port", -1)
42
+ self.tls_path = json_config.get("tls_path", "./")
43
+ self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
38
44
  self.check_config()
39
45
  self._check_summary_mode()
40
46
  self._check_file_format()
41
-
47
+ if self.online_run_ut:
48
+ self._check_online_run_ut()
42
49
 
43
50
  def _check_file_format(self):
44
51
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
45
52
  raise Exception("file_format is invalid")
46
53
 
54
+ def _check_online_run_ut(self):
55
+ if not isinstance(self.online_run_ut, bool):
56
+ raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
57
+
58
+ if not isinstance(self.online_run_ut_recompute, bool):
59
+ raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
60
+
61
+ if self.nfs_path:
62
+ check_file_or_directory_path(self.nfs_path, isdir=True)
63
+ return
64
+
65
+ if self.tls_path:
66
+ check_file_or_directory_path(self.tls_path, isdir=True)
67
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
68
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
69
+ check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
70
+ crl_path = os.path.join(self.tls_path, "crl.pem")
71
+ if os.path.exists(crl_path):
72
+ check_file_or_directory_path(crl_path)
73
+
74
+ if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
75
+ raise Exception(f"host: {self.host} is invalid.")
76
+
77
+ if not isinstance(self.port, int) or not (0 < self.port <= 65535):
78
+ raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
79
+
47
80
 
48
81
  class StatisticsConfig(BaseConfig):
49
82
  def __init__(self, json_config):
@@ -80,6 +113,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
80
113
  self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
81
114
  self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
82
115
  self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
116
+ self.list = json_config.get("list")
83
117
  self.if_preheat = json_config.get("if_preheat", False)
84
118
  self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
85
119
  self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
@@ -146,6 +180,11 @@ class FreeBenchmarkCheckConfig(BaseConfig):
146
180
  logger.error_log_with_exp(
147
181
  msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
148
182
  )
183
+ if self.fuzz_stage == Const.BACKWARD and not self.list:
184
+ raise MsprobeException(
185
+ MsprobeException.INVALID_PARAM_ERROR,
186
+ f"When fuzz_stage is set to {Const.BACKWARD}, the parameters list must not be empty."
187
+ )
149
188
 
150
189
  def _check_fuzz_level(self):
151
190
  if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
@@ -218,7 +257,12 @@ class RunUTConfig(BaseConfig):
218
257
  self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
219
258
  self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
220
259
  self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
221
-
260
+ self.is_online = json_config.get("is_online", False)
261
+ self.nfs_path = json_config.get("nfs_path", "")
262
+ self.host = json_config.get("host", "")
263
+ self.port = json_config.get("port", -1)
264
+ self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
265
+ self.tls_path = json_config.get("tls_path", "./")
222
266
  self.check_run_ut_config()
223
267
 
224
268
  @classmethod
@@ -236,11 +280,22 @@ class RunUTConfig(BaseConfig):
236
280
  if not os.path.exists(error_data_path):
237
281
  raise Exception("error_data_path: %s does not exist" % error_data_path)
238
282
 
283
+ @classmethod
284
+ def check_nfs_path_config(cls, nfs_path):
285
+ if nfs_path:
286
+ FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
287
+
288
+ @classmethod
289
+ def check_tls_path_config(cls, tls_path):
290
+ if tls_path:
291
+ FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
239
292
 
240
293
  def check_run_ut_config(self):
241
294
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
242
295
  RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
243
296
  RunUTConfig.check_error_data_path_config(self.error_data_path)
297
+ RunUTConfig.check_nfs_path_config(self.nfs_path)
298
+ RunUTConfig.check_tls_path_config(self.tls_path)
244
299
 
245
300
 
246
301
  class GradToolConfig(BaseConfig):
@@ -15,8 +15,9 @@
15
15
 
16
16
  from msprobe.core.common.utils import Const
17
17
  from msprobe.core.service import BaseService
18
+ from msprobe.pytorch.attl_manager import ATTLManager
18
19
  from msprobe.pytorch.common.log import logger
19
- from msprobe.pytorch.common.utils import get_rank_if_initialized
20
+ from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
20
21
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
21
22
  from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
22
23
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
@@ -24,6 +25,9 @@ from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
24
25
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
25
26
  from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
26
27
 
28
+ if torch_version_above_or_equal_2:
29
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
30
+
27
31
 
28
32
  class PytorchService(BaseService):
29
33
  @property
@@ -41,10 +45,12 @@ class PytorchService(BaseService):
41
45
  self.logger = logger
42
46
  self.api_register = get_api_register()
43
47
  self.module_processor = ModuleProcesser(self.data_collector.scope)
44
- self.hook_manager = PytorchHookManager(self.data_collector, self.config)
48
+ self.attl_manager = ATTLManager(self.config)
49
+ self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
45
50
  self.api_template = ApiTemplate
46
51
 
47
52
  def _register_hook(self):
53
+ self.attl_manager.attl_init()
48
54
  if self._is_mix_level:
49
55
  register_optimizer_hook(self.data_collector)
50
56
 
@@ -59,6 +65,9 @@ class PytorchService(BaseService):
59
65
  self.module_processor.register_module_hook(self.model, self.build_hook)
60
66
  self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
61
67
 
68
+ def _run_ut_dispatch(self, status):
69
+ if torch_version_above_or_equal_2:
70
+ run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
62
71
 
63
72
  def _reset_status(self):
64
73
  super()._reset_status()