mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
- msprobe/core/common/const.py +3 -0
- msprobe/core/common/file_utils.py +45 -5
- msprobe/core/common/utils.py +117 -13
- msprobe/core/common_config.py +15 -1
- msprobe/core/compare/acc_compare.py +21 -9
- msprobe/core/compare/compare_cli.py +10 -2
- msprobe/core/compare/merge_result/merge_result.py +1 -1
- msprobe/core/compare/utils.py +8 -2
- msprobe/core/config_check/checkers/base_checker.py +2 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
- msprobe/core/config_check/config_check_cli.py +1 -1
- msprobe/core/config_check/config_checker.py +1 -2
- msprobe/core/data_dump/data_collector.py +4 -1
- msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
- msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
- msprobe/core/debugger/precision_debugger.py +13 -8
- msprobe/core/hook_manager.py +112 -82
- msprobe/core/monitor/utils.py +338 -0
- msprobe/core/service.py +2 -1
- msprobe/core/single_save/single_comparator.py +5 -3
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +4 -4
- msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
- msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
- msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
- msprobe/docs/12.overflow_check_PyTorch.md +3 -2
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +35 -32
- msprobe/docs/21.visualization_PyTorch.md +9 -8
- msprobe/docs/22.visualization_MindSpore.md +1 -0
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/24.code_mapping_Mindspore.md +6 -5
- msprobe/docs/31.config_check.md +15 -5
- msprobe/docs/33.generate_operator_MindSpore.md +2 -2
- msprobe/docs/34.RL_collect.md +18 -9
- msprobe/docs/35.nan_analyze.md +4 -3
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
- msprobe/mindspore/cell_processor.py +35 -14
- msprobe/mindspore/code_mapping/bind.py +23 -4
- msprobe/mindspore/code_mapping/graph_parser.py +6 -4
- msprobe/mindspore/common/utils.py +3 -0
- msprobe/mindspore/compare/common_dir_compare.py +32 -12
- msprobe/mindspore/compare/ms_graph_compare.py +7 -2
- msprobe/mindspore/compare/utils.py +9 -1
- msprobe/mindspore/debugger/debugger_config.py +13 -11
- msprobe/mindspore/debugger/precision_debugger.py +67 -45
- msprobe/mindspore/dump/dump_tool_factory.py +2 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
- msprobe/mindspore/dump/jit_dump.py +6 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/monitor/common_func.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +3 -3
- msprobe/mindspore/monitor/utils.py +0 -252
- msprobe/mindspore/ms_config.py +0 -1
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/nan_analyze/graph.py +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
- msprobe/pytorch/common/utils.py +0 -16
- msprobe/pytorch/compare/pt_compare.py +5 -0
- msprobe/pytorch/debugger/debugger_config.py +12 -5
- msprobe/pytorch/debugger/precision_debugger.py +8 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
- msprobe/pytorch/hook_module/hook_module.py +9 -9
- msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
- msprobe/pytorch/monitor/csv2tb.py +3 -10
- msprobe/pytorch/monitor/features.py +5 -0
- msprobe/pytorch/monitor/module_hook.py +6 -7
- msprobe/pytorch/monitor/module_metric.py +0 -3
- msprobe/pytorch/monitor/optimizer_collect.py +1 -1
- msprobe/pytorch/monitor/utils.py +1 -317
- msprobe/pytorch/online_dispatch/dispatch.py +1 -1
- msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
- msprobe/pytorch/parse_tool/lib/utils.py +2 -4
- msprobe/visualization/graph_service.py +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
msprobe/core/common/utils.py
CHANGED
|
@@ -14,21 +14,22 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import collections
|
|
17
|
+
import functools
|
|
18
|
+
import inspect
|
|
17
19
|
import os
|
|
18
20
|
import re
|
|
19
|
-
import
|
|
21
|
+
import threading
|
|
20
22
|
import time
|
|
21
|
-
import
|
|
23
|
+
from collections import OrderedDict
|
|
22
24
|
from datetime import datetime, timezone
|
|
23
25
|
|
|
24
26
|
import numpy as np
|
|
25
27
|
|
|
26
|
-
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
27
28
|
from msprobe.core.common.const import Const, CompareConst
|
|
28
|
-
from msprobe.core.common.log import logger
|
|
29
|
-
from msprobe.core.common.exceptions import MsprobeException
|
|
30
29
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
|
-
|
|
30
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
31
|
+
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
32
|
+
from msprobe.core.common.log import logger
|
|
32
33
|
|
|
33
34
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
34
35
|
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
@@ -112,6 +113,82 @@ class DumpException(MsprobeBaseException):
|
|
|
112
113
|
return f"Dump Error Code {self.code}: {self.error_info}"
|
|
113
114
|
|
|
114
115
|
|
|
116
|
+
class ThreadSafe:
|
|
117
|
+
"""
|
|
118
|
+
线程安全控制工具类,提供三种使用方式:
|
|
119
|
+
1.上下文管理器:with ThreadSafe()
|
|
120
|
+
2.主动加锁与释放锁:ThreadSafe.acquire()/ThreadSafe.release()
|
|
121
|
+
3.方法装饰器:@ThreadSafe.synchronized
|
|
122
|
+
"""
|
|
123
|
+
_lock = threading.RLock()
|
|
124
|
+
|
|
125
|
+
def __enter__(self):
|
|
126
|
+
self.__class__._lock.acquire()
|
|
127
|
+
|
|
128
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
129
|
+
self.__class__._lock.release()
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def acquire(cls):
|
|
133
|
+
cls._lock.acquire()
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def release(cls):
|
|
137
|
+
cls._lock.release()
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def synchronized(cls, func):
|
|
141
|
+
@functools.wraps(func)
|
|
142
|
+
def wrapper(*args, **kwargs):
|
|
143
|
+
with cls._lock:
|
|
144
|
+
return func(*args, **kwargs)
|
|
145
|
+
|
|
146
|
+
return wrapper
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class ModuleQueue:
|
|
150
|
+
def __init__(self):
|
|
151
|
+
self.queue = OrderedDict()
|
|
152
|
+
|
|
153
|
+
def add_name(self, name):
|
|
154
|
+
self.queue[name] = True
|
|
155
|
+
|
|
156
|
+
def remove_name(self, name):
|
|
157
|
+
if name in self.queue:
|
|
158
|
+
del self.queue[name]
|
|
159
|
+
|
|
160
|
+
def find_last(self, name):
|
|
161
|
+
"""
|
|
162
|
+
在队列中找到当前 Module/Cell 的父节点名称并返回,若找不到则返回None
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
name: 需要寻找父节点的 Module/Cell 的名称
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
返回父节点名称,找不到则返回None
|
|
169
|
+
|
|
170
|
+
Examples:
|
|
171
|
+
父节点名称格式: Module.module1.module1.forward.0
|
|
172
|
+
子节点名称格式: Module.module1.module2.Module2.forward.0
|
|
173
|
+
匹配关系: Module/Cell 的名称总能被点(.)分割符分成5个部分及以上,子节点截断后4个点和父节点截断后3个点的前缀名称是匹配的
|
|
174
|
+
"""
|
|
175
|
+
child_parts = name.split('.')
|
|
176
|
+
if len(child_parts) < 5:
|
|
177
|
+
return None
|
|
178
|
+
child_name_prefix = '.'.join(child_parts[:-4])
|
|
179
|
+
if child_name_prefix in Const.MODULE_PREFIX:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
for parent_name in reversed(self.queue):
|
|
183
|
+
parent_parts = parent_name.split('.')
|
|
184
|
+
if len(parent_parts) < 5:
|
|
185
|
+
return None
|
|
186
|
+
parent_name_prefix = '.'.join(parent_parts[:-3])
|
|
187
|
+
if parent_name_prefix == child_name_prefix:
|
|
188
|
+
return parent_name
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
|
|
115
192
|
def is_json_file(file_path):
|
|
116
193
|
if isinstance(file_path, str) and file_path.lower().endswith('.json'):
|
|
117
194
|
return True
|
|
@@ -156,9 +233,10 @@ def check_compare_param(input_param, output_path, dump_mode, stack_mode):
|
|
|
156
233
|
|
|
157
234
|
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
|
|
158
235
|
arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
|
|
159
|
-
|
|
236
|
+
arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'is_print_compare_log']
|
|
237
|
+
for arg, name in zip(arg_list, arg_names):
|
|
160
238
|
if not isinstance(arg, bool):
|
|
161
|
-
logger.error(f"Invalid input parameter, {
|
|
239
|
+
logger.error(f"Invalid input parameter, {name} which should be only bool type.")
|
|
162
240
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
163
241
|
|
|
164
242
|
|
|
@@ -282,9 +360,9 @@ def set_dump_path(input_param):
|
|
|
282
360
|
npu_path = input_param.get("npu_json_path", None)
|
|
283
361
|
bench_path = input_param.get("bench_json_path", None)
|
|
284
362
|
dump_json_path_valid = npu_path is not None and npu_path.endswith("dump.json") and \
|
|
285
|
-
|
|
363
|
+
bench_path is not None and bench_path.endswith("dump.json")
|
|
286
364
|
debug_json_path_valid = npu_path is not None and npu_path.endswith("debug.json") and \
|
|
287
|
-
|
|
365
|
+
bench_path is not None and bench_path.endswith("debug.json")
|
|
288
366
|
if not dump_json_path_valid and not debug_json_path_valid:
|
|
289
367
|
logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.")
|
|
290
368
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
@@ -457,10 +535,10 @@ def get_real_step_or_rank(step_or_rank_input, obj):
|
|
|
457
535
|
def check_init_step(step):
|
|
458
536
|
if not is_int(step):
|
|
459
537
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
460
|
-
|
|
538
|
+
f"{step} must be an integer")
|
|
461
539
|
if not step >= 0:
|
|
462
540
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
463
|
-
|
|
541
|
+
f"{step} must be greater than or equal to 0")
|
|
464
542
|
|
|
465
543
|
|
|
466
544
|
def check_token_range(token_range):
|
|
@@ -568,14 +646,25 @@ def replace_last_occurrence(text, old, new):
|
|
|
568
646
|
|
|
569
647
|
def load_stack_json(stack_path):
|
|
570
648
|
stack_dict = load_json(stack_path)
|
|
649
|
+
|
|
650
|
+
if not isinstance(stack_dict, dict):
|
|
651
|
+
raise MsprobeException(
|
|
652
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
653
|
+
"The format of the stack.json is incorrect, the outermost layer of stack.json should be a dict type."
|
|
654
|
+
)
|
|
655
|
+
|
|
571
656
|
if not stack_dict.get(Const.NEW_STACK_FLAG):
|
|
572
657
|
return stack_dict
|
|
573
658
|
|
|
574
659
|
new_stack_dict = {}
|
|
575
660
|
for stack_info in stack_dict.values():
|
|
576
|
-
if len(stack_info) != 2:
|
|
661
|
+
if not isinstance(stack_info, list) or len(stack_info) != 2:
|
|
577
662
|
continue
|
|
663
|
+
|
|
578
664
|
api_list, stack_str = stack_info
|
|
665
|
+
if not isinstance(api_list, list):
|
|
666
|
+
continue
|
|
667
|
+
|
|
579
668
|
for api_name in api_list:
|
|
580
669
|
new_stack_dict.update({api_name: stack_str})
|
|
581
670
|
return new_stack_dict
|
|
@@ -597,3 +686,18 @@ def analyze_api_call_stack(name):
|
|
|
597
686
|
else:
|
|
598
687
|
stack_str.append(Const.WITHOUT_CALL_STACK)
|
|
599
688
|
return "".join(stack_str)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
def check_extern_input_list(input_list):
|
|
692
|
+
if not isinstance(input_list, list):
|
|
693
|
+
raise Exception("input is not a list")
|
|
694
|
+
if len(input_list) > Const.EXTERN_INPUT_LIST_MAX_LEN:
|
|
695
|
+
raise Exception(f"input list exceed max length {Const.EXTERN_INPUT_LIST_MAX_LEN}")
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def check_process_num(process_num):
|
|
699
|
+
if not is_int(process_num) or process_num <= 0:
|
|
700
|
+
raise ValueError(f"process_num({process_num}) is not a positive integer")
|
|
701
|
+
if process_num > Const.MAX_PROCESS_NUM:
|
|
702
|
+
raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.")
|
|
703
|
+
|
msprobe/core/common_config.py
CHANGED
|
@@ -13,7 +13,9 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
17
19
|
from msprobe.core.common.log import logger
|
|
18
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
19
21
|
from msprobe.core.common.utils import get_real_step_or_rank
|
|
@@ -67,6 +69,7 @@ class BaseConfig:
|
|
|
67
69
|
self.if_preheat = json_config.get("if_preheat")
|
|
68
70
|
self.preheat_step = json_config.get("preheat_step")
|
|
69
71
|
self.max_sample = json_config.get("max_sample")
|
|
72
|
+
self.is_regex_valid = True
|
|
70
73
|
|
|
71
74
|
@staticmethod
|
|
72
75
|
def _check_str_list_config(config_item, config_name):
|
|
@@ -83,6 +86,7 @@ class BaseConfig:
|
|
|
83
86
|
self._check_str_list_config(self.scope, "scope")
|
|
84
87
|
self._check_str_list_config(self.list, "list")
|
|
85
88
|
self._check_data_mode()
|
|
89
|
+
self._check_regex_in_list()
|
|
86
90
|
|
|
87
91
|
def _check_data_mode(self):
|
|
88
92
|
if self.data_mode is not None:
|
|
@@ -118,3 +122,13 @@ class BaseConfig:
|
|
|
118
122
|
f"summary_mode is invalid, summary_mode is not in {Const.SUMMARY_MODE}.",
|
|
119
123
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
120
124
|
)
|
|
125
|
+
|
|
126
|
+
def _check_regex_in_list(self):
|
|
127
|
+
if self.list:
|
|
128
|
+
for name in self.list:
|
|
129
|
+
if name.startswith('name-regex(') and name.endswith(')'):
|
|
130
|
+
try:
|
|
131
|
+
re.compile(name[len('name-regex('):-1])
|
|
132
|
+
except re.error:
|
|
133
|
+
self.is_regex_valid = False
|
|
134
|
+
break
|
|
@@ -31,7 +31,7 @@ from msprobe.core.common.utils import CompareException, add_time_with_xlsx, chec
|
|
|
31
31
|
set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type
|
|
32
32
|
from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping
|
|
33
33
|
from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \
|
|
34
|
-
reorder_op_x_list, set_stack_json_path
|
|
34
|
+
reorder_op_x_list, set_stack_json_path, check_api_info_len
|
|
35
35
|
from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict
|
|
36
36
|
from msprobe.core.compare.multiprocessing_compute import CompareRealData
|
|
37
37
|
from msprobe.core.compare.highlight import HighLight
|
|
@@ -211,25 +211,37 @@ class ParseData:
|
|
|
211
211
|
for index, op_name in enumerate(op_name_reorder):
|
|
212
212
|
result[CompareConst.OP_NAME].append(op_name)
|
|
213
213
|
if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
|
|
214
|
-
|
|
214
|
+
info_list = merge_list[CompareConst.INPUT_STRUCT]
|
|
215
215
|
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
216
|
-
|
|
216
|
+
info_list = merge_list[CompareConst.OUTPUT_STRUCT]
|
|
217
217
|
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
218
|
-
|
|
218
|
+
info_list = merge_list[CompareConst.PARAMS_STRUCT]
|
|
219
219
|
elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
|
|
220
|
-
|
|
220
|
+
info_list = merge_list[CompareConst.PARAMS_GRAD_STRUCT]
|
|
221
221
|
else:
|
|
222
|
-
|
|
222
|
+
info_list = merge_list[CompareConst.DEBUG_STRUCT]
|
|
223
|
+
check_api_info_len(op_name, info_list, 1)
|
|
224
|
+
struct = info_list.pop(0)
|
|
225
|
+
|
|
226
|
+
check_api_info_len(op_name, struct, 2)
|
|
223
227
|
result[Const.DTYPE].append(struct[0])
|
|
224
228
|
result[Const.SHAPE].append(struct[1])
|
|
225
229
|
if self.mode_config.dump_mode == Const.MD5:
|
|
230
|
+
check_api_info_len(op_name, struct, 3)
|
|
226
231
|
result[Const.MD5].append(struct[2])
|
|
232
|
+
|
|
233
|
+
check_api_info_len(op_name, summary_reorder, 1)
|
|
227
234
|
result[Const.SUMMARY].append(summary_reorder.pop(0))
|
|
228
|
-
|
|
229
|
-
|
|
235
|
+
|
|
236
|
+
if index == 0 and self.mode_config.stack_mode:
|
|
237
|
+
check_api_info_len(op_name, merge_list[Const.STACK_INFO], 1)
|
|
238
|
+
result[Const.STACK_INFO].append(merge_list[Const.STACK_INFO][0])
|
|
239
|
+
else:
|
|
240
|
+
result[Const.STACK_INFO].append(None)
|
|
241
|
+
|
|
230
242
|
if self.mode_config.dump_mode == Const.ALL:
|
|
243
|
+
check_api_info_len(op_name, data_name_reorder, 1)
|
|
231
244
|
result['data_name'].append(data_name_reorder.pop(0))
|
|
232
|
-
|
|
233
245
|
progress_bar.update(1)
|
|
234
246
|
progress_bar.close()
|
|
235
247
|
return pd.DataFrame(result)
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
from msprobe.core.common.file_utils import check_file_type, load_json
|
|
17
|
+
from msprobe.core.common.file_utils import check_file_type, load_json, check_file_or_directory_path
|
|
18
18
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
19
19
|
from msprobe.core.common.utils import CompareException
|
|
20
20
|
from msprobe.core.common.log import logger
|
|
@@ -22,6 +22,9 @@ from msprobe.core.common.log import logger
|
|
|
22
22
|
|
|
23
23
|
def compare_cli(args):
|
|
24
24
|
input_param = load_json(args.input_path)
|
|
25
|
+
if not isinstance(input_param, dict):
|
|
26
|
+
logger.error("input_param should be dict, please check!")
|
|
27
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
25
28
|
npu_path = input_param.get("npu_path", None)
|
|
26
29
|
bench_path = input_param.get("bench_path", None)
|
|
27
30
|
if not npu_path:
|
|
@@ -47,6 +50,8 @@ def compare_cli(args):
|
|
|
47
50
|
}
|
|
48
51
|
|
|
49
52
|
if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
|
|
53
|
+
check_file_or_directory_path(npu_path)
|
|
54
|
+
check_file_or_directory_path(bench_path)
|
|
50
55
|
input_param["npu_json_path"] = input_param.pop("npu_path")
|
|
51
56
|
input_param["bench_json_path"] = input_param.pop("bench_path")
|
|
52
57
|
if "stack_path" not in input_param:
|
|
@@ -68,6 +73,8 @@ def compare_cli(args):
|
|
|
68
73
|
}
|
|
69
74
|
ms_compare(input_param, args.output_path, **kwargs)
|
|
70
75
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
76
|
+
check_file_or_directory_path(npu_path, isdir=True)
|
|
77
|
+
check_file_or_directory_path(bench_path, isdir=True)
|
|
71
78
|
kwargs = {
|
|
72
79
|
**common_kwargs,
|
|
73
80
|
"stack_mode": args.stack_mode,
|
|
@@ -79,7 +86,8 @@ def compare_cli(args):
|
|
|
79
86
|
if input_param.get("rank_id") is not None:
|
|
80
87
|
ms_graph_compare(input_param, args.output_path)
|
|
81
88
|
return
|
|
82
|
-
|
|
89
|
+
common = input_param.get("common", False)
|
|
90
|
+
if isinstance(common, bool) and common:
|
|
83
91
|
common_dir_compare(input_param, args.output_path)
|
|
84
92
|
return
|
|
85
93
|
if frame_name == Const.PT_FRAMEWORK:
|
|
@@ -196,7 +196,7 @@ def result_process(compare_result_path_list, api_list):
|
|
|
196
196
|
compare_index_dict = {}
|
|
197
197
|
result_df = read_xlsx(compare_result_path)
|
|
198
198
|
|
|
199
|
-
rank_pattern = r"compare_result_rank(\d+)
|
|
199
|
+
rank_pattern = r"compare_result_rank(\d+)"
|
|
200
200
|
rank_num = int(re.search(rank_pattern, os.path.basename(compare_result_path)).group(1))
|
|
201
201
|
logger.info(f"Parsing rank{rank_num} compare result...")
|
|
202
202
|
if not result_df.empty:
|
msprobe/core/compare/utils.py
CHANGED
|
@@ -238,6 +238,12 @@ def merge_tensor(tensor_list, dump_mode):
|
|
|
238
238
|
return op_dict if op_dict[CompareConst.OP_NAME] else {}
|
|
239
239
|
|
|
240
240
|
|
|
241
|
+
def check_api_info_len(op_name, info_list, len_require):
|
|
242
|
+
if len(info_list) < len_require:
|
|
243
|
+
logger.error(f'Index out of bounds error, please check info of api: {op_name}.')
|
|
244
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
245
|
+
|
|
246
|
+
|
|
241
247
|
def print_compare_ends_info():
|
|
242
248
|
total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
243
249
|
logger.info('*' * total_len)
|
|
@@ -509,8 +515,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
509
515
|
|
|
510
516
|
result.append(result_item)
|
|
511
517
|
|
|
512
|
-
|
|
513
|
-
|
|
518
|
+
_, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict)
|
|
519
|
+
_, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict)
|
|
514
520
|
|
|
515
521
|
get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT)
|
|
516
522
|
get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params,
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
|
+
from msprobe.core.common.file_utils import check_path_pattern_valid
|
|
18
19
|
from msprobe.core.common.framework_adapter import FmkAdp
|
|
19
20
|
from msprobe.core.common.const import FileCheckConst
|
|
20
21
|
|
|
@@ -32,6 +33,7 @@ class PackInput:
|
|
|
32
33
|
raise Exception(f"model is not torch.nn.Module/mindspore.nn.Cell or module list.")
|
|
33
34
|
if not isinstance(self.output_zip_path, str) or not self.output_zip_path.endswith(FileCheckConst.ZIP_SUFFIX):
|
|
34
35
|
raise Exception(f"output zip path must be a string and ends with '.zip'")
|
|
36
|
+
check_path_pattern_valid(self.output_zip_path)
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
class BaseChecker:
|
|
@@ -20,12 +20,13 @@ from difflib import SequenceMatcher
|
|
|
20
20
|
from typing import Union, List, Dict, Any
|
|
21
21
|
import pandas as pd
|
|
22
22
|
|
|
23
|
+
from msprobe.core.common.utils import check_extern_input_list
|
|
23
24
|
from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
24
25
|
from msprobe.core.config_check.config_checker import register_checker_item
|
|
25
26
|
from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict
|
|
26
27
|
from msprobe.core.config_check.utils.hyperparameter_parser import ParserFactory
|
|
27
|
-
from msprobe.core.common.file_utils import (
|
|
28
|
-
|
|
28
|
+
from msprobe.core.common.file_utils import (check_file_or_directory_path, create_file_in_zip, load_json,
|
|
29
|
+
load_yaml)
|
|
29
30
|
from msprobe.core.common.const import Const
|
|
30
31
|
|
|
31
32
|
|
|
@@ -47,13 +48,13 @@ class HyperparameterChecker(BaseChecker):
|
|
|
47
48
|
output_zip_path = pack_input.output_zip_path
|
|
48
49
|
|
|
49
50
|
if shell_path:
|
|
50
|
-
|
|
51
|
-
raise TypeError("shell_path should be a list of file paths.")
|
|
51
|
+
check_extern_input_list(shell_path)
|
|
52
52
|
|
|
53
53
|
hyperparameters = {}
|
|
54
54
|
parser_factory = ParserFactory()
|
|
55
55
|
for script_path in shell_path:
|
|
56
56
|
if os.path.isfile(script_path):
|
|
57
|
+
check_file_or_directory_path(script_path)
|
|
57
58
|
parser = parser_factory.get_parser(os.path.splitext(script_path)[1])
|
|
58
59
|
update_dict(hyperparameters, parser.run(os.path.realpath(script_path)))
|
|
59
60
|
else:
|
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
from typing import Dict
|
|
17
17
|
from tqdm import tqdm
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.file_utils import save_json, check_path_before_create, check_path_not_exists
|
|
19
|
+
from msprobe.core.common.file_utils import save_json, check_path_before_create, check_path_not_exists, \
|
|
20
|
+
check_file_or_directory_path
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
21
22
|
from msprobe.core.config_check.ckpt_compare.megatron_loader import load_megatron_weights
|
|
22
23
|
from msprobe.core.config_check.ckpt_compare.metrics import METRIC_FUNC
|
|
@@ -44,6 +45,8 @@ def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict:
|
|
|
44
45
|
"""
|
|
45
46
|
|
|
46
47
|
# Load both checkpoints
|
|
48
|
+
check_file_or_directory_path(ckpt_path1, isdir=True)
|
|
49
|
+
check_file_or_directory_path(ckpt_path2, isdir=True)
|
|
47
50
|
check_path_before_create(output_path)
|
|
48
51
|
check_path_not_exists(output_path)
|
|
49
52
|
weights1 = load_megatron_weights(ckpt_path1)
|
|
@@ -29,7 +29,7 @@ def compare(bench_zip_path, cmp_zip_path, output_path, framework):
|
|
|
29
29
|
def _config_checking_parser(parser):
|
|
30
30
|
parser.add_argument('-d', '--dump', nargs='*', help='Collect the train config into a zip file')
|
|
31
31
|
parser.add_argument('-c', '--compare', nargs=2, help='Compare two zip files or checkpoints')
|
|
32
|
-
parser.add_argument('-o', '--output', help='output path, default is
|
|
32
|
+
parser.add_argument('-o', '--output', help='output path, default is ./config_check_result')
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def _run_config_checking_command(args):
|
|
@@ -43,8 +43,7 @@ class ConfigChecker:
|
|
|
43
43
|
|
|
44
44
|
@staticmethod
|
|
45
45
|
def compare(bench_zip_path, cmp_zip_path, output_path, fmk=Const.PT_FRAMEWORK):
|
|
46
|
-
|
|
47
|
-
shutil.rmtree(output_path)
|
|
46
|
+
create_directory(output_path)
|
|
48
47
|
bench_dir = os.path.join(output_path, "bench")
|
|
49
48
|
cmp_dir = os.path.join(output_path, "cmp")
|
|
50
49
|
extract_zip(bench_zip_path, bench_dir)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import atexit
|
|
17
17
|
import os
|
|
18
|
+
import threading
|
|
18
19
|
import traceback
|
|
19
20
|
|
|
20
21
|
from msprobe.core.data_dump.scope import ScopeFactory
|
|
@@ -255,7 +256,9 @@ class DataCollector:
|
|
|
255
256
|
else:
|
|
256
257
|
if self.config.level == Const.LEVEL_MIX and \
|
|
257
258
|
not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
|
|
258
|
-
self.data_writer.update_construct(
|
|
259
|
+
self.data_writer.update_construct(
|
|
260
|
+
{name: self.module_processor.api_parent_node.get(threading.get_ident())}
|
|
261
|
+
)
|
|
259
262
|
|
|
260
263
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
261
264
|
|
|
@@ -28,6 +28,7 @@ from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
|
28
28
|
from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
|
|
29
29
|
from msprobe.mindspore.common.log import logger
|
|
30
30
|
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
31
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
31
32
|
|
|
32
33
|
has_adump = True
|
|
33
34
|
try:
|
|
@@ -35,9 +36,15 @@ try:
|
|
|
35
36
|
except ImportError:
|
|
36
37
|
has_adump = False
|
|
37
38
|
|
|
39
|
+
if is_mindtorch():
|
|
40
|
+
from torch import distributed as dist
|
|
41
|
+
|
|
38
42
|
|
|
39
43
|
class MindsporeDataProcessor(BaseDataProcessor):
|
|
40
|
-
|
|
44
|
+
if is_mindtorch():
|
|
45
|
+
mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp, dist.ProcessGroup])
|
|
46
|
+
else:
|
|
47
|
+
mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp])
|
|
41
48
|
|
|
42
49
|
def __init__(self, config, data_writer):
|
|
43
50
|
super().__init__(config, data_writer)
|
|
@@ -114,6 +121,19 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
114
121
|
group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
|
|
115
122
|
return f"{group_ranks_hash:08x}"
|
|
116
123
|
|
|
124
|
+
@staticmethod
|
|
125
|
+
def _analyze_process_group(arg):
|
|
126
|
+
group_info = {"type": "mindspore.ProcessGroup"}
|
|
127
|
+
try:
|
|
128
|
+
group_ranks = dist.get_process_group_ranks(arg)
|
|
129
|
+
group_info.update({"group_ranks": group_ranks})
|
|
130
|
+
group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
|
|
131
|
+
group_id = f"{group_ranks_hash:08x}"
|
|
132
|
+
group_info.update({"group_id": group_id})
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.warning(f"Failed to get process group ranks info with error info: {e}.")
|
|
135
|
+
return group_info
|
|
136
|
+
|
|
117
137
|
@classmethod
|
|
118
138
|
def get_special_types(cls):
|
|
119
139
|
return super().get_special_types() + cls.mindspore_special_type
|
|
@@ -149,6 +169,8 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
149
169
|
(np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
|
|
150
170
|
(distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str))
|
|
151
171
|
]
|
|
172
|
+
if is_mindtorch():
|
|
173
|
+
type_analyzer.append((dist.ProcessGroup, self._analyze_process_group))
|
|
152
174
|
for type_key, analyze_fn in type_analyzer:
|
|
153
175
|
if isinstance(element, type_key):
|
|
154
176
|
return analyze_fn(element)
|
|
@@ -30,7 +30,7 @@ from msprobe.core.common.utils import convert_tuple
|
|
|
30
30
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
31
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
32
32
|
ModuleForwardInputsOutputs, TensorStatInfo
|
|
33
|
-
from msprobe.pytorch.common.utils import
|
|
33
|
+
from msprobe.pytorch.common.utils import save_pt
|
|
34
34
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
|
|
35
35
|
|
|
36
36
|
is_gpu = False
|
|
@@ -181,7 +181,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
181
181
|
|
|
182
182
|
@staticmethod
|
|
183
183
|
def _analyze_torch_size(arg):
|
|
184
|
-
return {"type": "torch.Size", "value": list(arg)}
|
|
184
|
+
return {"type": "torch.Size", "value": [int(x) for x in list(arg)]}
|
|
185
185
|
|
|
186
186
|
@staticmethod
|
|
187
187
|
def _analyze_memory_format(arg):
|
|
@@ -210,18 +210,6 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
210
210
|
logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
|
|
211
211
|
return {"type": "torch.distributed.ReduceOp", "value": op_type}
|
|
212
212
|
|
|
213
|
-
@staticmethod
|
|
214
|
-
def _cast_to_float_if_fp8(tensor):
|
|
215
|
-
dtype = str(tensor.dtype)
|
|
216
|
-
if is_float8_tensor(tensor):
|
|
217
|
-
dtype = PtConst.HIFLOAT8_TYPE if is_hifloat8_tensor(tensor) else dtype
|
|
218
|
-
logger.debug(
|
|
219
|
-
f"The {dtype} tensor analyzing/saving is unsupported in dump function."
|
|
220
|
-
f"Casting to float for processing."
|
|
221
|
-
)
|
|
222
|
-
tensor = tensor.float()
|
|
223
|
-
return tensor, dtype
|
|
224
|
-
|
|
225
213
|
@classmethod
|
|
226
214
|
def get_special_types(cls):
|
|
227
215
|
return super().get_special_types() + cls.pytorch_special_type
|
|
@@ -268,11 +256,10 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
268
256
|
return p2pop_info
|
|
269
257
|
|
|
270
258
|
def _analyze_tensor(self, tensor, suffix):
|
|
271
|
-
tensor, dtype = self._cast_to_float_if_fp8(tensor)
|
|
272
259
|
tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
|
|
273
260
|
tensor_json = {}
|
|
274
261
|
tensor_json.update({'type': 'torch.Tensor'})
|
|
275
|
-
tensor_json.update({'dtype': dtype})
|
|
262
|
+
tensor_json.update({'dtype': str(tensor.dtype)})
|
|
276
263
|
tensor_json.update({"shape": tensor.shape})
|
|
277
264
|
|
|
278
265
|
stat_values = [
|
|
@@ -295,7 +282,6 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
295
282
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
296
283
|
single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
|
|
297
284
|
single_arg.update({"data_name": dump_data_name})
|
|
298
|
-
tensor, _ = self._cast_to_float_if_fp8(tensor)
|
|
299
285
|
if self.config.async_dump:
|
|
300
286
|
self._async_dump_cache[file_path] = tensor.clone().detach()
|
|
301
287
|
else:
|
|
@@ -396,7 +382,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
396
382
|
self._analyze_maybe_overflow_flag()
|
|
397
383
|
if self.has_overflow:
|
|
398
384
|
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
399
|
-
tensor, _ = self._cast_to_float_if_fp8(tensor)
|
|
400
385
|
save_pt(tensor.clone().contiguous().detach(), file_path)
|
|
401
386
|
self.real_overflow_nums += 1
|
|
402
387
|
if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
|
|
@@ -588,11 +573,6 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
588
573
|
)
|
|
589
574
|
def clone_and_detach_tensor(self, input_params):
|
|
590
575
|
if isinstance(input_params, torch.Tensor):
|
|
591
|
-
if is_float8_tensor(input_params):
|
|
592
|
-
raise MsprobeException(
|
|
593
|
-
MsprobeException.UNSUPPORTED_TYPE_ERROR,
|
|
594
|
-
f"L2 backward dump does not support float8 type."
|
|
595
|
-
)
|
|
596
576
|
if input_params.requires_grad:
|
|
597
577
|
return input_params.clone().detach().requires_grad_()
|
|
598
578
|
return input_params.clone()
|
|
@@ -607,8 +587,6 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
607
587
|
|
|
608
588
|
def analyze_single_element(self, element, suffix_stack):
|
|
609
589
|
if isinstance(element, torch.Tensor):
|
|
610
|
-
if is_float8_tensor(element):
|
|
611
|
-
return {}
|
|
612
590
|
if not self.is_found_output_tensor:
|
|
613
591
|
if element.requires_grad:
|
|
614
592
|
self.forward_output_tensor = element
|