mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
  3. msprobe/core/common/const.py +3 -0
  4. msprobe/core/common/file_utils.py +45 -5
  5. msprobe/core/common/utils.py +117 -13
  6. msprobe/core/common_config.py +15 -1
  7. msprobe/core/compare/acc_compare.py +21 -9
  8. msprobe/core/compare/compare_cli.py +10 -2
  9. msprobe/core/compare/merge_result/merge_result.py +1 -1
  10. msprobe/core/compare/utils.py +8 -2
  11. msprobe/core/config_check/checkers/base_checker.py +2 -0
  12. msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
  13. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
  14. msprobe/core/config_check/config_check_cli.py +1 -1
  15. msprobe/core/config_check/config_checker.py +1 -2
  16. msprobe/core/data_dump/data_collector.py +4 -1
  17. msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
  18. msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
  19. msprobe/core/debugger/precision_debugger.py +13 -8
  20. msprobe/core/hook_manager.py +112 -82
  21. msprobe/core/monitor/utils.py +338 -0
  22. msprobe/core/service.py +2 -1
  23. msprobe/core/single_save/single_comparator.py +5 -3
  24. msprobe/docs/01.installation.md +1 -0
  25. msprobe/docs/05.data_dump_PyTorch.md +4 -4
  26. msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
  27. msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
  29. msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
  30. msprobe/docs/12.overflow_check_PyTorch.md +3 -2
  31. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  32. msprobe/docs/14.data_parse_PyTorch.md +35 -32
  33. msprobe/docs/21.visualization_PyTorch.md +9 -8
  34. msprobe/docs/22.visualization_MindSpore.md +1 -0
  35. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  36. msprobe/docs/24.code_mapping_Mindspore.md +6 -5
  37. msprobe/docs/31.config_check.md +15 -5
  38. msprobe/docs/33.generate_operator_MindSpore.md +2 -2
  39. msprobe/docs/34.RL_collect.md +18 -9
  40. msprobe/docs/35.nan_analyze.md +4 -3
  41. msprobe/docs/FAQ.md +3 -0
  42. msprobe/docs/img/ms_layer.png +0 -0
  43. msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
  44. msprobe/mindspore/cell_processor.py +35 -14
  45. msprobe/mindspore/code_mapping/bind.py +23 -4
  46. msprobe/mindspore/code_mapping/graph_parser.py +6 -4
  47. msprobe/mindspore/common/utils.py +3 -0
  48. msprobe/mindspore/compare/common_dir_compare.py +32 -12
  49. msprobe/mindspore/compare/ms_graph_compare.py +7 -2
  50. msprobe/mindspore/compare/utils.py +9 -1
  51. msprobe/mindspore/debugger/debugger_config.py +13 -11
  52. msprobe/mindspore/debugger/precision_debugger.py +67 -45
  53. msprobe/mindspore/dump/dump_tool_factory.py +2 -0
  54. msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
  55. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
  56. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
  57. msprobe/mindspore/dump/jit_dump.py +6 -3
  58. msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
  59. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
  60. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
  61. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
  62. msprobe/mindspore/mindspore_service.py +2 -2
  63. msprobe/mindspore/monitor/common_func.py +1 -1
  64. msprobe/mindspore/monitor/module_hook.py +3 -3
  65. msprobe/mindspore/monitor/utils.py +0 -252
  66. msprobe/mindspore/ms_config.py +0 -1
  67. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  68. msprobe/nan_analyze/graph.py +4 -0
  69. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
  70. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
  71. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
  72. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
  73. msprobe/pytorch/common/utils.py +0 -16
  74. msprobe/pytorch/compare/pt_compare.py +5 -0
  75. msprobe/pytorch/debugger/debugger_config.py +12 -5
  76. msprobe/pytorch/debugger/precision_debugger.py +8 -1
  77. msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
  78. msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
  79. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
  80. msprobe/pytorch/hook_module/hook_module.py +9 -9
  81. msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
  82. msprobe/pytorch/monitor/csv2tb.py +3 -10
  83. msprobe/pytorch/monitor/features.py +5 -0
  84. msprobe/pytorch/monitor/module_hook.py +6 -7
  85. msprobe/pytorch/monitor/module_metric.py +0 -3
  86. msprobe/pytorch/monitor/optimizer_collect.py +1 -1
  87. msprobe/pytorch/monitor/utils.py +1 -317
  88. msprobe/pytorch/online_dispatch/dispatch.py +1 -1
  89. msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
  90. msprobe/pytorch/parse_tool/lib/utils.py +2 -4
  91. msprobe/visualization/graph_service.py +1 -1
  92. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
  93. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
  94. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
  95. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
@@ -23,10 +23,11 @@
23
23
  msprobe -f pytorch nan_analyze -i dump_step_path -o output_dir_path
24
24
  ```
25
25
 
26
- | 参数 | 说明 |
27
- |--------------------||
26
+ | 参数 | 说明 |
27
+ |--------------------|---------------------------------------------|
28
+ | -f 或 --framework | 指定训练框架。pytorch。必选。 |
28
29
  | -i 或 --input_path | dump数据的目录。需指定到step层级,如`-i /xxx/dump/step0/` |
29
- | -o 或 --output_path | 输出文件的目录,可选,不填时默认在当前目录下创建 \"./output/" 目录。 |
30
+ | -o 或 --output_path | 输出文件的目录,可选,不填时默认在当前目录下创建 \"./output/" 目录。 |
30
31
 
31
32
  ### 输出文件介绍
32
33
 
msprobe/docs/FAQ.md CHANGED
@@ -36,6 +36,9 @@
36
36
  该信息说明 module 挂载了被 PyTorch 框架废弃的 register_backward_hook,这与工具使用的 register_full_backward_hook 接口会产生冲突,故工具会跳过该 module 的反向数据采集。
37
37
  - 如果您希望所有 module 数据都能采集下来,可以将模型中使用的 register_backward_hook 接口改为 PyTorch 框架推荐的 register_full_backward_pre_hook 或 register_full_backward_hook 接口。
38
38
 
39
+
40
+ 5. 在使用 msprobe 进行 Pytorch 框架的数据采集功能时,请注意确认环境变量 NPU_ASD_ENABLE=0 ,即关闭特征值检测功能。 由于工具冲突, 在该功能开启的情况下可能导致某些 api 数据采集的缺失。
41
+
39
42
  # 2 精度预检(PyTorch)
40
43
 
41
44
  1. 预检工具在 dump 和 run_ut 的过程中,是否需要同时开启或关闭 jit 编译(jit_compile)?
Binary file
@@ -13,6 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import (
17
+ Any,
18
+ Dict,
19
+ List,
20
+ Tuple,
21
+ Union
22
+ )
16
23
  import os
17
24
  import numpy as np
18
25
  import mindspore
@@ -39,6 +46,22 @@ if torch_mindtorch_importer.is_valid_pt_mt_env:
39
46
  else:
40
47
  import torch
41
48
 
49
+ # 为了可读性,我们先给每种返回形态起个别名
50
+ ForwardResult = Tuple[
51
+ List[ComputeElement],
52
+ Tuple[Any, ...],
53
+ Dict[str, Any],
54
+ Tuple[Any, ...],
55
+ ]
56
+
57
+ BackwardResultMT = Tuple[
58
+ List[ComputeElement],
59
+ Union[Any, Tuple[Any, ...]],
60
+ Tuple[Any, ...],
61
+ ]
62
+
63
+ PyTorchBackward = List[ComputeElement]
64
+
42
65
 
43
66
  class ApiInputAggregation:
44
67
  def __init__(self, inputs, kwargs, gradient_inputs) -> None:
@@ -179,7 +202,12 @@ class ApiRunner:
179
202
  return api_instance
180
203
 
181
204
  @staticmethod
182
- def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
205
+ def run_api(
206
+ api_instance,
207
+ api_input_aggregation,
208
+ forward_or_backward: str,
209
+ api_platform: str,
210
+ ) -> Union[ForwardResult, BackwardResultMT, PyTorchBackward]:
183
211
  inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
184
212
  for compute_element in api_input_aggregation.inputs)
185
213
  kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import threading
16
17
  from collections import OrderedDict
17
18
 
18
19
  from mindspore import Tensor
@@ -21,6 +22,8 @@ from mindspore.ops.operations import _inner_ops as inner
21
22
 
22
23
  from msprobe.core.common.const import Const
23
24
  from msprobe.core.common.exceptions import MsprobeException
25
+ from msprobe.core.common.runtime import Runtime
26
+ from msprobe.core.common.utils import ModuleQueue, ThreadSafe
24
27
  from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
25
28
  from msprobe.mindspore.common.const import Const as MsConst
26
29
  from msprobe.mindspore.common.log import logger
@@ -32,7 +35,6 @@ from msprobe.mindspore.common.utils import (
32
35
  )
33
36
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
34
37
  from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
35
- from msprobe.core.common.runtime import Runtime
36
38
 
37
39
 
38
40
  def get_cell_construct(construct):
@@ -40,13 +42,15 @@ def get_cell_construct(construct):
40
42
  if hasattr(self, 'msprobe_hook'):
41
43
  setattr(self, 'msprobe_input_kwargs', kwargs)
42
44
  return construct(self, *args, **kwargs)
45
+
43
46
  return _construct
44
47
 
45
48
 
46
49
  class CellProcessor:
50
+ cell_queue = ModuleQueue()
47
51
  cell_count = {}
48
- cell_stack = []
49
- api_parent_node = None
52
+ cell_stack = {}
53
+ api_parent_node = {}
50
54
  module_node = {}
51
55
  cell_bw_hook_kernels = {}
52
56
  cell_backward_pre_hook = []
@@ -65,9 +69,10 @@ class CellProcessor:
65
69
 
66
70
  @classmethod
67
71
  def reset_cell_stats(cls):
72
+ cls.cell_queue = ModuleQueue()
68
73
  cls.cell_count = {}
69
- cls.cell_stack = []
70
- cls.api_parent_node = None
74
+ cls.cell_stack = {}
75
+ cls.api_parent_node = {}
71
76
  cls.module_node = {}
72
77
  cls.cell_bw_hook_kernels = {}
73
78
  cls.cell_backward_pre_hook = []
@@ -122,6 +127,7 @@ class CellProcessor:
122
127
  GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
123
128
 
124
129
  def build_cell_hook(self, cell_name, build_data_hook):
130
+ @ThreadSafe.synchronized
125
131
  def forward_pre_hook(cell, args):
126
132
  index = CellProcessor.set_and_get_calls_number(cell_name)
127
133
  full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
@@ -146,11 +152,13 @@ class CellProcessor:
146
152
  setattr(cell, 'msprobe_forward_hook', True)
147
153
 
148
154
  def get_backward_hook(backward_data_hook, full_backward_name):
155
+ @ThreadSafe.synchronized
149
156
  def backward_hook_fn(cell, grad_input, grad_output):
150
157
  new_output = backward_data_hook(cell, grad_input, grad_output)
151
158
  self.set_construct_info_in_hook(full_backward_name)
152
159
  cell.has_pre_hook_called = False
153
160
  return new_output
161
+
154
162
  return backward_hook_fn
155
163
 
156
164
  enable_hooked = sum(
@@ -170,13 +178,14 @@ class CellProcessor:
170
178
 
171
179
  return args
172
180
 
181
+ @ThreadSafe.synchronized
173
182
  def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None):
174
183
  index = CellProcessor.cell_count.get(cell_name, 0)
175
184
  full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
176
185
  full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
177
186
 
178
187
  self.set_construct_info_in_hook(full_forward_name)
179
-
188
+
180
189
  hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
181
190
  hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs)
182
191
  if hook_result is not None:
@@ -199,6 +208,7 @@ class CellProcessor:
199
208
  outputs = new_outputs
200
209
 
201
210
  def get_backward_pre_hook(full_backward_name, backward_data_hook):
211
+ @ThreadSafe.synchronized
202
212
  def backward_pre_hook_fn(cell, grad_output):
203
213
  cell.has_pre_hook_called = True
204
214
  self.set_construct_info_in_pre_hook(full_backward_name)
@@ -206,6 +216,7 @@ class CellProcessor:
206
216
  backward_data_hook(cell, (), grad_output)
207
217
  self.set_construct_info_in_hook(full_backward_name)
208
218
  cell.has_pre_hook_called = False
219
+
209
220
  return backward_pre_hook_fn
210
221
 
211
222
  backward_pre_hook = OrderedDict()
@@ -233,18 +244,28 @@ class CellProcessor:
233
244
  return forward_pre_hook
234
245
 
235
246
  def set_construct_info_in_pre_hook(self, full_name):
236
- if self.cell_stack:
237
- CellProcessor.module_node[full_name] = self.cell_stack[-1]
247
+ tid = threading.get_ident()
248
+ if tid not in self.cell_stack:
249
+ CellProcessor.cell_stack[tid] = []
250
+
251
+ if self.cell_stack[tid]:
252
+ CellProcessor.module_node[full_name] = self.cell_stack[tid][-1]
238
253
  else:
239
- CellProcessor.module_node[full_name] = None
240
- CellProcessor.cell_stack.append(full_name)
241
- CellProcessor.api_parent_node = full_name
254
+ parent_name = CellProcessor.cell_queue.find_last(full_name)
255
+ CellProcessor.module_node[full_name] = parent_name
256
+
257
+ CellProcessor.cell_queue.add_name(full_name)
258
+ CellProcessor.cell_stack[tid].append(full_name)
259
+ CellProcessor.api_parent_node[tid] = full_name
242
260
  if self.scope:
243
261
  self.scope.begin_module(full_name)
244
262
 
245
263
  def set_construct_info_in_hook(self, full_name):
246
- if self.cell_stack:
247
- CellProcessor.cell_stack.pop()
248
- CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] if self.cell_stack else None
264
+ tid = threading.get_ident()
265
+ CellProcessor.api_parent_node[tid] = None
266
+ if self.cell_stack.get(tid):
267
+ CellProcessor.cell_stack[tid].pop()
268
+ if self.cell_stack.get(tid):
269
+ CellProcessor.api_parent_node[tid] = CellProcessor.cell_stack[tid][-1]
249
270
  if self.scope:
250
271
  self.scope.end_module(full_name)
@@ -119,9 +119,17 @@ def find_npy_files(npy_path):
119
119
 
120
120
  # 如果是目录,使用Path.rglob查找所有.npy文件
121
121
  if npy_path_obj.is_dir():
122
- for file in npy_path_obj.rglob(Const.NUMPY_PATTERN):
123
- check_file_or_directory_path(file)
124
- npy_files.append(file.resolve())
122
+ base_depth = len(npy_path_obj.resolve().parts)
123
+ for root, dirs, files in os.walk(npy_path_obj):
124
+ current_depth = len(Path(root).resolve().parts) - base_depth
125
+ if current_depth >= 10:
126
+ dirs[:] = []
127
+
128
+ for filename in files:
129
+ if filename.endswith(Const.NUMPY_SUFFIX):
130
+ file_path = Path(root) / filename
131
+ check_file_or_directory_path(file_path)
132
+ npy_files.append(file_path.resolve())
125
133
  else:
126
134
  logger.info(f"The specified path is neither an .npy file nor a directory: {npy_path}")
127
135
 
@@ -254,7 +262,18 @@ def bind_code_info_for_data(input_dir: str, nodes: Dict[str, GraphNode]) -> Dict
254
262
  corresponding_name = None
255
263
  name_without_ext = os.path.splitext(corresponding_name)[0]
256
264
  npy_path = os.path.realpath(npy_file)
257
- node_scope = name_without_ext.split(".")[1]
265
+
266
+ parts = name_without_ext.split(".")
267
+ if len(parts) < 2:
268
+ logger.error(
269
+ f'File name "{file_name}" in "{directory}" '
270
+ f'does not conform to expected format (missing scope separator ".")!'
271
+ )
272
+ raise Exception(
273
+ f'File name "{file_name}" has incorrect format, cannot extract node scope!'
274
+ )
275
+ node_scope = parts[1]
276
+
258
277
  trie = Trie()
259
278
  for key, value in match_dict.items():
260
279
  trie.insert(key, value)
@@ -77,7 +77,7 @@ class Parser:
77
77
 
78
78
  @staticmethod
79
79
  def extract_constants(inputs_str: str) -> List[str]:
80
- constant_pattern = re.compile(r'\b(\w+\(.*?\))')
80
+ constant_pattern = re.compile(r'\b([A-Za-z_][A-Za-z0-9_]{0,10000})\(([A-Za-z0-9_\s,.\-+/]{0,10000})\)')
81
81
  constants = constant_pattern.findall(inputs_str)
82
82
  return constants
83
83
 
@@ -90,7 +90,8 @@ class Parser:
90
90
  self.nodes[func_name] = func_graph_info
91
91
 
92
92
  def parse_nodes(self, text: str, subgraph_info: GraphNode) -> None:
93
- node_pattern = re.compile(r'(%\d+)\((\S+)\)\s*=\s*(\S+)\(')
93
+ node_pattern = re.compile(
94
+ r'(%\d{1,10000})\(([A-Za-z0-9_\.]{1,10000})\)\s*=\s*([A-Za-z_][A-Za-z0-9_]{0,10000})\(')
94
95
  matches = list(node_pattern.finditer(text))
95
96
  for i, match in enumerate(matches):
96
97
  series_number = match.group(1)
@@ -106,8 +107,9 @@ class Parser:
106
107
 
107
108
  constants = self.__class__.extract_constants(args_str)
108
109
 
109
- scope_pattern = re.compile(r'# .*scope.*:\s*\((.*?)\)', re.IGNORECASE | re.MULTILINE)
110
-
110
+ scope_pattern = re.compile(
111
+ r'^(?=.{0,300}$)[ \t]*\#[ \t]*[^\r\n]*?scope[^\r\n]*?:[ \t]*\(([^)\r\n]{1,200})\)[ \t]*$',
112
+ re.IGNORECASE | re.MULTILINE)
111
113
  scope_match = scope_pattern.search(text, end_pos)
112
114
  scope = scope_match.group(1) if scope_match else ""
113
115
 
@@ -95,6 +95,9 @@ def save_tensor_as_npy(tensor, file_path):
95
95
 
96
96
 
97
97
  def convert_to_int(value):
98
+ if isinstance(value, bool):
99
+ logger.error('The value in rank_id or step should be int, please check!')
100
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
98
101
  try:
99
102
  return int(value)
100
103
  except Exception:
@@ -30,9 +30,10 @@ from msprobe.core.common.utils import CompareException
30
30
  from msprobe.core.common.exceptions import FileCheckException
31
31
  from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \
32
32
  check_path_before_create, load_npy
33
- from msprobe.core.common.const import CompareConst, FileCheckConst
33
+ from msprobe.core.common.const import CompareConst
34
34
  from msprobe.core.compare.npy_compare import compare_ops_apply
35
35
  from msprobe.core.compare.multiprocessing_compute import check_accuracy
36
+ from msprobe.mindspore.compare.utils import check_name_map_dict
36
37
 
37
38
 
38
39
  def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]:
@@ -49,6 +50,7 @@ def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataF
49
50
  npu_root = Path(input_params.get('npu_path'))
50
51
  bench_root = Path(input_params.get('bench_path'))
51
52
  name_map_dict = input_params.get('map_dict', {})
53
+ check_name_map_dict(name_map_dict)
52
54
  file_tree = build_mirror_file_tree(npu_root, bench_root)
53
55
 
54
56
  # 处理文件比对
@@ -114,24 +116,42 @@ def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple
114
116
  file_tree = {}
115
117
 
116
118
  # 遍历NPU目录构建树结构
117
- for npu_path in npu_root.rglob('*.npy'):
118
- dir_path = npu_path.relative_to(npu_root).parent
119
- npu_dir_pair = os.path.join(npu_root, dir_path)
120
- bench_dir_pair = os.path.join(bench_root, dir_path)
121
- try:
122
- check_file_or_directory_path(bench_dir_pair, isdir=True)
123
- except FileCheckException:
119
+ # 使用os.walk遍历目录,限制深度为10层
120
+ for root, dirs, files in os.walk(npu_root):
121
+ # 计算当前目录深度
122
+ depth = len(Path(root).relative_to(npu_root).parts)
123
+ if depth > 10:
124
+ dirs.clear() # 清空dirs列表以阻止继续递归
124
125
  continue
125
- # 添加到文件树
126
- if dir_path not in file_tree:
127
- file_tree[dir_path] = (npu_dir_pair, bench_dir_pair)
126
+
127
+ # 检查当前目录下是否有npy文件
128
+ if any(f.endswith('.npy') for f in files):
129
+ # 获取相对路径
130
+ dir_path = Path(root).relative_to(npu_root)
131
+ npu_dir_pair = os.path.join(npu_root, dir_path)
132
+ bench_dir_pair = os.path.join(bench_root, dir_path)
133
+
134
+ try:
135
+ check_file_or_directory_path(bench_dir_pair, isdir=True)
136
+ except FileCheckException:
137
+ continue
138
+
139
+ # 添加到文件树
140
+ if dir_path not in file_tree:
141
+ file_tree[dir_path] = (npu_dir_pair, bench_dir_pair)
128
142
 
129
143
  return file_tree
130
144
 
131
145
 
132
146
  def find_npy_files(directory):
133
147
  npy_files_dict = {}
134
- for root, _, files in os.walk(directory):
148
+ # 限制递归深度为1层,即只遍历当前目录和其直接子目录
149
+ for root, dirs, files in os.walk(directory, topdown=True):
150
+ # 计算当前目录深度
151
+ depth = root[len(directory):].count(os.sep)
152
+ # 如果深度超过10层则跳过
153
+ if depth > 10:
154
+ dirs.clear()
135
155
  for file in files:
136
156
  if file.endswith(".npy"):
137
157
  # 分割文件名并去掉最后两个元素
@@ -168,8 +168,13 @@ class GraphMSComparator:
168
168
  self.output_path = output_path
169
169
  self.base_npu_path = input_param.get('npu_path', None)
170
170
  self.base_bench_path = input_param.get('bench_path', None)
171
- self.rank_list = [convert_to_int(rank_id) for rank_id in input_param.get('rank_id', [])]
172
- self.step_list = [convert_to_int(step_id) for step_id in input_param.get('step_id', [])]
171
+ rank_id_list = input_param.get('rank_id', [])
172
+ step_id_list = input_param.get('step_id', [])
173
+ if not isinstance(rank_id_list, list) or not isinstance(step_id_list, list):
174
+ logger.error("'rank_id' and 'step_id' should both be lists, please check!")
175
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
176
+ self.rank_list = [convert_to_int(rank_id) for rank_id in rank_id_list]
177
+ self.step_list = [convert_to_int(step_id) for step_id in step_id_list]
173
178
  # split by rank and step, generate rank step path
174
179
  self.npu_rank_step_dict = self.generate_rank_step_path(self.base_npu_path)
175
180
  self.bench_rank_step_dict = self.generate_rank_step_path(self.base_bench_path)
@@ -17,7 +17,8 @@ import os
17
17
 
18
18
  from msprobe.core.common.const import Const
19
19
  from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst
20
- from msprobe.core.common.utils import detect_framework_by_dump_json
20
+ from msprobe.core.common.utils import detect_framework_by_dump_json, CompareException, check_op_str_pattern_valid
21
+ from msprobe.core.common.log import logger
21
22
 
22
23
 
23
24
  def read_npy_data(dir_path, file_name):
@@ -35,3 +36,10 @@ def read_npy_data(dir_path, file_name):
35
36
  def check_cross_framework(bench_json_path):
36
37
  framework = detect_framework_by_dump_json(bench_json_path)
37
38
  return framework == Const.PT_FRAMEWORK
39
+
40
+
41
+ def check_name_map_dict(name_map_dict):
42
+ if not isinstance(name_map_dict, dict):
43
+ logger.error("'map_dict' should be a dict, please check!")
44
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
45
+ check_op_str_pattern_valid(str(name_map_dict))
@@ -81,18 +81,22 @@ class DebuggerConfig:
81
81
  target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
82
82
  if models is None or isinstance(models, target_module_type[0]):
83
83
  return models
84
- error_model = None
85
84
  if isinstance(models, (list, tuple)):
85
+ error_model = None
86
86
  for model in models:
87
87
  if not isinstance(model, target_module_type[0]):
88
88
  error_model = model
89
89
  break
90
- else:
91
- error_model = models
90
+ if error_model is not None:
91
+ error_info = (
92
+ f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
93
+ f"type, currently there is a {type(error_model)} type.")
94
+ raise MsprobeException(
95
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
92
96
 
93
- if error_model is not None:
97
+ else:
94
98
  error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
95
- f"type, currently there is a {type(error_model)} type.")
99
+ f"type, currently there is a {type(models)} type.")
96
100
  raise MsprobeException(
97
101
  MsprobeException.INVALID_PARAM_ERROR, error_info)
98
102
  return models
@@ -125,16 +129,14 @@ class DebuggerConfig:
125
129
  self.level_ori = Const.LEVEL_MIX
126
130
  return True
127
131
 
128
- def check_config_with_l2(self):
129
- if self.level_ori != Const.LEVEL_L2:
130
- return
131
- if self.task != Const.TENSOR:
132
+ def check_config_with_l2(self, is_graph_config):
133
+ if not is_graph_config and self.task != Const.TENSOR:
132
134
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
133
135
  f"When level is set to L2, the task must be set to tensor.")
134
- if self.scope:
136
+ if not is_graph_config and self.scope:
135
137
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
136
138
  f"When level is set to L2, the scope cannot be configured.")
137
- if not self.list or len(self.list) != 1:
139
+ if not is_graph_config and (not self.list or len(self.list) != 1):
138
140
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
139
141
  f"When level is set to L2, the list must be configured as a list with one api name.")
140
142