mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
20
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
-
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
|
|
21
|
+
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid, \
|
|
22
|
+
FileChecker
|
|
22
23
|
from msprobe.core.common.log import logger
|
|
23
24
|
from msprobe.core.common.utils import is_int
|
|
24
25
|
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
@@ -66,6 +67,7 @@ class TensorConfig(BaseConfig):
|
|
|
66
67
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
67
68
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
68
69
|
check_crt_valid(os.path.join(self.tls_path, "client.crt"))
|
|
70
|
+
check_crt_valid(os.path.join(self.tls_path, "client.key"), True)
|
|
69
71
|
|
|
70
72
|
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
71
73
|
raise Exception(f"host: {self.host} is invalid.")
|
|
@@ -95,6 +97,8 @@ class OverflowCheckConfig(BaseConfig):
|
|
|
95
97
|
def check_overflow_config(self):
|
|
96
98
|
if self.overflow_nums is not None and not is_int(self.overflow_nums):
|
|
97
99
|
raise Exception("overflow_num is invalid")
|
|
100
|
+
if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
|
|
101
|
+
raise Exception("overflow_nums should be -1 or positive integer")
|
|
98
102
|
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
99
103
|
raise Exception("check_mode is invalid")
|
|
100
104
|
|
|
@@ -148,7 +152,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
148
152
|
self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
|
|
149
153
|
):
|
|
150
154
|
msg = (
|
|
151
|
-
f"You
|
|
155
|
+
f"You need to and can only set fuzz_device as {DeviceType.CPU} "
|
|
152
156
|
f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
|
|
153
157
|
)
|
|
154
158
|
logger.error_log_with_exp(
|
|
@@ -271,13 +275,13 @@ class RunUTConfig(BaseConfig):
|
|
|
271
275
|
|
|
272
276
|
@classmethod
|
|
273
277
|
def check_nfs_path_config(cls, nfs_path):
|
|
274
|
-
if nfs_path
|
|
275
|
-
|
|
278
|
+
if nfs_path:
|
|
279
|
+
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
276
280
|
|
|
277
281
|
@classmethod
|
|
278
282
|
def check_tls_path_config(cls, tls_path):
|
|
279
|
-
if tls_path
|
|
280
|
-
|
|
283
|
+
if tls_path:
|
|
284
|
+
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
281
285
|
|
|
282
286
|
def check_run_ut_config(self):
|
|
283
287
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
@@ -303,28 +307,25 @@ class GradToolConfig(BaseConfig):
|
|
|
303
307
|
check_bounds(self.bounds)
|
|
304
308
|
|
|
305
309
|
|
|
310
|
+
class StructureConfig(BaseConfig):
|
|
311
|
+
def __init__(self, json_config):
|
|
312
|
+
super().__init__(json_config)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
TaskDict = {
|
|
316
|
+
Const.TENSOR: TensorConfig,
|
|
317
|
+
Const.STATISTICS: StatisticsConfig,
|
|
318
|
+
Const.OVERFLOW_CHECK: OverflowCheckConfig,
|
|
319
|
+
Const.FREE_BENCHMARK: FreeBenchmarkCheckConfig,
|
|
320
|
+
Const.RUN_UT: RunUTConfig,
|
|
321
|
+
Const.GRAD_PROBE: GradToolConfig,
|
|
322
|
+
Const.STRUCTURE: StructureConfig
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
|
|
306
326
|
def parse_task_config(task, json_config):
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
config_dic = json_config.get(Const.TENSOR, default_dic)
|
|
310
|
-
return TensorConfig(config_dic)
|
|
311
|
-
elif task == Const.STATISTICS:
|
|
312
|
-
config_dic = json_config.get(Const.STATISTICS, default_dic)
|
|
313
|
-
return StatisticsConfig(config_dic)
|
|
314
|
-
elif task == Const.OVERFLOW_CHECK:
|
|
315
|
-
config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic)
|
|
316
|
-
return OverflowCheckConfig(config_dic)
|
|
317
|
-
elif task == Const.FREE_BENCHMARK:
|
|
318
|
-
config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic)
|
|
319
|
-
return FreeBenchmarkCheckConfig(config_dic)
|
|
320
|
-
elif task == Const.RUN_UT:
|
|
321
|
-
config_dic = json_config.get(Const.RUN_UT, default_dic)
|
|
322
|
-
return RunUTConfig(config_dic)
|
|
323
|
-
elif task == Const.GRAD_PROBE:
|
|
324
|
-
config_dic = json_config.get(Const.GRAD_PROBE, default_dic)
|
|
325
|
-
return GradToolConfig(config_dic)
|
|
326
|
-
else:
|
|
327
|
-
return StatisticsConfig(default_dic)
|
|
327
|
+
task_map = json_config.get(task, dict())
|
|
328
|
+
return TaskDict.get(task)(task_map)
|
|
328
329
|
|
|
329
330
|
|
|
330
331
|
def parse_json_config(json_file_path, task):
|
msprobe/pytorch/service.py
CHANGED
|
@@ -15,22 +15,22 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
import os
|
|
18
|
-
from collections import namedtuple
|
|
18
|
+
from collections import namedtuple, defaultdict
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
from msprobe.core.common.const import Const
|
|
22
22
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
23
23
|
from msprobe.core.common.file_utils import create_directory
|
|
24
|
-
from msprobe.core.common.utils import print_tools_ends_info
|
|
24
|
+
from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation
|
|
25
25
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
26
26
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
27
27
|
from msprobe.core.data_dump.scope import BaseScope
|
|
28
28
|
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
29
29
|
from msprobe.pytorch.common.log import logger
|
|
30
|
-
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
30
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
|
|
31
31
|
from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
32
32
|
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
33
|
-
from msprobe.pytorch.hook_module.
|
|
33
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
34
34
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
35
35
|
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
36
36
|
|
|
@@ -50,19 +50,25 @@ class Service:
|
|
|
50
50
|
self.switch = False
|
|
51
51
|
self.inner_switch = False
|
|
52
52
|
self.current_iter = 0
|
|
53
|
+
self.loop = 0
|
|
54
|
+
self.init_step = 0
|
|
53
55
|
self.first_start = True
|
|
54
56
|
self.current_rank = None
|
|
55
57
|
self.dump_iter_dir = None
|
|
56
58
|
self.should_stop_service = False
|
|
57
59
|
self.attl = None
|
|
58
60
|
self.params_grad_info = {}
|
|
61
|
+
self.hook_handle_dict = {}
|
|
59
62
|
# 提前注册,确保注册尽可能多的API hook
|
|
63
|
+
self.api_register = get_api_register()
|
|
60
64
|
self.register_api_hook()
|
|
65
|
+
self.init_for_debug_level()
|
|
61
66
|
|
|
62
67
|
def build_hook(self, module_type, name):
|
|
63
68
|
def pre_hook(api_or_module_name, module, args, kwargs):
|
|
64
69
|
if not self.should_execute_hook(module_type, module, True):
|
|
65
70
|
return args, kwargs
|
|
71
|
+
is_recompute = is_recomputation()
|
|
66
72
|
|
|
67
73
|
self.inner_switch = True
|
|
68
74
|
if module_type == BaseScope.Module_Type_Module:
|
|
@@ -77,7 +83,13 @@ class Service:
|
|
|
77
83
|
return None, None
|
|
78
84
|
if self.data_collector:
|
|
79
85
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
80
|
-
self.data_collector.forward_input_data_collect(
|
|
86
|
+
self.data_collector.forward_input_data_collect(
|
|
87
|
+
api_or_module_name,
|
|
88
|
+
module,
|
|
89
|
+
pid,
|
|
90
|
+
module_input_output,
|
|
91
|
+
is_recompute
|
|
92
|
+
)
|
|
81
93
|
|
|
82
94
|
self.inner_switch = False
|
|
83
95
|
return args, kwargs
|
|
@@ -101,7 +113,12 @@ class Service:
|
|
|
101
113
|
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
102
114
|
for param_name, param in params_dict.items():
|
|
103
115
|
if param.requires_grad:
|
|
104
|
-
|
|
116
|
+
name = ori_name + Const.SEP + param_name
|
|
117
|
+
old_handle = self.hook_handle_dict.get(name)
|
|
118
|
+
if old_handle and hasattr(old_handle, "remove"):
|
|
119
|
+
old_handle.remove()
|
|
120
|
+
handle = param.register_hook(grad_hook(module, ori_name, param_name))
|
|
121
|
+
self.hook_handle_dict[name] = handle
|
|
105
122
|
|
|
106
123
|
def init_params_grad_info(module, params_dict):
|
|
107
124
|
'''
|
|
@@ -125,6 +142,7 @@ class Service:
|
|
|
125
142
|
def forward_hook(api_or_module_name, module, args, kwargs, output):
|
|
126
143
|
if not self.should_execute_hook(module_type, module, True):
|
|
127
144
|
return None
|
|
145
|
+
is_recompute = is_recomputation()
|
|
128
146
|
|
|
129
147
|
self.inner_switch = True
|
|
130
148
|
if self.config.online_run_ut:
|
|
@@ -147,10 +165,15 @@ class Service:
|
|
|
147
165
|
if module_type == BaseScope.Module_Type_Module:
|
|
148
166
|
api_or_module_name = module.mindstudio_reserved_name[-1]
|
|
149
167
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
150
|
-
params_dict = {
|
|
151
|
-
|
|
168
|
+
params_dict = {}
|
|
169
|
+
if self.config.task != Const.STRUCTURE:
|
|
170
|
+
params_dict = {
|
|
171
|
+
key.split(Const.SEP)[-1]: value
|
|
172
|
+
for key, value in module.named_parameters(recurse=False)
|
|
173
|
+
}
|
|
174
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
152
175
|
# 判断是否需要注册参数hook
|
|
153
|
-
if
|
|
176
|
+
if params_dict:
|
|
154
177
|
ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
|
|
155
178
|
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
156
179
|
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
@@ -160,7 +183,8 @@ class Service:
|
|
|
160
183
|
api_or_module_name,
|
|
161
184
|
module,
|
|
162
185
|
pid,
|
|
163
|
-
module_input_output
|
|
186
|
+
module_input_output,
|
|
187
|
+
is_recompute
|
|
164
188
|
)
|
|
165
189
|
init_params_grad_info(module, params_dict)
|
|
166
190
|
else:
|
|
@@ -169,7 +193,8 @@ class Service:
|
|
|
169
193
|
api_or_module_name,
|
|
170
194
|
module,
|
|
171
195
|
pid,
|
|
172
|
-
module_input_output
|
|
196
|
+
module_input_output,
|
|
197
|
+
is_recompute
|
|
173
198
|
)
|
|
174
199
|
|
|
175
200
|
if self.data_collector.if_return_forward_new_output():
|
|
@@ -185,6 +210,7 @@ class Service:
|
|
|
185
210
|
def backward_hook(api_or_module_name, module, grad_input, grad_output):
|
|
186
211
|
if not self.should_execute_hook(module_type, module, False):
|
|
187
212
|
return
|
|
213
|
+
is_recompute = is_recomputation()
|
|
188
214
|
|
|
189
215
|
self.inner_switch = True
|
|
190
216
|
if module_type == BaseScope.Module_Type_Module:
|
|
@@ -198,7 +224,13 @@ class Service:
|
|
|
198
224
|
if self.data_collector:
|
|
199
225
|
# 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
|
|
200
226
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
201
|
-
self.data_collector.backward_data_collect(
|
|
227
|
+
self.data_collector.backward_data_collect(
|
|
228
|
+
api_or_module_name,
|
|
229
|
+
module,
|
|
230
|
+
pid,
|
|
231
|
+
module_input_output,
|
|
232
|
+
is_recompute
|
|
233
|
+
)
|
|
202
234
|
self.inner_switch = False
|
|
203
235
|
|
|
204
236
|
pid = os.getpid()
|
|
@@ -217,6 +249,10 @@ class Service:
|
|
|
217
249
|
return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
|
|
218
250
|
|
|
219
251
|
def start(self, model):
|
|
252
|
+
self.current_iter = self.loop + self.init_step
|
|
253
|
+
self.data_collector.update_iter(self.current_iter)
|
|
254
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
255
|
+
return
|
|
220
256
|
if self.need_stop_service():
|
|
221
257
|
return
|
|
222
258
|
|
|
@@ -231,6 +267,8 @@ class Service:
|
|
|
231
267
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
232
268
|
return
|
|
233
269
|
self.register_module_hook()
|
|
270
|
+
if self.config.level == Const.LEVEL_MIX:
|
|
271
|
+
register_optimizer_hook(self.data_collector)
|
|
234
272
|
self.first_start = False
|
|
235
273
|
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
236
274
|
run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
|
|
@@ -241,6 +279,8 @@ class Service:
|
|
|
241
279
|
logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
242
280
|
|
|
243
281
|
def stop(self):
|
|
282
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
283
|
+
return
|
|
244
284
|
if self.should_stop_service:
|
|
245
285
|
return
|
|
246
286
|
if self.config.step and self.current_iter not in self.config.step:
|
|
@@ -255,18 +295,21 @@ class Service:
|
|
|
255
295
|
return
|
|
256
296
|
if self.config.async_dump:
|
|
257
297
|
self.data_collector.fill_stack_tensor_data()
|
|
258
|
-
self.
|
|
298
|
+
if self.config.task == Const.TENSOR:
|
|
299
|
+
self.data_collector.data_processor.dump_async_data()
|
|
259
300
|
self.data_collector.write_json()
|
|
260
301
|
|
|
261
302
|
def step(self):
|
|
303
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
304
|
+
return
|
|
262
305
|
if self.should_stop_service:
|
|
263
306
|
return
|
|
264
307
|
if self.config.async_dump:
|
|
265
308
|
self.data_collector.fill_stack_tensor_data()
|
|
266
|
-
self.
|
|
309
|
+
if self.config.task == Const.TENSOR:
|
|
310
|
+
self.data_collector.data_processor.dump_async_data()
|
|
267
311
|
self.data_collector.write_json()
|
|
268
|
-
self.
|
|
269
|
-
self.data_collector.update_iter(self.current_iter)
|
|
312
|
+
self.loop += 1
|
|
270
313
|
self.reset_status()
|
|
271
314
|
|
|
272
315
|
def need_stop_service(self):
|
|
@@ -319,26 +362,22 @@ class Service:
|
|
|
319
362
|
else:
|
|
320
363
|
dump_data_dir = None
|
|
321
364
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
)
|
|
365
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
366
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
367
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
368
|
+
dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
369
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
370
|
+
dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
|
|
371
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
329
372
|
self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
|
|
330
373
|
|
|
331
374
|
def register_api_hook(self):
|
|
332
375
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
333
376
|
logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
334
|
-
api_register.initialize_hook(
|
|
335
|
-
functools.partial(self.build_hook, BaseScope.Module_Type_API)
|
|
336
|
-
self.config.online_run_ut
|
|
377
|
+
self.api_register.initialize_hook(
|
|
378
|
+
functools.partial(self.build_hook, BaseScope.Module_Type_API)
|
|
337
379
|
)
|
|
338
|
-
api_register.
|
|
339
|
-
|
|
340
|
-
if self.config.level == Const.LEVEL_MIX:
|
|
341
|
-
register_optimizer_hook(self.data_collector)
|
|
380
|
+
self.api_register.register_all_api()
|
|
342
381
|
|
|
343
382
|
def register_module_hook(self):
|
|
344
383
|
if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
@@ -373,13 +412,13 @@ class Service:
|
|
|
373
412
|
if self.config.nfs_path:
|
|
374
413
|
self.attl.upload("end")
|
|
375
414
|
elif self.attl.socket_manager is not None:
|
|
376
|
-
logger.info(f"pid: {os.getpid()} finished, start
|
|
415
|
+
logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
|
|
377
416
|
self.attl.socket_manager.send_stop_signal()
|
|
378
417
|
|
|
379
418
|
def reset_status(self):
|
|
380
419
|
ModuleProcesser.reset_module_stats()
|
|
381
420
|
HOOKModule.reset_module_stats()
|
|
382
|
-
self.data_collector.
|
|
421
|
+
self.data_collector.reset_status()
|
|
383
422
|
self.params_grad_info.clear()
|
|
384
423
|
|
|
385
424
|
if self.config.level == Const.LEVEL_L2:
|
|
@@ -389,3 +428,46 @@ class Service:
|
|
|
389
428
|
return
|
|
390
429
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
391
430
|
return
|
|
431
|
+
|
|
432
|
+
def init_for_debug_level(self):
|
|
433
|
+
if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
|
|
434
|
+
return
|
|
435
|
+
try:
|
|
436
|
+
self.current_rank = get_rank_if_initialized()
|
|
437
|
+
except DistributedNotInitializedError:
|
|
438
|
+
self.current_rank = None
|
|
439
|
+
|
|
440
|
+
# dir: dump_path -- rank{} -- debug.json
|
|
441
|
+
self.dump_iter_dir = self.config.dump_path
|
|
442
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
443
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
444
|
+
create_directory(dump_dir)
|
|
445
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
446
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
447
|
+
create_directory(dump_data_dir)
|
|
448
|
+
else:
|
|
449
|
+
dump_data_dir = None
|
|
450
|
+
|
|
451
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
452
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
453
|
+
dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
|
|
454
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
455
|
+
self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
|
|
456
|
+
|
|
457
|
+
self.debug_variable_counter = defaultdict(int)
|
|
458
|
+
|
|
459
|
+
def save(self, variable, name, save_backward):
|
|
460
|
+
if self.config.level != Const.LEVEL_DEBUG:
|
|
461
|
+
return
|
|
462
|
+
count = self.debug_variable_counter[name]
|
|
463
|
+
self.debug_variable_counter[name] += 1
|
|
464
|
+
|
|
465
|
+
name_with_count = f"{name}.{count}"
|
|
466
|
+
grad_name_with_count = f"{name}_grad.{count}"
|
|
467
|
+
|
|
468
|
+
# forward save
|
|
469
|
+
self.data_collector.debug_data_collect_forward(variable, name_with_count)
|
|
470
|
+
|
|
471
|
+
# backward save
|
|
472
|
+
if save_backward:
|
|
473
|
+
self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
|
|
@@ -16,18 +16,19 @@
|
|
|
16
16
|
import re
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import Const
|
|
19
|
-
from msprobe.core.common.file_utils import load_json
|
|
19
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
20
20
|
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
21
|
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
22
22
|
from msprobe.visualization.graph.graph import Graph
|
|
23
23
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
24
|
-
from msprobe.visualization.utils import
|
|
24
|
+
from msprobe.visualization.utils import GraphConst
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class GraphBuilder:
|
|
28
28
|
backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
|
|
30
|
+
# 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
|
|
31
|
+
template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
|
|
31
32
|
|
|
32
33
|
@staticmethod
|
|
33
34
|
def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
|
|
@@ -50,6 +51,7 @@ class GraphBuilder:
|
|
|
50
51
|
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
51
52
|
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
52
53
|
GraphBuilder._collect_apis_between_modules(graph)
|
|
54
|
+
GraphBuilder._add_parameters_grad(graph, data_dict)
|
|
53
55
|
return graph
|
|
54
56
|
|
|
55
57
|
@staticmethod
|
|
@@ -72,7 +74,7 @@ class GraphBuilder:
|
|
|
72
74
|
if config.task:
|
|
73
75
|
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
74
76
|
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
75
|
-
|
|
77
|
+
save_json(filename, result, indent=4)
|
|
76
78
|
|
|
77
79
|
@staticmethod
|
|
78
80
|
def _simplify_stack(stack_dict):
|
|
@@ -113,12 +115,17 @@ class GraphBuilder:
|
|
|
113
115
|
如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
|
|
114
116
|
"""
|
|
115
117
|
# 匹配以.backward.后跟一个或多个数字结尾的模式
|
|
116
|
-
backward_pattern
|
|
117
|
-
|
|
118
|
-
if re.search(backward_pattern, subnode_id) and not upnode_id:
|
|
119
|
-
forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
|
|
118
|
+
if GraphBuilder.backward_pattern.search(subnode_id) and not upnode_id:
|
|
119
|
+
forward_upnode_id = construct_dict.get(GraphBuilder.backward_pattern.sub(r".forward.\2", subnode_id))
|
|
120
120
|
if forward_upnode_id:
|
|
121
|
-
new_upnode_id =
|
|
121
|
+
new_upnode_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_upnode_id)
|
|
122
|
+
if new_upnode_id in construct_dict:
|
|
123
|
+
return new_upnode_id
|
|
124
|
+
# 匹配以.backward结尾的节点
|
|
125
|
+
if subnode_id.endswith(Const.SEP + Const.BACKWARD) and not upnode_id:
|
|
126
|
+
forward_upnode_id = construct_dict.get(subnode_id.replace(Const.BACKWARD, Const.FORWARD))
|
|
127
|
+
if forward_upnode_id:
|
|
128
|
+
new_upnode_id = forward_upnode_id.replace(Const.FORWARD, Const.BACKWARD)
|
|
122
129
|
if new_upnode_id in construct_dict:
|
|
123
130
|
return new_upnode_id
|
|
124
131
|
return upnode_id
|
|
@@ -148,6 +155,8 @@ class GraphBuilder:
|
|
|
148
155
|
input_data, output_data = get_input_output(node_data, node.id)
|
|
149
156
|
# 更新数据
|
|
150
157
|
node.set_input_output(input_data, output_data)
|
|
158
|
+
if GraphConst.BATCH_P2P in name:
|
|
159
|
+
GraphBuilder._extract_batch_p2p_info(node, node_data)
|
|
151
160
|
# 反向节点使用对应前向节点的堆栈信息
|
|
152
161
|
# 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
|
|
153
162
|
if (not node_stack_info and
|
|
@@ -164,6 +173,24 @@ class GraphBuilder:
|
|
|
164
173
|
node.add_upnode(upnode)
|
|
165
174
|
return node
|
|
166
175
|
|
|
176
|
+
@staticmethod
|
|
177
|
+
def _is_valid_batch_p2p_output(param_list):
|
|
178
|
+
if not isinstance(param_list, list) or not param_list:
|
|
179
|
+
return False
|
|
180
|
+
if not isinstance(param_list[0], list) or not param_list[0]:
|
|
181
|
+
return False
|
|
182
|
+
return True
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _extract_batch_p2p_info(node, node_data):
|
|
186
|
+
param_list = node_data.get(Const.OUTPUT, [])
|
|
187
|
+
# 数据格式:"output": [[{param1}, {param2}, ...]]
|
|
188
|
+
if GraphBuilder._is_valid_batch_p2p_output(param_list):
|
|
189
|
+
for param in param_list[0]:
|
|
190
|
+
info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
|
|
191
|
+
GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
|
|
192
|
+
node.batch_p2p_info.append(info)
|
|
193
|
+
|
|
167
194
|
@staticmethod
|
|
168
195
|
def _collect_apis_between_modules(graph):
|
|
169
196
|
"""
|
|
@@ -209,6 +236,44 @@ class GraphBuilder:
|
|
|
209
236
|
|
|
210
237
|
graph.root.subnodes = output
|
|
211
238
|
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _add_parameters_grad(graph, data_dict):
|
|
241
|
+
"""
|
|
242
|
+
将parameters_grad信息添加到graph中,
|
|
243
|
+
对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点
|
|
244
|
+
|
|
245
|
+
例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2
|
|
246
|
+
则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点
|
|
247
|
+
"""
|
|
248
|
+
prefixes = []
|
|
249
|
+
suffix = Const.SEP + Const.PARAMS_GRAD
|
|
250
|
+
for node_id in data_dict.keys():
|
|
251
|
+
if node_id not in graph.node_map and node_id.endswith(suffix):
|
|
252
|
+
prefixes.append(node_id.replace(suffix, ''))
|
|
253
|
+
|
|
254
|
+
max_info = {prefix: 0 for prefix in prefixes}
|
|
255
|
+
|
|
256
|
+
for key in graph.node_map.keys():
|
|
257
|
+
for prefix in prefixes:
|
|
258
|
+
# 构建正则表达式,匹配以 "backward.数字" 结尾的键
|
|
259
|
+
pattern = re.compile(r'^' + re.escape(prefix) + r'\.backward\.(\d+)$')
|
|
260
|
+
match = pattern.match(key)
|
|
261
|
+
if match:
|
|
262
|
+
num = int(match.group(1))
|
|
263
|
+
if num > max_info[prefix]:
|
|
264
|
+
max_info[prefix] = num
|
|
265
|
+
|
|
266
|
+
for prefix, num in max_info.items():
|
|
267
|
+
node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num)
|
|
268
|
+
node = graph.get_node(node_id)
|
|
269
|
+
if node:
|
|
270
|
+
parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node)
|
|
271
|
+
# 添加输入输出数据
|
|
272
|
+
node_data = data_dict.get(parameters_grad_node_id, {})
|
|
273
|
+
input_data, output_data = get_input_output(node_data, parameters_grad_node_id)
|
|
274
|
+
# 更新数据
|
|
275
|
+
graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
|
|
276
|
+
|
|
212
277
|
|
|
213
278
|
class GraphExportConfig:
|
|
214
279
|
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
import re
|
|
16
|
-
import math
|
|
17
16
|
from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
|
|
18
17
|
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
18
|
from msprobe.visualization.utils import GraphConst
|
|
@@ -23,7 +22,7 @@ from msprobe.core.compare.acc_compare import ModeConfig
|
|
|
23
22
|
# 用于将节点名字解析成对应的NodeOp的规则
|
|
24
23
|
op_patterns = [
|
|
25
24
|
# NodeOp.module
|
|
26
|
-
r'^(Module.|Cell
|
|
25
|
+
r'^(Module.|Cell.|optimizer|clip_grad)',
|
|
27
26
|
# NodeOp.function_api
|
|
28
27
|
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
|
|
29
28
|
]
|
|
@@ -57,8 +56,8 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
|
57
56
|
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
58
57
|
return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
|
|
59
58
|
else:
|
|
60
|
-
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
61
|
-
ms_comparator = MSComparator(mode_config)
|
|
59
|
+
from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig
|
|
60
|
+
ms_comparator = MSComparator(mode_config, MappingConfig())
|
|
62
61
|
ms_comparator.cross_frame = is_cross_frame
|
|
63
62
|
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
64
63
|
|
|
@@ -120,11 +119,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
|
|
|
120
119
|
return True
|
|
121
120
|
|
|
122
121
|
|
|
123
|
-
def format_node_data(data_dict):
|
|
122
|
+
def format_node_data(data_dict, node_id=None):
|
|
124
123
|
"""
|
|
125
|
-
|
|
124
|
+
删除节点数据中不需要展示的字段
|
|
126
125
|
"""
|
|
127
126
|
del_list = ['requires_grad', 'full_op_name']
|
|
127
|
+
if node_id and GraphConst.BATCH_P2P in node_id:
|
|
128
|
+
del_list.extend(['op', 'peer', 'tag', 'group_id'])
|
|
128
129
|
for _, value in data_dict.items():
|
|
129
130
|
if not isinstance(value, dict):
|
|
130
131
|
continue
|