mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
msprobe/pytorch/service.py
CHANGED
|
@@ -1,22 +1,42 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import functools
|
|
2
17
|
import os
|
|
3
|
-
|
|
4
18
|
from collections import namedtuple
|
|
19
|
+
|
|
5
20
|
import torch
|
|
6
21
|
from msprobe.core.common.const import Const
|
|
7
22
|
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
8
23
|
from msprobe.core.common.file_utils import create_directory
|
|
24
|
+
from msprobe.core.common.utils import print_tools_ends_info
|
|
9
25
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
10
26
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
11
27
|
from msprobe.core.data_dump.scope import BaseScope
|
|
28
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
12
29
|
from msprobe.pytorch.common.log import logger
|
|
13
30
|
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
31
|
+
from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
14
32
|
from msprobe.pytorch.hook_module import remove_dropout
|
|
15
33
|
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
16
34
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
17
35
|
from msprobe.pytorch.module_processer import ModuleProcesser
|
|
18
|
-
|
|
36
|
+
|
|
19
37
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
38
|
+
if torch_version_above_or_equal_2:
|
|
39
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
20
40
|
|
|
21
41
|
HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
|
|
22
42
|
|
|
@@ -32,6 +52,7 @@ class Service:
|
|
|
32
52
|
self.first_start = True
|
|
33
53
|
self.current_rank = None
|
|
34
54
|
self.dump_iter_dir = None
|
|
55
|
+
self.should_stop_service = False
|
|
35
56
|
self.attl = None
|
|
36
57
|
|
|
37
58
|
@staticmethod
|
|
@@ -39,14 +60,29 @@ class Service:
|
|
|
39
60
|
logger.info_on_rank_0("Data needed ends here.")
|
|
40
61
|
api_register.api_originality()
|
|
41
62
|
|
|
63
|
+
@staticmethod
|
|
64
|
+
def is_registered_backward_hook(module):
|
|
65
|
+
if hasattr(module, '_backward_hooks') and \
|
|
66
|
+
len(module._backward_hooks) > 0 and \
|
|
67
|
+
module._is_full_backward_hook is False:
|
|
68
|
+
return True
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
def check_register_full_backward_hook(self, module):
|
|
72
|
+
if self.is_registered_backward_hook(module):
|
|
73
|
+
module._backward_hooks.clear()
|
|
74
|
+
module._is_full_backward_hook = None
|
|
75
|
+
logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
|
|
76
|
+
|
|
42
77
|
def build_hook(self, module_type, name):
|
|
43
78
|
def pre_hook(api_or_module_name, module, args, kwargs):
|
|
79
|
+
if not self.should_execute_hook():
|
|
80
|
+
return args, kwargs
|
|
81
|
+
|
|
44
82
|
if module_type == BaseScope.Module_Type_Module:
|
|
45
83
|
api_or_module_name = module.mindstudio_reserved_name
|
|
46
84
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
47
85
|
|
|
48
|
-
if not self.switch:
|
|
49
|
-
return args, kwargs
|
|
50
86
|
if self.config.online_run_ut:
|
|
51
87
|
return None, None
|
|
52
88
|
if self.data_collector:
|
|
@@ -55,13 +91,13 @@ class Service:
|
|
|
55
91
|
return args, kwargs
|
|
56
92
|
|
|
57
93
|
def forward_hook(api_or_module_name, module, args, kwargs, output):
|
|
94
|
+
if not self.should_execute_hook():
|
|
95
|
+
return None
|
|
96
|
+
|
|
58
97
|
if module_type == BaseScope.Module_Type_Module:
|
|
59
98
|
api_or_module_name = module.mindstudio_reserved_name
|
|
60
99
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
61
100
|
|
|
62
|
-
if not self.switch:
|
|
63
|
-
return None
|
|
64
|
-
|
|
65
101
|
if self.config.online_run_ut:
|
|
66
102
|
if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
|
|
67
103
|
return None
|
|
@@ -80,18 +116,14 @@ class Service:
|
|
|
80
116
|
return forward_hook(api_or_module_name, module, args, {}, output)
|
|
81
117
|
|
|
82
118
|
def backward_hook(api_or_module_name, module, grad_input, grad_output):
|
|
119
|
+
if not self.should_execute_hook():
|
|
120
|
+
return
|
|
121
|
+
|
|
83
122
|
if module_type == BaseScope.Module_Type_Module:
|
|
84
123
|
api_or_module_name = module.mindstudio_reserved_name
|
|
85
124
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
86
125
|
|
|
87
|
-
if not self.switch:
|
|
88
|
-
return
|
|
89
|
-
|
|
90
126
|
if self.config.online_run_ut:
|
|
91
|
-
if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
|
|
92
|
-
return
|
|
93
|
-
api_data = ApiData(name[:-1], grad_input, {}, grad_output, self.current_iter, self.current_rank)
|
|
94
|
-
self.attl_send(api_data)
|
|
95
127
|
return
|
|
96
128
|
|
|
97
129
|
if self.data_collector:
|
|
@@ -105,26 +137,15 @@ class Service:
|
|
|
105
137
|
pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
|
|
106
138
|
forward_hook_fn = functools.partial(forward_hook, forward_name_template)
|
|
107
139
|
backward_hook_fn = functools.partial(backward_hook, backward_name_template)
|
|
108
|
-
forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
|
|
140
|
+
forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
|
|
141
|
+
forward_name_template)
|
|
109
142
|
return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
|
|
110
143
|
|
|
111
|
-
def step(self):
|
|
112
|
-
self.current_iter += 1
|
|
113
|
-
self.data_collector.update_iter(self.current_iter)
|
|
114
|
-
|
|
115
|
-
ModuleProcesser.reset_module_stats()
|
|
116
|
-
HOOKModule.reset_module_stats()
|
|
117
|
-
|
|
118
144
|
def start(self, model, api_origin=False):
|
|
119
|
-
self.
|
|
120
|
-
if self.config.step and self.current_iter > max(self.config.step):
|
|
121
|
-
if self.config.online_run_ut:
|
|
122
|
-
# send stop signal if online_run_ut
|
|
123
|
-
self.attl_stop()
|
|
124
|
-
self.stop()
|
|
125
|
-
raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
|
|
126
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
145
|
+
if self.need_stop_service():
|
|
127
146
|
return
|
|
147
|
+
|
|
148
|
+
self.model = model
|
|
128
149
|
if self.first_start:
|
|
129
150
|
try:
|
|
130
151
|
self.current_rank = get_rank_if_initialized()
|
|
@@ -138,13 +159,17 @@ class Service:
|
|
|
138
159
|
self.first_start = False
|
|
139
160
|
if api_origin:
|
|
140
161
|
api_register.api_modularity()
|
|
162
|
+
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
163
|
+
run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
|
|
141
164
|
self.switch = True
|
|
142
165
|
logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
143
|
-
if
|
|
166
|
+
if not self.config.online_run_ut:
|
|
144
167
|
self.create_dirs()
|
|
145
168
|
logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
146
169
|
|
|
147
170
|
def stop(self):
|
|
171
|
+
if self.should_stop_service:
|
|
172
|
+
return
|
|
148
173
|
if self.config.level == "L2":
|
|
149
174
|
return
|
|
150
175
|
if self.config.step and self.current_iter not in self.config.step:
|
|
@@ -152,14 +177,60 @@ class Service:
|
|
|
152
177
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
153
178
|
return
|
|
154
179
|
self.switch = False
|
|
155
|
-
if self.config.online_run_ut:
|
|
180
|
+
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
181
|
+
run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
|
|
156
182
|
return
|
|
157
183
|
self.data_collector.write_json()
|
|
158
184
|
|
|
185
|
+
def step(self):
|
|
186
|
+
if self.should_stop_service:
|
|
187
|
+
return
|
|
188
|
+
self.current_iter += 1
|
|
189
|
+
self.data_collector.update_iter(self.current_iter)
|
|
190
|
+
|
|
191
|
+
ModuleProcesser.reset_module_stats()
|
|
192
|
+
HOOKModule.reset_module_stats()
|
|
193
|
+
self.data_collector.data_writer.reset_cache()
|
|
194
|
+
|
|
195
|
+
if self.config.level == Const.LEVEL_L2:
|
|
196
|
+
self.data_collector.data_processor.reset_status()
|
|
197
|
+
|
|
198
|
+
def need_stop_service(self):
|
|
199
|
+
if self.should_stop_service:
|
|
200
|
+
return True
|
|
201
|
+
end_service = self.config.step and self.current_iter > max(self.config.step) or \
|
|
202
|
+
self.data_collector and self.data_collector.data_processor.is_terminated
|
|
203
|
+
if end_service:
|
|
204
|
+
if self.config.online_run_ut:
|
|
205
|
+
# send stop signal if online_run_ut
|
|
206
|
+
self.attl_stop()
|
|
207
|
+
if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
|
|
208
|
+
api_register.api_originality()
|
|
209
|
+
self.switch = False
|
|
210
|
+
self.should_stop_service = True
|
|
211
|
+
print_tools_ends_info()
|
|
212
|
+
return True
|
|
213
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
214
|
+
return True
|
|
215
|
+
return False
|
|
216
|
+
|
|
217
|
+
def should_execute_hook(self):
|
|
218
|
+
if not self.switch:
|
|
219
|
+
return False
|
|
220
|
+
if self.data_collector and self.data_collector.data_processor.is_terminated:
|
|
221
|
+
return False
|
|
222
|
+
return True
|
|
223
|
+
|
|
159
224
|
def create_dirs(self):
|
|
160
225
|
create_directory(self.config.dump_path)
|
|
161
226
|
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
162
227
|
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
228
|
+
if self.config.level == Const.LEVEL_L2:
|
|
229
|
+
create_directory(self.dump_iter_dir)
|
|
230
|
+
kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
|
|
231
|
+
self.config.kernel_config_path = kernel_config_path
|
|
232
|
+
return
|
|
233
|
+
|
|
163
234
|
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
164
235
|
create_directory(dump_dir)
|
|
165
236
|
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
@@ -187,14 +258,16 @@ class Service:
|
|
|
187
258
|
prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
|
|
188
259
|
module.__class__.__name__ + Const.SEP
|
|
189
260
|
|
|
190
|
-
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2
|
|
191
|
-
|
|
261
|
+
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
|
|
262
|
+
BaseScope.Module_Type_Module, prefix)
|
|
192
263
|
if torch_version_above_or_equal_2:
|
|
193
264
|
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
194
265
|
else:
|
|
266
|
+
self.check_register_full_backward_hook(module)
|
|
195
267
|
module.register_full_backward_hook(
|
|
196
268
|
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
197
269
|
module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
270
|
+
self.check_register_full_backward_hook(module)
|
|
198
271
|
module.register_full_backward_hook(backward_hook)
|
|
199
272
|
|
|
200
273
|
module.register_forward_pre_hook(
|
|
@@ -204,11 +277,13 @@ class Service:
|
|
|
204
277
|
if torch_version_above_or_equal_2:
|
|
205
278
|
module.register_full_backward_pre_hook(
|
|
206
279
|
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
280
|
+
self.check_register_full_backward_hook(module)
|
|
207
281
|
module.register_full_backward_hook(
|
|
208
282
|
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
209
283
|
|
|
210
284
|
if self.config.level in ["mix", "L1", "L2"]:
|
|
211
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)
|
|
285
|
+
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
|
|
286
|
+
self.config.online_run_ut)
|
|
212
287
|
api_register.api_modularity()
|
|
213
288
|
|
|
214
289
|
if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
from msprobe.visualization.graph.graph import Graph
|
|
18
|
+
from msprobe.visualization.graph.node_op import NodeOp
|
|
19
|
+
from msprobe.visualization.utils import save_json_file, GraphConst
|
|
20
|
+
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
|
+
from msprobe.core.common.file_utils import load_json
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GraphBuilder:
|
|
25
|
+
@staticmethod
|
|
26
|
+
def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
|
|
27
|
+
"""
|
|
28
|
+
GraphBuilder的对外提供的构图方法
|
|
29
|
+
Args:
|
|
30
|
+
construct_path: construct.json路径
|
|
31
|
+
data_path: dump.json路径
|
|
32
|
+
stack_path: stack.json路径
|
|
33
|
+
model_name: 模型名字,依赖外部输入
|
|
34
|
+
Returns: Graph,代表图的数据结构
|
|
35
|
+
"""
|
|
36
|
+
construct_dict = load_json(construct_path)
|
|
37
|
+
dump_dict = load_json(data_path)
|
|
38
|
+
stack_dict = load_json(stack_path)
|
|
39
|
+
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
40
|
+
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
41
|
+
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
42
|
+
GraphBuilder._collect_apis_between_modules(graph)
|
|
43
|
+
return graph
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def to_json(filename, config):
|
|
47
|
+
"""
|
|
48
|
+
将graph导出成.vis文件的接口
|
|
49
|
+
"""
|
|
50
|
+
result = {}
|
|
51
|
+
if config.graph_b:
|
|
52
|
+
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
|
|
53
|
+
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
|
|
54
|
+
else:
|
|
55
|
+
result = config.graph_n.to_dict()
|
|
56
|
+
if config.tool_tip:
|
|
57
|
+
result[GraphConst.JSON_TIP_KEY] = config.tool_tip
|
|
58
|
+
if config.node_colors:
|
|
59
|
+
result[GraphConst.COLORS] = config.node_colors
|
|
60
|
+
if config.micro_steps:
|
|
61
|
+
result[GraphConst.MICRO_STEPS] = config.micro_steps
|
|
62
|
+
if config.task:
|
|
63
|
+
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
64
|
+
save_json_file(filename, result)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
|
|
68
|
+
"""
|
|
69
|
+
如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
|
|
70
|
+
"""
|
|
71
|
+
# 匹配以.backward.后跟一个或多个数字结尾的模式
|
|
72
|
+
backward_pattern = r"(\.backward\.)(\d+)$"
|
|
73
|
+
forward_pattern = r"(\.forward\.)(\d+)$"
|
|
74
|
+
if re.search(backward_pattern, subnode_id) and not upnode_id:
|
|
75
|
+
forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
|
|
76
|
+
if forward_upnode_id:
|
|
77
|
+
new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
|
|
78
|
+
if new_upnode_id in construct_dict:
|
|
79
|
+
return new_upnode_id
|
|
80
|
+
return upnode_id
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def _init_nodes(graph, construct_dict, data_dict, stack_dict):
|
|
84
|
+
for subnode_id, upnode_id in construct_dict.items():
|
|
85
|
+
upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
|
|
86
|
+
if upnode_id:
|
|
87
|
+
upnode_op = NodeOp.get_node_op(upnode_id)
|
|
88
|
+
upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
|
|
89
|
+
else:
|
|
90
|
+
upnode = graph.root
|
|
91
|
+
node_op = NodeOp.get_node_op(subnode_id)
|
|
92
|
+
GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], node_op, subnode_id, upnode)
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _create_or_get_node(graph, data_stack_list, op, name, upnode=None):
|
|
96
|
+
if name in graph.node_map:
|
|
97
|
+
node = graph.get_node(name)
|
|
98
|
+
else:
|
|
99
|
+
graph.add_node(op, name, upnode)
|
|
100
|
+
node = graph.get_node(name)
|
|
101
|
+
node_data = data_stack_list[0].get(name, {})
|
|
102
|
+
node_stack_info = data_stack_list[1].get(name, [])
|
|
103
|
+
# 添加输入输出数据
|
|
104
|
+
input_data, output_data = get_input_output(node_data, node.id)
|
|
105
|
+
# 更新数据
|
|
106
|
+
node.set_input_output(input_data, output_data)
|
|
107
|
+
node.stack_info = node_stack_info
|
|
108
|
+
# 添加节点
|
|
109
|
+
node.add_upnode(upnode)
|
|
110
|
+
return node
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _collect_apis_between_modules(graph):
|
|
114
|
+
"""
|
|
115
|
+
图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点
|
|
116
|
+
Args:
|
|
117
|
+
graph: 模型结构
|
|
118
|
+
|
|
119
|
+
Returns: None
|
|
120
|
+
"""
|
|
121
|
+
i = 0
|
|
122
|
+
output = []
|
|
123
|
+
node_list = graph.root.subnodes
|
|
124
|
+
while i < len(node_list):
|
|
125
|
+
current_node = node_list[i]
|
|
126
|
+
|
|
127
|
+
# 当前节点为api,检查后续是否还有api
|
|
128
|
+
if current_node.op == NodeOp.function_api:
|
|
129
|
+
temp_nodes = [current_node]
|
|
130
|
+
i += 1
|
|
131
|
+
while i < len(node_list) and node_list[i].op == NodeOp.function_api:
|
|
132
|
+
temp_nodes.append(node_list[i])
|
|
133
|
+
i += 1
|
|
134
|
+
|
|
135
|
+
# 检查api节点是否大于等于2个
|
|
136
|
+
if len(temp_nodes) >= 2:
|
|
137
|
+
# 创建新节点,将这些api节点放入新节点的subnodes属性
|
|
138
|
+
node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES,
|
|
139
|
+
id_accumulation=True)
|
|
140
|
+
api_collection_node = graph.get_node(node_id)
|
|
141
|
+
api_collection_node.subnodes = temp_nodes
|
|
142
|
+
# 重新确立父子关系
|
|
143
|
+
for node in temp_nodes:
|
|
144
|
+
node.upnode = api_collection_node
|
|
145
|
+
api_collection_node.upnode = graph.root
|
|
146
|
+
output.append(api_collection_node)
|
|
147
|
+
else:
|
|
148
|
+
# 如果连续的api节点不足2个,将它们原样添加到输出列表
|
|
149
|
+
output.extend(temp_nodes)
|
|
150
|
+
else:
|
|
151
|
+
# 如果当前节点为module,直接添加到输出列表
|
|
152
|
+
output.append(current_node)
|
|
153
|
+
i += 1
|
|
154
|
+
|
|
155
|
+
graph.root.subnodes = output
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class GraphExportConfig:
|
|
159
|
+
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task=''):
|
|
160
|
+
self.graph_n = graph_n
|
|
161
|
+
self.graph_b = graph_b
|
|
162
|
+
self.tool_tip = tool_tip
|
|
163
|
+
self.node_colors = node_colors
|
|
164
|
+
self.micro_steps = micro_steps
|
|
165
|
+
self.task = task
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import re
|
|
16
|
+
import math
|
|
17
|
+
from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
|
|
18
|
+
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
|
+
from msprobe.visualization.utils import GraphConst
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
21
|
+
|
|
22
|
+
# 用于将节点名字解析成对应的NodeOp的规则
|
|
23
|
+
op_patterns = [
|
|
24
|
+
# NodeOp.module
|
|
25
|
+
r'^(Module.|Cell.)',
|
|
26
|
+
# NodeOp.function_api
|
|
27
|
+
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_compare_mode(dump_path_param):
|
|
32
|
+
"""
|
|
33
|
+
获得比较模式,包括summary、MD5和真实数据三种模式
|
|
34
|
+
Args:
|
|
35
|
+
dump_path_param: 调用acc_compare接口所依赖的参数
|
|
36
|
+
Returns: 0 summary mode, 1 md5 mode, 2 true data mode
|
|
37
|
+
"""
|
|
38
|
+
set_dump_path(dump_path_param)
|
|
39
|
+
dump_mode = get_dump_mode(dump_path_param)
|
|
40
|
+
compare_mode = GraphConst.DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING.get(dump_mode)
|
|
41
|
+
return compare_mode
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
45
|
+
"""
|
|
46
|
+
多进程运行生成真实数据
|
|
47
|
+
Args:
|
|
48
|
+
dump_path_param: 调用acc_compare接口所依赖的参数
|
|
49
|
+
csv_path: 生成文件路径
|
|
50
|
+
framework: 框架类型, pytorch或mindspore
|
|
51
|
+
is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
|
|
52
|
+
"""
|
|
53
|
+
if framework == Const.PT_FRAMEWORK:
|
|
54
|
+
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
55
|
+
return PTComparator().do_multi_process(dump_path_param, csv_path)
|
|
56
|
+
else:
|
|
57
|
+
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
58
|
+
ms_comparator = MSComparator()
|
|
59
|
+
ms_comparator.cross_frame = is_cross_frame
|
|
60
|
+
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_input_output(node_data, node_id):
|
|
64
|
+
"""
|
|
65
|
+
将dump的原始数据进行拆解,分解为output和input两个数据
|
|
66
|
+
Args:
|
|
67
|
+
node_data: 属于单个节点的dump数据
|
|
68
|
+
node_id: 节点名字
|
|
69
|
+
"""
|
|
70
|
+
input_data = {}
|
|
71
|
+
output_data = {}
|
|
72
|
+
op_parsed_list = read_op(node_data, node_id)
|
|
73
|
+
for item in op_parsed_list:
|
|
74
|
+
full_op_name = item.get('full_op_name', '')
|
|
75
|
+
if not full_op_name:
|
|
76
|
+
continue
|
|
77
|
+
if GraphConst.OUTPUT in full_op_name and GraphConst.INPUT not in full_op_name:
|
|
78
|
+
output_data[full_op_name] = item
|
|
79
|
+
else:
|
|
80
|
+
name = item.get('data_name')
|
|
81
|
+
# 节点参数名称尽量使用落盘数据的名称
|
|
82
|
+
if isinstance(name, str) and name != '-1':
|
|
83
|
+
input_data[name.rsplit(Const.SEP, 1)[0]] = item
|
|
84
|
+
else:
|
|
85
|
+
input_data[full_op_name] = item
|
|
86
|
+
return input_data, output_data
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def compare_data(data_dict_list1, data_dict_list2):
|
|
90
|
+
"""
|
|
91
|
+
比较get_input_output中输出的结果是否结构一致,比较一致返回True
|
|
92
|
+
"""
|
|
93
|
+
if len(data_dict_list1) != len(data_dict_list2):
|
|
94
|
+
return False
|
|
95
|
+
# 用于比较两个节点是否相等的关键字段
|
|
96
|
+
tag_keys = ['type', 'shape']
|
|
97
|
+
for key1, key2 in zip(data_dict_list1, data_dict_list2):
|
|
98
|
+
dict1 = data_dict_list1[key1]
|
|
99
|
+
dict2 = data_dict_list2[key2]
|
|
100
|
+
for tag_key in tag_keys:
|
|
101
|
+
tag_value1 = dict1.get(tag_key, None)
|
|
102
|
+
tag_value2 = dict2.get(tag_key, None)
|
|
103
|
+
if tag_value1 != tag_value2:
|
|
104
|
+
return False
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def format_node_data(data_dict):
|
|
109
|
+
"""
|
|
110
|
+
批量进行节点数据的输出
|
|
111
|
+
"""
|
|
112
|
+
del_list = ['requires_grad', 'full_op_name']
|
|
113
|
+
for _, value in data_dict.items():
|
|
114
|
+
if not isinstance(value, dict):
|
|
115
|
+
continue
|
|
116
|
+
for item in del_list:
|
|
117
|
+
if item in value:
|
|
118
|
+
del value[item]
|
|
119
|
+
_format_data(value)
|
|
120
|
+
return data_dict
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def compare_node(node_ids, data_dicts, stack_json_data, compare_mode):
|
|
124
|
+
"""
|
|
125
|
+
调用acc_compare.py中的get_accuracy获得精度对比指标
|
|
126
|
+
真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
|
|
127
|
+
Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
|
|
128
|
+
"""
|
|
129
|
+
merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
|
|
130
|
+
merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
|
|
131
|
+
result = []
|
|
132
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
133
|
+
get_accuracy(result, merge_n, merge_b, dump_mode)
|
|
134
|
+
return result
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _parse_node(node_id, data_dict, stack_json_data, compare_mode):
|
|
138
|
+
"""
|
|
139
|
+
转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
|
|
140
|
+
"""
|
|
141
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
142
|
+
op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
|
|
143
|
+
if node_id in stack_json_data:
|
|
144
|
+
op_parsed_list.append(
|
|
145
|
+
{'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
|
|
146
|
+
else:
|
|
147
|
+
op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
|
|
148
|
+
result = merge_tensor(op_parsed_list, dump_mode)
|
|
149
|
+
if not result:
|
|
150
|
+
result['op_name'] = []
|
|
151
|
+
return result
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _format_decimal_string(s):
|
|
155
|
+
"""
|
|
156
|
+
使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
|
|
157
|
+
"""
|
|
158
|
+
pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
|
|
159
|
+
matches = pattern.findall(s)
|
|
160
|
+
for match in matches:
|
|
161
|
+
is_percent = match.endswith('%')
|
|
162
|
+
number_str = match.rstrip('%')
|
|
163
|
+
decimal_part = number_str.split('.')[1]
|
|
164
|
+
# 如果小数位数大于6,进行处理
|
|
165
|
+
if len(decimal_part) > GraphConst.ROUND_TH:
|
|
166
|
+
number_float = float(number_str)
|
|
167
|
+
formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}"
|
|
168
|
+
# 如果原来是百分数,加回百分号
|
|
169
|
+
if is_percent:
|
|
170
|
+
formatted_number += '%'
|
|
171
|
+
# 替换原字符串中的数值部分
|
|
172
|
+
s = s.replace(match, formatted_number)
|
|
173
|
+
return s
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _format_data(data_dict):
|
|
177
|
+
"""
|
|
178
|
+
格式化数据,小数保留6位,处理一些异常值
|
|
179
|
+
"""
|
|
180
|
+
pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
|
|
181
|
+
all_null = False
|
|
182
|
+
for key, value in data_dict.items():
|
|
183
|
+
if isinstance(value, str):
|
|
184
|
+
# 将单引号删掉,None换成null避免前端解析错误
|
|
185
|
+
value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL)
|
|
186
|
+
value = _format_decimal_string(value)
|
|
187
|
+
elif value is None or value == ' ':
|
|
188
|
+
value = GraphConst.NULL
|
|
189
|
+
# 科学计数法1.123123123123e-11,格式化为1.123123e-11
|
|
190
|
+
elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)):
|
|
191
|
+
value = "{:.6e}".format(value)
|
|
192
|
+
elif isinstance(value, float):
|
|
193
|
+
value = round(value, GraphConst.ROUND_TH)
|
|
194
|
+
# Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案
|
|
195
|
+
if key != GraphConst.ERROR_KEY:
|
|
196
|
+
# 除了error_key不转str,其他都转str, 避免前端解析错误
|
|
197
|
+
value = str(value)
|
|
198
|
+
# max为null, 意味着这个参数值为null
|
|
199
|
+
if key == Const.MAX and value == GraphConst.NULL:
|
|
200
|
+
all_null = True
|
|
201
|
+
data_dict[key] = value
|
|
202
|
+
# 字典里的value全null,只保留一个null
|
|
203
|
+
if all_null:
|
|
204
|
+
data_dict.clear()
|
|
205
|
+
data_dict[GraphConst.VALUE] = GraphConst.NULL
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|