mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
- msprobe/core/common/const.py +3 -0
- msprobe/core/common/file_utils.py +45 -5
- msprobe/core/common/utils.py +117 -13
- msprobe/core/common_config.py +15 -1
- msprobe/core/compare/acc_compare.py +21 -9
- msprobe/core/compare/compare_cli.py +10 -2
- msprobe/core/compare/merge_result/merge_result.py +1 -1
- msprobe/core/compare/utils.py +8 -2
- msprobe/core/config_check/checkers/base_checker.py +2 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
- msprobe/core/config_check/config_check_cli.py +1 -1
- msprobe/core/config_check/config_checker.py +1 -2
- msprobe/core/data_dump/data_collector.py +4 -1
- msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
- msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
- msprobe/core/debugger/precision_debugger.py +13 -8
- msprobe/core/hook_manager.py +112 -82
- msprobe/core/monitor/utils.py +338 -0
- msprobe/core/service.py +2 -1
- msprobe/core/single_save/single_comparator.py +5 -3
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +4 -4
- msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
- msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
- msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
- msprobe/docs/12.overflow_check_PyTorch.md +3 -2
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +35 -32
- msprobe/docs/21.visualization_PyTorch.md +9 -8
- msprobe/docs/22.visualization_MindSpore.md +1 -0
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/24.code_mapping_Mindspore.md +6 -5
- msprobe/docs/31.config_check.md +15 -5
- msprobe/docs/33.generate_operator_MindSpore.md +2 -2
- msprobe/docs/34.RL_collect.md +18 -9
- msprobe/docs/35.nan_analyze.md +4 -3
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
- msprobe/mindspore/cell_processor.py +35 -14
- msprobe/mindspore/code_mapping/bind.py +23 -4
- msprobe/mindspore/code_mapping/graph_parser.py +6 -4
- msprobe/mindspore/common/utils.py +3 -0
- msprobe/mindspore/compare/common_dir_compare.py +32 -12
- msprobe/mindspore/compare/ms_graph_compare.py +7 -2
- msprobe/mindspore/compare/utils.py +9 -1
- msprobe/mindspore/debugger/debugger_config.py +13 -11
- msprobe/mindspore/debugger/precision_debugger.py +67 -45
- msprobe/mindspore/dump/dump_tool_factory.py +2 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
- msprobe/mindspore/dump/jit_dump.py +6 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/monitor/common_func.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +3 -3
- msprobe/mindspore/monitor/utils.py +0 -252
- msprobe/mindspore/ms_config.py +0 -1
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/nan_analyze/graph.py +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
- msprobe/pytorch/common/utils.py +0 -16
- msprobe/pytorch/compare/pt_compare.py +5 -0
- msprobe/pytorch/debugger/debugger_config.py +12 -5
- msprobe/pytorch/debugger/precision_debugger.py +8 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
- msprobe/pytorch/hook_module/hook_module.py +9 -9
- msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
- msprobe/pytorch/monitor/csv2tb.py +3 -10
- msprobe/pytorch/monitor/features.py +5 -0
- msprobe/pytorch/monitor/module_hook.py +6 -7
- msprobe/pytorch/monitor/module_metric.py +0 -3
- msprobe/pytorch/monitor/optimizer_collect.py +1 -1
- msprobe/pytorch/monitor/utils.py +1 -317
- msprobe/pytorch/online_dispatch/dispatch.py +1 -1
- msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
- msprobe/pytorch/parse_tool/lib/utils.py +2 -4
- msprobe/visualization/graph_service.py +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
msprobe/docs/35.nan_analyze.md
CHANGED
|
@@ -23,10 +23,11 @@
|
|
|
23
23
|
msprobe -f pytorch nan_analyze -i dump_step_path -o output_dir_path
|
|
24
24
|
```
|
|
25
25
|
|
|
26
|
-
| 参数 | 说明
|
|
27
|
-
|
|
26
|
+
| 参数 | 说明 |
|
|
27
|
+
|--------------------|---------------------------------------------|
|
|
28
|
+
| -f 或 --framework | 指定训练框架。pytorch。必选。 |
|
|
28
29
|
| -i 或 --input_path | dump数据的目录。需指定到step层级,如`-i /xxx/dump/step0/` |
|
|
29
|
-
| -o 或 --output_path | 输出文件的目录,可选,不填时默认在当前目录下创建 \"./output/" 目录。
|
|
30
|
+
| -o 或 --output_path | 输出文件的目录,可选,不填时默认在当前目录下创建 \"./output/" 目录。 |
|
|
30
31
|
|
|
31
32
|
### 输出文件介绍
|
|
32
33
|
|
msprobe/docs/FAQ.md
CHANGED
|
@@ -36,6 +36,9 @@
|
|
|
36
36
|
该信息说明 module 挂载了被 PyTorch 框架废弃的 register_backward_hook,这与工具使用的 register_full_backward_hook 接口会产生冲突,故工具会跳过该 module 的反向数据采集。
|
|
37
37
|
- 如果您希望所有 module 数据都能采集下来,可以将模型中使用的 register_backward_hook 接口改为 PyTorch 框架推荐的 register_full_backward_pre_hook 或 register_full_backward_hook 接口。
|
|
38
38
|
|
|
39
|
+
|
|
40
|
+
5. 在使用 msprobe 进行 Pytorch 框架的数据采集功能时,请注意确认环境变量 NPU_ASD_ENABLE=0 ,即关闭特征值检测功能。 由于工具冲突, 在该功能开启的情况下可能导致某些 api 数据采集的缺失。
|
|
41
|
+
|
|
39
42
|
# 2 精度预检(PyTorch)
|
|
40
43
|
|
|
41
44
|
1. 预检工具在 dump 和 run_ut 的过程中,是否需要同时开启或关闭 jit 编译(jit_compile)?
|
msprobe/docs/img/ms_layer.png
CHANGED
|
Binary file
|
|
@@ -13,6 +13,13 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from typing import (
|
|
17
|
+
Any,
|
|
18
|
+
Dict,
|
|
19
|
+
List,
|
|
20
|
+
Tuple,
|
|
21
|
+
Union
|
|
22
|
+
)
|
|
16
23
|
import os
|
|
17
24
|
import numpy as np
|
|
18
25
|
import mindspore
|
|
@@ -39,6 +46,22 @@ if torch_mindtorch_importer.is_valid_pt_mt_env:
|
|
|
39
46
|
else:
|
|
40
47
|
import torch
|
|
41
48
|
|
|
49
|
+
# 为了可读性,我们先给每种返回形态起个别名
|
|
50
|
+
ForwardResult = Tuple[
|
|
51
|
+
List[ComputeElement],
|
|
52
|
+
Tuple[Any, ...],
|
|
53
|
+
Dict[str, Any],
|
|
54
|
+
Tuple[Any, ...],
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
BackwardResultMT = Tuple[
|
|
58
|
+
List[ComputeElement],
|
|
59
|
+
Union[Any, Tuple[Any, ...]],
|
|
60
|
+
Tuple[Any, ...],
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
PyTorchBackward = List[ComputeElement]
|
|
64
|
+
|
|
42
65
|
|
|
43
66
|
class ApiInputAggregation:
|
|
44
67
|
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
@@ -179,7 +202,12 @@ class ApiRunner:
|
|
|
179
202
|
return api_instance
|
|
180
203
|
|
|
181
204
|
@staticmethod
|
|
182
|
-
def run_api(
|
|
205
|
+
def run_api(
|
|
206
|
+
api_instance,
|
|
207
|
+
api_input_aggregation,
|
|
208
|
+
forward_or_backward: str,
|
|
209
|
+
api_platform: str,
|
|
210
|
+
) -> Union[ForwardResult, BackwardResultMT, PyTorchBackward]:
|
|
183
211
|
inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
184
212
|
for compute_element in api_input_aggregation.inputs)
|
|
185
213
|
kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import threading
|
|
16
17
|
from collections import OrderedDict
|
|
17
18
|
|
|
18
19
|
from mindspore import Tensor
|
|
@@ -21,6 +22,8 @@ from mindspore.ops.operations import _inner_ops as inner
|
|
|
21
22
|
|
|
22
23
|
from msprobe.core.common.const import Const
|
|
23
24
|
from msprobe.core.common.exceptions import MsprobeException
|
|
25
|
+
from msprobe.core.common.runtime import Runtime
|
|
26
|
+
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
24
27
|
from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
|
|
25
28
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
29
|
from msprobe.mindspore.common.log import logger
|
|
@@ -32,7 +35,6 @@ from msprobe.mindspore.common.utils import (
|
|
|
32
35
|
)
|
|
33
36
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
34
37
|
from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
|
|
35
|
-
from msprobe.core.common.runtime import Runtime
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
def get_cell_construct(construct):
|
|
@@ -40,13 +42,15 @@ def get_cell_construct(construct):
|
|
|
40
42
|
if hasattr(self, 'msprobe_hook'):
|
|
41
43
|
setattr(self, 'msprobe_input_kwargs', kwargs)
|
|
42
44
|
return construct(self, *args, **kwargs)
|
|
45
|
+
|
|
43
46
|
return _construct
|
|
44
47
|
|
|
45
48
|
|
|
46
49
|
class CellProcessor:
|
|
50
|
+
cell_queue = ModuleQueue()
|
|
47
51
|
cell_count = {}
|
|
48
|
-
cell_stack =
|
|
49
|
-
api_parent_node =
|
|
52
|
+
cell_stack = {}
|
|
53
|
+
api_parent_node = {}
|
|
50
54
|
module_node = {}
|
|
51
55
|
cell_bw_hook_kernels = {}
|
|
52
56
|
cell_backward_pre_hook = []
|
|
@@ -65,9 +69,10 @@ class CellProcessor:
|
|
|
65
69
|
|
|
66
70
|
@classmethod
|
|
67
71
|
def reset_cell_stats(cls):
|
|
72
|
+
cls.cell_queue = ModuleQueue()
|
|
68
73
|
cls.cell_count = {}
|
|
69
|
-
cls.cell_stack =
|
|
70
|
-
cls.api_parent_node =
|
|
74
|
+
cls.cell_stack = {}
|
|
75
|
+
cls.api_parent_node = {}
|
|
71
76
|
cls.module_node = {}
|
|
72
77
|
cls.cell_bw_hook_kernels = {}
|
|
73
78
|
cls.cell_backward_pre_hook = []
|
|
@@ -122,6 +127,7 @@ class CellProcessor:
|
|
|
122
127
|
GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
|
|
123
128
|
|
|
124
129
|
def build_cell_hook(self, cell_name, build_data_hook):
|
|
130
|
+
@ThreadSafe.synchronized
|
|
125
131
|
def forward_pre_hook(cell, args):
|
|
126
132
|
index = CellProcessor.set_and_get_calls_number(cell_name)
|
|
127
133
|
full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
@@ -146,11 +152,13 @@ class CellProcessor:
|
|
|
146
152
|
setattr(cell, 'msprobe_forward_hook', True)
|
|
147
153
|
|
|
148
154
|
def get_backward_hook(backward_data_hook, full_backward_name):
|
|
155
|
+
@ThreadSafe.synchronized
|
|
149
156
|
def backward_hook_fn(cell, grad_input, grad_output):
|
|
150
157
|
new_output = backward_data_hook(cell, grad_input, grad_output)
|
|
151
158
|
self.set_construct_info_in_hook(full_backward_name)
|
|
152
159
|
cell.has_pre_hook_called = False
|
|
153
160
|
return new_output
|
|
161
|
+
|
|
154
162
|
return backward_hook_fn
|
|
155
163
|
|
|
156
164
|
enable_hooked = sum(
|
|
@@ -170,13 +178,14 @@ class CellProcessor:
|
|
|
170
178
|
|
|
171
179
|
return args
|
|
172
180
|
|
|
181
|
+
@ThreadSafe.synchronized
|
|
173
182
|
def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None):
|
|
174
183
|
index = CellProcessor.cell_count.get(cell_name, 0)
|
|
175
184
|
full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
176
185
|
full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
177
186
|
|
|
178
187
|
self.set_construct_info_in_hook(full_forward_name)
|
|
179
|
-
|
|
188
|
+
|
|
180
189
|
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
181
190
|
hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs)
|
|
182
191
|
if hook_result is not None:
|
|
@@ -199,6 +208,7 @@ class CellProcessor:
|
|
|
199
208
|
outputs = new_outputs
|
|
200
209
|
|
|
201
210
|
def get_backward_pre_hook(full_backward_name, backward_data_hook):
|
|
211
|
+
@ThreadSafe.synchronized
|
|
202
212
|
def backward_pre_hook_fn(cell, grad_output):
|
|
203
213
|
cell.has_pre_hook_called = True
|
|
204
214
|
self.set_construct_info_in_pre_hook(full_backward_name)
|
|
@@ -206,6 +216,7 @@ class CellProcessor:
|
|
|
206
216
|
backward_data_hook(cell, (), grad_output)
|
|
207
217
|
self.set_construct_info_in_hook(full_backward_name)
|
|
208
218
|
cell.has_pre_hook_called = False
|
|
219
|
+
|
|
209
220
|
return backward_pre_hook_fn
|
|
210
221
|
|
|
211
222
|
backward_pre_hook = OrderedDict()
|
|
@@ -233,18 +244,28 @@ class CellProcessor:
|
|
|
233
244
|
return forward_pre_hook
|
|
234
245
|
|
|
235
246
|
def set_construct_info_in_pre_hook(self, full_name):
|
|
236
|
-
|
|
237
|
-
|
|
247
|
+
tid = threading.get_ident()
|
|
248
|
+
if tid not in self.cell_stack:
|
|
249
|
+
CellProcessor.cell_stack[tid] = []
|
|
250
|
+
|
|
251
|
+
if self.cell_stack[tid]:
|
|
252
|
+
CellProcessor.module_node[full_name] = self.cell_stack[tid][-1]
|
|
238
253
|
else:
|
|
239
|
-
CellProcessor.
|
|
240
|
-
|
|
241
|
-
|
|
254
|
+
parent_name = CellProcessor.cell_queue.find_last(full_name)
|
|
255
|
+
CellProcessor.module_node[full_name] = parent_name
|
|
256
|
+
|
|
257
|
+
CellProcessor.cell_queue.add_name(full_name)
|
|
258
|
+
CellProcessor.cell_stack[tid].append(full_name)
|
|
259
|
+
CellProcessor.api_parent_node[tid] = full_name
|
|
242
260
|
if self.scope:
|
|
243
261
|
self.scope.begin_module(full_name)
|
|
244
262
|
|
|
245
263
|
def set_construct_info_in_hook(self, full_name):
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
264
|
+
tid = threading.get_ident()
|
|
265
|
+
CellProcessor.api_parent_node[tid] = None
|
|
266
|
+
if self.cell_stack.get(tid):
|
|
267
|
+
CellProcessor.cell_stack[tid].pop()
|
|
268
|
+
if self.cell_stack.get(tid):
|
|
269
|
+
CellProcessor.api_parent_node[tid] = CellProcessor.cell_stack[tid][-1]
|
|
249
270
|
if self.scope:
|
|
250
271
|
self.scope.end_module(full_name)
|
|
@@ -119,9 +119,17 @@ def find_npy_files(npy_path):
|
|
|
119
119
|
|
|
120
120
|
# 如果是目录,使用Path.rglob查找所有.npy文件
|
|
121
121
|
if npy_path_obj.is_dir():
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
base_depth = len(npy_path_obj.resolve().parts)
|
|
123
|
+
for root, dirs, files in os.walk(npy_path_obj):
|
|
124
|
+
current_depth = len(Path(root).resolve().parts) - base_depth
|
|
125
|
+
if current_depth >= 10:
|
|
126
|
+
dirs[:] = []
|
|
127
|
+
|
|
128
|
+
for filename in files:
|
|
129
|
+
if filename.endswith(Const.NUMPY_SUFFIX):
|
|
130
|
+
file_path = Path(root) / filename
|
|
131
|
+
check_file_or_directory_path(file_path)
|
|
132
|
+
npy_files.append(file_path.resolve())
|
|
125
133
|
else:
|
|
126
134
|
logger.info(f"The specified path is neither an .npy file nor a directory: {npy_path}")
|
|
127
135
|
|
|
@@ -254,7 +262,18 @@ def bind_code_info_for_data(input_dir: str, nodes: Dict[str, GraphNode]) -> Dict
|
|
|
254
262
|
corresponding_name = None
|
|
255
263
|
name_without_ext = os.path.splitext(corresponding_name)[0]
|
|
256
264
|
npy_path = os.path.realpath(npy_file)
|
|
257
|
-
|
|
265
|
+
|
|
266
|
+
parts = name_without_ext.split(".")
|
|
267
|
+
if len(parts) < 2:
|
|
268
|
+
logger.error(
|
|
269
|
+
f'File name "{file_name}" in "{directory}" '
|
|
270
|
+
f'does not conform to expected format (missing scope separator ".")!'
|
|
271
|
+
)
|
|
272
|
+
raise Exception(
|
|
273
|
+
f'File name "{file_name}" has incorrect format, cannot extract node scope!'
|
|
274
|
+
)
|
|
275
|
+
node_scope = parts[1]
|
|
276
|
+
|
|
258
277
|
trie = Trie()
|
|
259
278
|
for key, value in match_dict.items():
|
|
260
279
|
trie.insert(key, value)
|
|
@@ -77,7 +77,7 @@ class Parser:
|
|
|
77
77
|
|
|
78
78
|
@staticmethod
|
|
79
79
|
def extract_constants(inputs_str: str) -> List[str]:
|
|
80
|
-
constant_pattern = re.compile(r'\b(\
|
|
80
|
+
constant_pattern = re.compile(r'\b([A-Za-z_][A-Za-z0-9_]{0,10000})\(([A-Za-z0-9_\s,.\-+/]{0,10000})\)')
|
|
81
81
|
constants = constant_pattern.findall(inputs_str)
|
|
82
82
|
return constants
|
|
83
83
|
|
|
@@ -90,7 +90,8 @@ class Parser:
|
|
|
90
90
|
self.nodes[func_name] = func_graph_info
|
|
91
91
|
|
|
92
92
|
def parse_nodes(self, text: str, subgraph_info: GraphNode) -> None:
|
|
93
|
-
node_pattern = re.compile(
|
|
93
|
+
node_pattern = re.compile(
|
|
94
|
+
r'(%\d{1,10000})\(([A-Za-z0-9_\.]{1,10000})\)\s*=\s*([A-Za-z_][A-Za-z0-9_]{0,10000})\(')
|
|
94
95
|
matches = list(node_pattern.finditer(text))
|
|
95
96
|
for i, match in enumerate(matches):
|
|
96
97
|
series_number = match.group(1)
|
|
@@ -106,8 +107,9 @@ class Parser:
|
|
|
106
107
|
|
|
107
108
|
constants = self.__class__.extract_constants(args_str)
|
|
108
109
|
|
|
109
|
-
scope_pattern = re.compile(
|
|
110
|
-
|
|
110
|
+
scope_pattern = re.compile(
|
|
111
|
+
r'^(?=.{0,300}$)[ \t]*\#[ \t]*[^\r\n]*?scope[^\r\n]*?:[ \t]*\(([^)\r\n]{1,200})\)[ \t]*$',
|
|
112
|
+
re.IGNORECASE | re.MULTILINE)
|
|
111
113
|
scope_match = scope_pattern.search(text, end_pos)
|
|
112
114
|
scope = scope_match.group(1) if scope_match else ""
|
|
113
115
|
|
|
@@ -95,6 +95,9 @@ def save_tensor_as_npy(tensor, file_path):
|
|
|
95
95
|
|
|
96
96
|
|
|
97
97
|
def convert_to_int(value):
|
|
98
|
+
if isinstance(value, bool):
|
|
99
|
+
logger.error('The value in rank_id or step should be int, please check!')
|
|
100
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
98
101
|
try:
|
|
99
102
|
return int(value)
|
|
100
103
|
except Exception:
|
|
@@ -30,9 +30,10 @@ from msprobe.core.common.utils import CompareException
|
|
|
30
30
|
from msprobe.core.common.exceptions import FileCheckException
|
|
31
31
|
from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \
|
|
32
32
|
check_path_before_create, load_npy
|
|
33
|
-
from msprobe.core.common.const import CompareConst
|
|
33
|
+
from msprobe.core.common.const import CompareConst
|
|
34
34
|
from msprobe.core.compare.npy_compare import compare_ops_apply
|
|
35
35
|
from msprobe.core.compare.multiprocessing_compute import check_accuracy
|
|
36
|
+
from msprobe.mindspore.compare.utils import check_name_map_dict
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]:
|
|
@@ -49,6 +50,7 @@ def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataF
|
|
|
49
50
|
npu_root = Path(input_params.get('npu_path'))
|
|
50
51
|
bench_root = Path(input_params.get('bench_path'))
|
|
51
52
|
name_map_dict = input_params.get('map_dict', {})
|
|
53
|
+
check_name_map_dict(name_map_dict)
|
|
52
54
|
file_tree = build_mirror_file_tree(npu_root, bench_root)
|
|
53
55
|
|
|
54
56
|
# 处理文件比对
|
|
@@ -114,24 +116,42 @@ def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple
|
|
|
114
116
|
file_tree = {}
|
|
115
117
|
|
|
116
118
|
# 遍历NPU目录构建树结构
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
except FileCheckException:
|
|
119
|
+
# 使用os.walk遍历目录,限制深度为10层
|
|
120
|
+
for root, dirs, files in os.walk(npu_root):
|
|
121
|
+
# 计算当前目录深度
|
|
122
|
+
depth = len(Path(root).relative_to(npu_root).parts)
|
|
123
|
+
if depth > 10:
|
|
124
|
+
dirs.clear() # 清空dirs列表以阻止继续递归
|
|
124
125
|
continue
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
126
|
+
|
|
127
|
+
# 检查当前目录下是否有npy文件
|
|
128
|
+
if any(f.endswith('.npy') for f in files):
|
|
129
|
+
# 获取相对路径
|
|
130
|
+
dir_path = Path(root).relative_to(npu_root)
|
|
131
|
+
npu_dir_pair = os.path.join(npu_root, dir_path)
|
|
132
|
+
bench_dir_pair = os.path.join(bench_root, dir_path)
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
check_file_or_directory_path(bench_dir_pair, isdir=True)
|
|
136
|
+
except FileCheckException:
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
# 添加到文件树
|
|
140
|
+
if dir_path not in file_tree:
|
|
141
|
+
file_tree[dir_path] = (npu_dir_pair, bench_dir_pair)
|
|
128
142
|
|
|
129
143
|
return file_tree
|
|
130
144
|
|
|
131
145
|
|
|
132
146
|
def find_npy_files(directory):
|
|
133
147
|
npy_files_dict = {}
|
|
134
|
-
|
|
148
|
+
# 限制递归深度为1层,即只遍历当前目录和其直接子目录
|
|
149
|
+
for root, dirs, files in os.walk(directory, topdown=True):
|
|
150
|
+
# 计算当前目录深度
|
|
151
|
+
depth = root[len(directory):].count(os.sep)
|
|
152
|
+
# 如果深度超过10层则跳过
|
|
153
|
+
if depth > 10:
|
|
154
|
+
dirs.clear()
|
|
135
155
|
for file in files:
|
|
136
156
|
if file.endswith(".npy"):
|
|
137
157
|
# 分割文件名并去掉最后两个元素
|
|
@@ -168,8 +168,13 @@ class GraphMSComparator:
|
|
|
168
168
|
self.output_path = output_path
|
|
169
169
|
self.base_npu_path = input_param.get('npu_path', None)
|
|
170
170
|
self.base_bench_path = input_param.get('bench_path', None)
|
|
171
|
-
|
|
172
|
-
|
|
171
|
+
rank_id_list = input_param.get('rank_id', [])
|
|
172
|
+
step_id_list = input_param.get('step_id', [])
|
|
173
|
+
if not isinstance(rank_id_list, list) or not isinstance(step_id_list, list):
|
|
174
|
+
logger.error("'rank_id' and 'step_id' should both be lists, please check!")
|
|
175
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
176
|
+
self.rank_list = [convert_to_int(rank_id) for rank_id in rank_id_list]
|
|
177
|
+
self.step_list = [convert_to_int(step_id) for step_id in step_id_list]
|
|
173
178
|
# split by rank and step, generate rank step path
|
|
174
179
|
self.npu_rank_step_dict = self.generate_rank_step_path(self.base_npu_path)
|
|
175
180
|
self.bench_rank_step_dict = self.generate_rank_step_path(self.base_bench_path)
|
|
@@ -17,7 +17,8 @@ import os
|
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import Const
|
|
19
19
|
from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst
|
|
20
|
-
from msprobe.core.common.utils import detect_framework_by_dump_json
|
|
20
|
+
from msprobe.core.common.utils import detect_framework_by_dump_json, CompareException, check_op_str_pattern_valid
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def read_npy_data(dir_path, file_name):
|
|
@@ -35,3 +36,10 @@ def read_npy_data(dir_path, file_name):
|
|
|
35
36
|
def check_cross_framework(bench_json_path):
|
|
36
37
|
framework = detect_framework_by_dump_json(bench_json_path)
|
|
37
38
|
return framework == Const.PT_FRAMEWORK
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def check_name_map_dict(name_map_dict):
|
|
42
|
+
if not isinstance(name_map_dict, dict):
|
|
43
|
+
logger.error("'map_dict' should be a dict, please check!")
|
|
44
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
45
|
+
check_op_str_pattern_valid(str(name_map_dict))
|
|
@@ -81,18 +81,22 @@ class DebuggerConfig:
|
|
|
81
81
|
target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
|
|
82
82
|
if models is None or isinstance(models, target_module_type[0]):
|
|
83
83
|
return models
|
|
84
|
-
error_model = None
|
|
85
84
|
if isinstance(models, (list, tuple)):
|
|
85
|
+
error_model = None
|
|
86
86
|
for model in models:
|
|
87
87
|
if not isinstance(model, target_module_type[0]):
|
|
88
88
|
error_model = model
|
|
89
89
|
break
|
|
90
|
-
|
|
91
|
-
|
|
90
|
+
if error_model is not None:
|
|
91
|
+
error_info = (
|
|
92
|
+
f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
|
|
93
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
94
|
+
raise MsprobeException(
|
|
95
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
92
96
|
|
|
93
|
-
|
|
97
|
+
else:
|
|
94
98
|
error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
|
|
95
|
-
f"type, currently there is a {type(
|
|
99
|
+
f"type, currently there is a {type(models)} type.")
|
|
96
100
|
raise MsprobeException(
|
|
97
101
|
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
98
102
|
return models
|
|
@@ -125,16 +129,14 @@ class DebuggerConfig:
|
|
|
125
129
|
self.level_ori = Const.LEVEL_MIX
|
|
126
130
|
return True
|
|
127
131
|
|
|
128
|
-
def check_config_with_l2(self):
|
|
129
|
-
if self.
|
|
130
|
-
return
|
|
131
|
-
if self.task != Const.TENSOR:
|
|
132
|
+
def check_config_with_l2(self, is_graph_config):
|
|
133
|
+
if not is_graph_config and self.task != Const.TENSOR:
|
|
132
134
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
133
135
|
f"When level is set to L2, the task must be set to tensor.")
|
|
134
|
-
if self.scope:
|
|
136
|
+
if not is_graph_config and self.scope:
|
|
135
137
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
136
138
|
f"When level is set to L2, the scope cannot be configured.")
|
|
137
|
-
if not self.list or len(self.list) != 1:
|
|
139
|
+
if not is_graph_config and (not self.list or len(self.list) != 1):
|
|
138
140
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
139
141
|
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
140
142
|
|