mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
msprobe/pytorch/monitor/utils.py
CHANGED
|
@@ -25,7 +25,7 @@ import torch
|
|
|
25
25
|
from msprobe.core.common.const import MonitorConst, Const
|
|
26
26
|
from msprobe.pytorch.common.log import logger
|
|
27
27
|
from msprobe.core.common.utils import is_int
|
|
28
|
-
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
28
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
device = "cpu"
|
|
@@ -105,6 +105,15 @@ def validate_ops(ops):
|
|
|
105
105
|
return valid_ops
|
|
106
106
|
|
|
107
107
|
|
|
108
|
+
def validate_ndigits(ndigits):
|
|
109
|
+
if not ndigits:
|
|
110
|
+
return
|
|
111
|
+
if not is_int(ndigits) or ndigits <= 0:
|
|
112
|
+
raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
|
|
113
|
+
if ndigits > MonitorConst.MAX_NDIGITS:
|
|
114
|
+
raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
|
|
115
|
+
|
|
116
|
+
|
|
108
117
|
def validate_ranks(ranks):
|
|
109
118
|
if not isinstance(ranks, list):
|
|
110
119
|
raise TypeError("module_ranks should be a list")
|
|
@@ -206,9 +215,17 @@ def validate_step_count_per_record(step_count_per_record):
|
|
|
206
215
|
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
207
216
|
|
|
208
217
|
|
|
218
|
+
def validate_dynamic_on(dynamic_on):
|
|
219
|
+
if not isinstance(dynamic_on, bool):
|
|
220
|
+
raise TypeError('dynamic_on should be a bool')
|
|
221
|
+
|
|
222
|
+
|
|
209
223
|
def validate_config(config):
|
|
210
224
|
config['ops'] = validate_ops(config.get('ops', []))
|
|
211
225
|
|
|
226
|
+
ndigits = config.get('ndigits')
|
|
227
|
+
validate_ndigits(ndigits)
|
|
228
|
+
|
|
212
229
|
eps = config.get('eps', 1e-8)
|
|
213
230
|
if not isinstance(eps, float):
|
|
214
231
|
raise TypeError("eps should be a float")
|
|
@@ -246,9 +263,20 @@ def validate_config(config):
|
|
|
246
263
|
step_count_per_record = config.get('step_count_per_record', 1)
|
|
247
264
|
validate_step_count_per_record(step_count_per_record)
|
|
248
265
|
|
|
266
|
+
config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
|
|
267
|
+
MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
|
|
268
|
+
config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
|
|
269
|
+
MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
|
|
270
|
+
MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
|
|
271
|
+
config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
|
|
272
|
+
MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
|
|
273
|
+
|
|
249
274
|
squash_name = config.get('squash_name', True)
|
|
250
275
|
validate_squash_name(squash_name)
|
|
251
276
|
|
|
277
|
+
dynamic_on = config.get('dynamic_on', False)
|
|
278
|
+
validate_dynamic_on(dynamic_on)
|
|
279
|
+
|
|
252
280
|
if not targets:
|
|
253
281
|
if xy_distribution:
|
|
254
282
|
config["all_xy"] = True
|
|
@@ -257,6 +285,8 @@ def validate_config(config):
|
|
|
257
285
|
|
|
258
286
|
def time_str2time_digit(time_str):
|
|
259
287
|
time_format = '%b%d_%H-%M-%S'
|
|
288
|
+
if not isinstance(time_str, str):
|
|
289
|
+
raise TypeError(f"time_str:{time_str} should be a str")
|
|
260
290
|
try:
|
|
261
291
|
time_digit = datetime.strptime(time_str, time_format)
|
|
262
292
|
except Exception as e:
|
|
@@ -284,3 +314,40 @@ def get_target_output_dir(monitor_path, time_start, time_end):
|
|
|
284
314
|
if start_ok and end_ok:
|
|
285
315
|
result[rank] = os.path.join(monitor_path, dirname)
|
|
286
316
|
return result
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def chmod_tensorboard_dir(path):
|
|
320
|
+
"""
|
|
321
|
+
format配置为tensorboard时,需要补充文件权限设置
|
|
322
|
+
"""
|
|
323
|
+
try:
|
|
324
|
+
recursive_chmod(path)
|
|
325
|
+
except Exception as e:
|
|
326
|
+
logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def validate_set_monitor(grad_acc_steps, start_iteration):
|
|
330
|
+
"""
|
|
331
|
+
validate parameters of set_monitor.
|
|
332
|
+
"""
|
|
333
|
+
grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
|
|
334
|
+
MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
|
|
335
|
+
|
|
336
|
+
start_iteration = validate_int_arg(start_iteration, "start_iteration",
|
|
337
|
+
MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
|
|
338
|
+
return grad_acc_steps, start_iteration
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def validate_int_arg(value, name, minimum, default_value):
|
|
342
|
+
"""Validate int args, if any exception occurs, use the default value."""
|
|
343
|
+
if value is None:
|
|
344
|
+
return default_value
|
|
345
|
+
try:
|
|
346
|
+
if not is_int(value):
|
|
347
|
+
raise TypeError(f"{name} must be int")
|
|
348
|
+
if value < minimum:
|
|
349
|
+
raise ValueError(f"{name} must greater than {minimum}")
|
|
350
|
+
except Exception as e:
|
|
351
|
+
value = default_value
|
|
352
|
+
logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
|
|
353
|
+
return value
|
|
@@ -125,8 +125,6 @@ class Saver:
|
|
|
125
125
|
|
|
126
126
|
def write_summary_csv(self, test_result):
|
|
127
127
|
test_rows = []
|
|
128
|
-
if self.stack_info:
|
|
129
|
-
test_rows[0].append(self.COLUMN_STACK_INFO)
|
|
130
128
|
|
|
131
129
|
check_op_str_pattern_valid(test_result.api_name)
|
|
132
130
|
df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import json
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
|
+
import multiprocessing
|
|
19
20
|
from multiprocessing import Pool
|
|
20
21
|
|
|
21
22
|
import torch
|
|
@@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
52
53
|
return
|
|
53
54
|
if dump_path is None:
|
|
54
55
|
logger.error("Please set dump_path when dump_mode is config!")
|
|
56
|
+
raise DispatchException("Please set dump_path when dump_mode is config!")
|
|
55
57
|
check_file_or_directory_path(dump_path, True)
|
|
56
58
|
|
|
57
59
|
self.device_id = torch_npu._C._npu_getDevice()
|
|
@@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
85
87
|
self.get_ops(yaml_path)
|
|
86
88
|
|
|
87
89
|
self.lock = None
|
|
90
|
+
max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
|
|
91
|
+
if process_num > max_process_num:
|
|
92
|
+
logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
|
|
93
|
+
raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
|
|
94
|
+
f'but got {process_num}!')
|
|
88
95
|
if process_num > 0:
|
|
89
96
|
self.pool = Pool(process_num)
|
|
90
97
|
if debug:
|
|
@@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
115
122
|
if len(json_line_data) == 0:
|
|
116
123
|
break
|
|
117
124
|
msg = json.loads(json_line_data)
|
|
125
|
+
if len(msg) < 2:
|
|
126
|
+
raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
|
|
118
127
|
self.all_summary[msg[0]] = msg[1]
|
|
119
128
|
fp_handle.close()
|
|
120
129
|
|
|
@@ -19,6 +19,8 @@ import os
|
|
|
19
19
|
from datetime import datetime, timezone
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
22
24
|
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
|
|
23
25
|
from msprobe.pytorch.common.log import logger
|
|
24
26
|
|
|
@@ -91,6 +93,7 @@ def support_basic_type(data):
|
|
|
91
93
|
return False
|
|
92
94
|
|
|
93
95
|
|
|
96
|
+
@recursion_depth_decorator("dump_data")
|
|
94
97
|
def dump_data(data, prefix, dump_path):
|
|
95
98
|
if isinstance(data, (tuple, list)) and data:
|
|
96
99
|
for i, item in enumerate(data):
|
|
@@ -27,8 +27,10 @@ else:
|
|
|
27
27
|
pta_cpu_device = torch.device("cpu")
|
|
28
28
|
|
|
29
29
|
from msprobe.core.common.const import CompareConst
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
31
|
from msprobe.pytorch.common.log import logger
|
|
31
32
|
|
|
33
|
+
|
|
32
34
|
cpu_device = torch._C.device("cpu")
|
|
33
35
|
COLOR_RED = '\033[31m'
|
|
34
36
|
COLOR_GREEN = '\033[32m'
|
|
@@ -85,6 +87,7 @@ def get_callstack():
|
|
|
85
87
|
return callstack
|
|
86
88
|
|
|
87
89
|
|
|
90
|
+
@recursion_depth_decorator("data_to_cpu")
|
|
88
91
|
def data_to_cpu(data, deep, data_cpu):
|
|
89
92
|
global cpu_device
|
|
90
93
|
list_cpu = []
|
|
@@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd):
|
|
|
45
45
|
|
|
46
46
|
@catch_exception
|
|
47
47
|
def default(self, line=""):
|
|
48
|
-
self.
|
|
49
|
-
return False
|
|
50
|
-
|
|
51
|
-
@catch_exception
|
|
52
|
-
def do_run(self, line=""):
|
|
53
|
-
self.util.execute_command(line)
|
|
48
|
+
self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n")
|
|
54
49
|
|
|
55
50
|
@catch_exception
|
|
56
51
|
def do_vc(self, line=""):
|
|
@@ -119,6 +119,7 @@ class Util:
|
|
|
119
119
|
|
|
120
120
|
@staticmethod
|
|
121
121
|
def deal_with_dir_or_file_inconsistency(output_path):
|
|
122
|
+
logger.warning(f"Trying to delete {output_path}")
|
|
122
123
|
remove_path(output_path)
|
|
123
124
|
raise ParseException("Inconsistent directory structure or file.")
|
|
124
125
|
|
|
@@ -264,7 +265,7 @@ class Util:
|
|
|
264
265
|
match = re_pattern.match(name)
|
|
265
266
|
if not match:
|
|
266
267
|
continue
|
|
267
|
-
if extern_pattern != '' and re_pattern.match(extern_pattern) and not
|
|
268
|
+
if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern):
|
|
268
269
|
continue
|
|
269
270
|
file_list[name] = gen_info_func(name, match, file["root"])
|
|
270
271
|
return file_list
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
20
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
-
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
|
|
21
|
+
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid, \
|
|
22
|
+
FileChecker
|
|
22
23
|
from msprobe.core.common.log import logger
|
|
23
24
|
from msprobe.core.common.utils import is_int
|
|
24
25
|
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
@@ -66,6 +67,7 @@ class TensorConfig(BaseConfig):
|
|
|
66
67
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
67
68
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
68
69
|
check_crt_valid(os.path.join(self.tls_path, "client.crt"))
|
|
70
|
+
check_crt_valid(os.path.join(self.tls_path, "client.key"), True)
|
|
69
71
|
|
|
70
72
|
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
71
73
|
raise Exception(f"host: {self.host} is invalid.")
|
|
@@ -95,6 +97,8 @@ class OverflowCheckConfig(BaseConfig):
|
|
|
95
97
|
def check_overflow_config(self):
|
|
96
98
|
if self.overflow_nums is not None and not is_int(self.overflow_nums):
|
|
97
99
|
raise Exception("overflow_num is invalid")
|
|
100
|
+
if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
|
|
101
|
+
raise Exception("overflow_nums should be -1 or positive integer")
|
|
98
102
|
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
99
103
|
raise Exception("check_mode is invalid")
|
|
100
104
|
|
|
@@ -148,7 +152,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
148
152
|
self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
|
|
149
153
|
):
|
|
150
154
|
msg = (
|
|
151
|
-
f"You
|
|
155
|
+
f"You need to and can only set fuzz_device as {DeviceType.CPU} "
|
|
152
156
|
f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
|
|
153
157
|
)
|
|
154
158
|
logger.error_log_with_exp(
|
|
@@ -271,13 +275,13 @@ class RunUTConfig(BaseConfig):
|
|
|
271
275
|
|
|
272
276
|
@classmethod
|
|
273
277
|
def check_nfs_path_config(cls, nfs_path):
|
|
274
|
-
if nfs_path
|
|
275
|
-
|
|
278
|
+
if nfs_path:
|
|
279
|
+
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
276
280
|
|
|
277
281
|
@classmethod
|
|
278
282
|
def check_tls_path_config(cls, tls_path):
|
|
279
|
-
if tls_path
|
|
280
|
-
|
|
283
|
+
if tls_path:
|
|
284
|
+
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
281
285
|
|
|
282
286
|
def check_run_ut_config(self):
|
|
283
287
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
msprobe/pytorch/service.py
CHANGED
|
@@ -30,7 +30,7 @@ from msprobe.pytorch.common.log import logger
|
|
|
30
30
|
from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
|
|
31
31
|
from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
32
32
|
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
33
|
-
from msprobe.pytorch.hook_module.
|
|
33
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
34
34
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
35
35
|
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
36
36
|
|
|
@@ -50,6 +50,8 @@ class Service:
|
|
|
50
50
|
self.switch = False
|
|
51
51
|
self.inner_switch = False
|
|
52
52
|
self.current_iter = 0
|
|
53
|
+
self.loop = 0
|
|
54
|
+
self.init_step = 0
|
|
53
55
|
self.first_start = True
|
|
54
56
|
self.current_rank = None
|
|
55
57
|
self.dump_iter_dir = None
|
|
@@ -58,6 +60,7 @@ class Service:
|
|
|
58
60
|
self.params_grad_info = {}
|
|
59
61
|
self.hook_handle_dict = {}
|
|
60
62
|
# 提前注册,确保注册尽可能多的API hook
|
|
63
|
+
self.api_register = get_api_register()
|
|
61
64
|
self.register_api_hook()
|
|
62
65
|
self.init_for_debug_level()
|
|
63
66
|
|
|
@@ -246,6 +249,8 @@ class Service:
|
|
|
246
249
|
return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
|
|
247
250
|
|
|
248
251
|
def start(self, model):
|
|
252
|
+
self.current_iter = self.loop + self.init_step
|
|
253
|
+
self.data_collector.update_iter(self.current_iter)
|
|
249
254
|
if self.config.level == Const.LEVEL_DEBUG:
|
|
250
255
|
return
|
|
251
256
|
if self.need_stop_service():
|
|
@@ -304,8 +309,7 @@ class Service:
|
|
|
304
309
|
if self.config.task == Const.TENSOR:
|
|
305
310
|
self.data_collector.data_processor.dump_async_data()
|
|
306
311
|
self.data_collector.write_json()
|
|
307
|
-
self.
|
|
308
|
-
self.data_collector.update_iter(self.current_iter)
|
|
312
|
+
self.loop += 1
|
|
309
313
|
self.reset_status()
|
|
310
314
|
|
|
311
315
|
def need_stop_service(self):
|
|
@@ -370,11 +374,10 @@ class Service:
|
|
|
370
374
|
def register_api_hook(self):
|
|
371
375
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
372
376
|
logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
373
|
-
api_register.initialize_hook(
|
|
374
|
-
functools.partial(self.build_hook, BaseScope.Module_Type_API)
|
|
375
|
-
self.config.online_run_ut
|
|
377
|
+
self.api_register.initialize_hook(
|
|
378
|
+
functools.partial(self.build_hook, BaseScope.Module_Type_API)
|
|
376
379
|
)
|
|
377
|
-
api_register.
|
|
380
|
+
self.api_register.register_all_api()
|
|
378
381
|
|
|
379
382
|
def register_module_hook(self):
|
|
380
383
|
if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
@@ -409,7 +412,7 @@ class Service:
|
|
|
409
412
|
if self.config.nfs_path:
|
|
410
413
|
self.attl.upload("end")
|
|
411
414
|
elif self.attl.socket_manager is not None:
|
|
412
|
-
logger.info(f"pid: {os.getpid()} finished, start
|
|
415
|
+
logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
|
|
413
416
|
self.attl.socket_manager.send_stop_signal()
|
|
414
417
|
|
|
415
418
|
def reset_status(self):
|
|
@@ -16,19 +16,19 @@
|
|
|
16
16
|
import re
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import Const
|
|
19
|
-
from msprobe.core.common.file_utils import load_json
|
|
19
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
20
20
|
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
21
|
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
22
22
|
from msprobe.visualization.graph.graph import Graph
|
|
23
23
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
24
|
-
from msprobe.visualization.utils import
|
|
24
|
+
from msprobe.visualization.utils import GraphConst
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class GraphBuilder:
|
|
28
28
|
backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
|
|
29
29
|
forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
|
|
30
|
-
# 匹配以大写字母开头,后接任意字母,并以Template(
|
|
31
|
-
template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
|
|
30
|
+
# 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
|
|
31
|
+
template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
|
|
32
32
|
|
|
33
33
|
@staticmethod
|
|
34
34
|
def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
|
|
@@ -51,6 +51,7 @@ class GraphBuilder:
|
|
|
51
51
|
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
52
52
|
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
53
53
|
GraphBuilder._collect_apis_between_modules(graph)
|
|
54
|
+
GraphBuilder._add_parameters_grad(graph, data_dict)
|
|
54
55
|
return graph
|
|
55
56
|
|
|
56
57
|
@staticmethod
|
|
@@ -73,7 +74,7 @@ class GraphBuilder:
|
|
|
73
74
|
if config.task:
|
|
74
75
|
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
75
76
|
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
76
|
-
|
|
77
|
+
save_json(filename, result, indent=4)
|
|
77
78
|
|
|
78
79
|
@staticmethod
|
|
79
80
|
def _simplify_stack(stack_dict):
|
|
@@ -235,6 +236,44 @@ class GraphBuilder:
|
|
|
235
236
|
|
|
236
237
|
graph.root.subnodes = output
|
|
237
238
|
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _add_parameters_grad(graph, data_dict):
|
|
241
|
+
"""
|
|
242
|
+
将parameters_grad信息添加到graph中,
|
|
243
|
+
对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点
|
|
244
|
+
|
|
245
|
+
例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2
|
|
246
|
+
则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点
|
|
247
|
+
"""
|
|
248
|
+
prefixes = []
|
|
249
|
+
suffix = Const.SEP + Const.PARAMS_GRAD
|
|
250
|
+
for node_id in data_dict.keys():
|
|
251
|
+
if node_id not in graph.node_map and node_id.endswith(suffix):
|
|
252
|
+
prefixes.append(node_id.replace(suffix, ''))
|
|
253
|
+
|
|
254
|
+
max_info = {prefix: 0 for prefix in prefixes}
|
|
255
|
+
|
|
256
|
+
for key in graph.node_map.keys():
|
|
257
|
+
for prefix in prefixes:
|
|
258
|
+
# 构建正则表达式,匹配以 "backward.数字" 结尾的键
|
|
259
|
+
pattern = re.compile(r'^' + re.escape(prefix) + r'\.backward\.(\d+)$')
|
|
260
|
+
match = pattern.match(key)
|
|
261
|
+
if match:
|
|
262
|
+
num = int(match.group(1))
|
|
263
|
+
if num > max_info[prefix]:
|
|
264
|
+
max_info[prefix] = num
|
|
265
|
+
|
|
266
|
+
for prefix, num in max_info.items():
|
|
267
|
+
node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num)
|
|
268
|
+
node = graph.get_node(node_id)
|
|
269
|
+
if node:
|
|
270
|
+
parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node)
|
|
271
|
+
# 添加输入输出数据
|
|
272
|
+
node_data = data_dict.get(parameters_grad_node_id, {})
|
|
273
|
+
input_data, output_data = get_input_output(node_data, parameters_grad_node_id)
|
|
274
|
+
# 更新数据
|
|
275
|
+
graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
|
|
276
|
+
|
|
238
277
|
|
|
239
278
|
class GraphExportConfig:
|
|
240
279
|
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
import re
|
|
16
|
-
import math
|
|
17
16
|
from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
|
|
18
17
|
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
18
|
from msprobe.visualization.utils import GraphConst
|
|
@@ -17,12 +17,14 @@ import re
|
|
|
17
17
|
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
18
18
|
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
|
|
19
19
|
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
20
|
-
from msprobe.visualization.graph.node_colors import NodeColors
|
|
21
20
|
from msprobe.visualization.compare.mode_adapter import ModeAdapter
|
|
22
21
|
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class GraphComparator:
|
|
26
|
+
MAX_DEPTH = 1000
|
|
27
|
+
|
|
26
28
|
def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
|
|
27
29
|
self.graph_n = graphs[0]
|
|
28
30
|
self.graph_b = graphs[1]
|
|
@@ -41,7 +43,7 @@ class GraphComparator:
|
|
|
41
43
|
else:
|
|
42
44
|
self._compare_nodes(self.graph_n.root)
|
|
43
45
|
self._postcompare()
|
|
44
|
-
|
|
46
|
+
|
|
45
47
|
def add_compare_result_to_node(self, node, compare_result_list):
|
|
46
48
|
"""
|
|
47
49
|
将比对结果添加到节点的输入输出数据中
|
|
@@ -66,43 +68,8 @@ class GraphComparator:
|
|
|
66
68
|
self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
|
|
67
69
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
68
70
|
node.data.update(other_dict)
|
|
69
|
-
|
|
70
|
-
def _parse_param(self, dump_path_param, output_path):
|
|
71
|
-
self.dump_path_param = dump_path_param
|
|
72
|
-
self.output_path = output_path
|
|
73
|
-
compare_mode = get_compare_mode(self.dump_path_param)
|
|
74
|
-
self.ma = ModeAdapter(compare_mode)
|
|
75
|
-
self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
|
|
76
|
-
self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
|
|
77
|
-
self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
|
|
78
|
-
|
|
79
|
-
def _postcompare(self):
|
|
80
|
-
self._handle_api_collection_index()
|
|
81
|
-
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
82
|
-
return
|
|
83
|
-
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
84
|
-
df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
|
|
85
|
-
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
86
|
-
for node in self.ma.compare_nodes:
|
|
87
|
-
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
88
|
-
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
89
|
-
|
|
90
|
-
def _handle_api_collection_index(self):
|
|
91
|
-
"""
|
|
92
|
-
api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
|
|
93
|
-
md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
|
|
94
|
-
"""
|
|
95
|
-
for node in self.graph_n.root.subnodes:
|
|
96
|
-
if node.op == NodeOp.api_collection:
|
|
97
|
-
precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
98
|
-
else GraphConst.MIN_INDEX_KEY
|
|
99
|
-
for api in node.subnodes:
|
|
100
|
-
precision_index = min(precision_index,
|
|
101
|
-
api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
|
|
102
|
-
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
103
|
-
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
104
|
-
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
105
71
|
|
|
72
|
+
@recursion_depth_decorator('GraphComparator._compare_nodes', max_depth=MAX_DEPTH)
|
|
106
73
|
def _compare_nodes(self, node_n):
|
|
107
74
|
"""
|
|
108
75
|
递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
|
|
@@ -126,6 +93,7 @@ class GraphComparator:
|
|
|
126
93
|
for subnode in node_n.subnodes:
|
|
127
94
|
self._compare_nodes(subnode)
|
|
128
95
|
|
|
96
|
+
@recursion_depth_decorator('GraphComparator._compare_nodes_fuzzy', max_depth=MAX_DEPTH)
|
|
129
97
|
def _compare_nodes_fuzzy(self, node_n):
|
|
130
98
|
if node_n.op != NodeOp.function_api:
|
|
131
99
|
# 模块经过模糊匹配
|
|
@@ -146,6 +114,42 @@ class GraphComparator:
|
|
|
146
114
|
for sub_node in node_n.subnodes:
|
|
147
115
|
self._compare_nodes_fuzzy(sub_node)
|
|
148
116
|
|
|
117
|
+
def _parse_param(self, dump_path_param, output_path):
|
|
118
|
+
self.dump_path_param = dump_path_param
|
|
119
|
+
self.output_path = output_path
|
|
120
|
+
compare_mode = get_compare_mode(self.dump_path_param)
|
|
121
|
+
self.ma = ModeAdapter(compare_mode)
|
|
122
|
+
self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
|
|
123
|
+
self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
|
|
124
|
+
self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
|
|
125
|
+
|
|
126
|
+
def _postcompare(self):
|
|
127
|
+
self._handle_api_collection_index()
|
|
128
|
+
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
129
|
+
return
|
|
130
|
+
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
131
|
+
df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
|
|
132
|
+
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
133
|
+
for node in self.ma.compare_nodes:
|
|
134
|
+
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
135
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
136
|
+
|
|
137
|
+
def _handle_api_collection_index(self):
|
|
138
|
+
"""
|
|
139
|
+
api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
|
|
140
|
+
md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
|
|
141
|
+
"""
|
|
142
|
+
for node in self.graph_n.root.subnodes:
|
|
143
|
+
if node.op == NodeOp.api_collection:
|
|
144
|
+
precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
145
|
+
else GraphConst.MIN_INDEX_KEY
|
|
146
|
+
for api in node.subnodes:
|
|
147
|
+
precision_index = min(precision_index,
|
|
148
|
+
api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
|
|
149
|
+
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
150
|
+
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
151
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
152
|
+
|
|
149
153
|
def _get_and_add_result(self, node_n, node_b):
|
|
150
154
|
compare_result_list = compare_node([node_n.id, node_b.id],
|
|
151
155
|
[self.data_n_dict, self.data_b_dict],
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
import math
|
|
18
17
|
from msprobe.core.common.const import CompareConst, Const
|
|
19
18
|
from msprobe.visualization.utils import ToolTip, GraphConst, str2float
|
|
20
19
|
|
|
@@ -157,24 +156,6 @@ class ModeAdapter:
|
|
|
157
156
|
return
|
|
158
157
|
self.csv_data.extend(compare_result_list)
|
|
159
158
|
|
|
160
|
-
def add_error_key(self, node_data):
|
|
161
|
-
"""
|
|
162
|
-
根据不同的模式进行提供不同错误信息
|
|
163
|
-
"""
|
|
164
|
-
for key, value in node_data.items():
|
|
165
|
-
if not isinstance(value, dict):
|
|
166
|
-
continue
|
|
167
|
-
if self.compare_mode == GraphConst.SUMMARY_COMPARE:
|
|
168
|
-
message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
|
|
169
|
-
CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
|
|
170
|
-
elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
171
|
-
message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
|
|
172
|
-
else:
|
|
173
|
-
# 输出件优化
|
|
174
|
-
message = []
|
|
175
|
-
value[GraphConst.ERROR_KEY] = message
|
|
176
|
-
node_data[key] = value
|
|
177
|
-
|
|
178
159
|
def get_tool_tip(self):
|
|
179
160
|
"""
|
|
180
161
|
用于前端展示字段的具体含义
|
|
@@ -12,10 +12,11 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
|
|
15
16
|
from msprobe.core.overflow_check.level import OverflowLevel
|
|
16
|
-
from msprobe.visualization.graph.node_op import NodeOp
|
|
17
17
|
from msprobe.visualization.utils import GraphConst
|
|
18
18
|
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
|
|
19
|
+
from msprobe.core.common.log import logger
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class BaseNode:
|
|
@@ -114,7 +115,13 @@ class BaseNode:
|
|
|
114
115
|
"""
|
|
115
116
|
ancestors = []
|
|
116
117
|
current_node = self.upnode
|
|
118
|
+
seen_nodes = set()
|
|
117
119
|
while current_node:
|
|
120
|
+
if current_node.id in seen_nodes:
|
|
121
|
+
logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, '
|
|
122
|
+
f'current node is {current_node.id}.')
|
|
123
|
+
return []
|
|
124
|
+
seen_nodes.add(current_node.id)
|
|
118
125
|
ancestors.append(current_node.id)
|
|
119
126
|
current_node = current_node.upnode
|
|
120
127
|
return list(reversed(ancestors))
|
|
@@ -107,15 +107,6 @@ class DistributedAnalyzer:
|
|
|
107
107
|
return None, None
|
|
108
108
|
return group_ranks, group_id
|
|
109
109
|
|
|
110
|
-
@staticmethod
|
|
111
|
-
def _get_batch_group_info(node, rank):
|
|
112
|
-
for data in node.input_data.values():
|
|
113
|
-
group_id = data.get('group_id')
|
|
114
|
-
if group_id is not None:
|
|
115
|
-
return group_id
|
|
116
|
-
logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
117
|
-
return None
|
|
118
|
-
|
|
119
110
|
def distributed_match(self):
|
|
120
111
|
for rank, graph in self.graphs.items():
|
|
121
112
|
nodes = graph.node_map
|
|
@@ -377,7 +368,7 @@ class DistributedAnalyzer:
|
|
|
377
368
|
target_api_name = self.config.get(api_name)[0]
|
|
378
369
|
target_rank = int(id_info[1].replace(Const.RANK, ''))
|
|
379
370
|
except Exception as e:
|
|
380
|
-
logger.warning(f'Failed to
|
|
371
|
+
logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.')
|
|
381
372
|
continue
|
|
382
373
|
target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
|
|
383
374
|
if not target_node:
|