mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
17
|
+
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
|
|
18
|
+
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
19
|
+
from msprobe.visualization.graph.node_colors import NodeColors
|
|
20
|
+
from msprobe.visualization.compare.mode_adapter import ModeAdapter
|
|
21
|
+
from msprobe.core.common.const import Const
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GraphComparator:
|
|
25
|
+
def __init__(self, graphs, dump_path_param, output_path, framework=Const.PT_FRAMEWORK, mapping_dict=None):
|
|
26
|
+
self.graph_n = graphs[0]
|
|
27
|
+
self.graph_b = graphs[1]
|
|
28
|
+
self._parse_param(dump_path_param, output_path)
|
|
29
|
+
self.framework = framework
|
|
30
|
+
self.mapping_dict = mapping_dict
|
|
31
|
+
|
|
32
|
+
def compare(self):
|
|
33
|
+
"""
|
|
34
|
+
比较函数,初始化结束后单独调用。比较结果写入graph_n
|
|
35
|
+
"""
|
|
36
|
+
self._compare_nodes(self.graph_n.root)
|
|
37
|
+
self._postcompare()
|
|
38
|
+
|
|
39
|
+
def add_compare_result_to_node(self, node, compare_result_list):
|
|
40
|
+
"""
|
|
41
|
+
将比对结果添加到节点的输入输出数据中
|
|
42
|
+
Args:
|
|
43
|
+
node: 节点
|
|
44
|
+
compare_result_list: 包含参数信息和对比指标(真实数据对比模式除外)的list
|
|
45
|
+
"""
|
|
46
|
+
# 真实数据比对,先暂存节点,在多进程对比得到精度指标后,再将指标添加到节点中
|
|
47
|
+
if self.ma.prepare_real_data(node):
|
|
48
|
+
return
|
|
49
|
+
compare_in_dict = {}
|
|
50
|
+
compare_out_dict = {}
|
|
51
|
+
# input和output对比数据分开
|
|
52
|
+
for item in compare_result_list:
|
|
53
|
+
if not isinstance(item, (list, tuple)) or not item:
|
|
54
|
+
continue
|
|
55
|
+
if '.output.' in item[0]:
|
|
56
|
+
compare_out_dict[item[0]] = item
|
|
57
|
+
else:
|
|
58
|
+
compare_in_dict[item[0]] = item
|
|
59
|
+
precision_index, other_dict = (
|
|
60
|
+
self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
|
|
61
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
62
|
+
node.data.update(other_dict)
|
|
63
|
+
if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
|
|
64
|
+
node.get_suggestions()
|
|
65
|
+
|
|
66
|
+
def _parse_param(self, dump_path_param, output_path):
|
|
67
|
+
self.dump_path_param = dump_path_param
|
|
68
|
+
self.output_path = output_path
|
|
69
|
+
compare_mode = get_compare_mode(self.dump_path_param)
|
|
70
|
+
self.ma = ModeAdapter(compare_mode)
|
|
71
|
+
self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
|
|
72
|
+
self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
|
|
73
|
+
self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
|
|
74
|
+
|
|
75
|
+
def _postcompare(self):
|
|
76
|
+
self._handle_api_collection_index()
|
|
77
|
+
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
78
|
+
return
|
|
79
|
+
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
80
|
+
df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
|
|
81
|
+
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
82
|
+
for node in self.ma.compare_nodes:
|
|
83
|
+
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
84
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
85
|
+
if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
|
|
86
|
+
node.get_suggestions()
|
|
87
|
+
|
|
88
|
+
def _handle_api_collection_index(self):
|
|
89
|
+
"""
|
|
90
|
+
api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
|
|
91
|
+
md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
|
|
92
|
+
"""
|
|
93
|
+
for node in self.graph_n.root.subnodes:
|
|
94
|
+
if node.op == NodeOp.api_collection:
|
|
95
|
+
precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
96
|
+
else GraphConst.MIN_INDEX_KEY
|
|
97
|
+
for api in node.subnodes:
|
|
98
|
+
precision_index = min(precision_index,
|
|
99
|
+
api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
|
|
100
|
+
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
101
|
+
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
102
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
103
|
+
|
|
104
|
+
def _compare_nodes(self, node_n):
|
|
105
|
+
"""
|
|
106
|
+
递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
|
|
107
|
+
这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
|
|
108
|
+
"""
|
|
109
|
+
if self.mapping_dict:
|
|
110
|
+
node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
|
|
111
|
+
if node_b:
|
|
112
|
+
ancestors_n.append(node_n.id)
|
|
113
|
+
ancestors_b.append(node_b.id)
|
|
114
|
+
node_n.matched_node_link = ancestors_b
|
|
115
|
+
node_b.matched_node_link = ancestors_n
|
|
116
|
+
else:
|
|
117
|
+
node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
|
|
118
|
+
if node_b:
|
|
119
|
+
ancestors.append(node_b.id)
|
|
120
|
+
node_n.add_link(node_b, ancestors)
|
|
121
|
+
if node_b:
|
|
122
|
+
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
123
|
+
compare_result_list = compare_node([node_n.id, node_b.id],
|
|
124
|
+
[self.data_n_dict, self.data_b_dict],
|
|
125
|
+
self.stack_json_data, self.ma.compare_mode)
|
|
126
|
+
if compare_result_list:
|
|
127
|
+
self.ma.add_csv_data(compare_result_list)
|
|
128
|
+
self.add_compare_result_to_node(node_n, compare_result_list)
|
|
129
|
+
for subnode in node_n.subnodes:
|
|
130
|
+
self._compare_nodes(subnode)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import math
|
|
18
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
19
|
+
from msprobe.visualization.utils import ToolTip, GraphConst, str2float
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ModeAdapter:
|
|
23
|
+
def __init__(self, compare_mode):
|
|
24
|
+
self.compare_mode = compare_mode
|
|
25
|
+
self.csv_data = []
|
|
26
|
+
self.compare_nodes = []
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def _add_md5_compare_data(node_data, compare_data_dict):
|
|
30
|
+
precision_index = GraphConst.MAX_INDEX_KEY
|
|
31
|
+
for key, value in node_data.items():
|
|
32
|
+
if not isinstance(value, dict):
|
|
33
|
+
continue
|
|
34
|
+
compare_data = compare_data_dict.get(key)
|
|
35
|
+
if compare_data:
|
|
36
|
+
headers = CompareConst.MD5_COMPARE_RESULT_HEADER
|
|
37
|
+
id_list = [headers.index(x) for x in GraphConst.MD5_INDEX_LIST]
|
|
38
|
+
ModeAdapter._match_data(value, compare_data, GraphConst.MD5_INDEX_LIST, id_list)
|
|
39
|
+
# md5比对是否通过
|
|
40
|
+
if value.get(CompareConst.RESULT) != CompareConst.PASS:
|
|
41
|
+
precision_index = GraphConst.MIN_INDEX_KEY
|
|
42
|
+
node_data[key] = value
|
|
43
|
+
return precision_index
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def _add_real_compare_data(node_data, compare_data_dict):
|
|
47
|
+
min_thousandth = float(1)
|
|
48
|
+
numbers = []
|
|
49
|
+
for key, value in node_data.items():
|
|
50
|
+
if not isinstance(value, dict):
|
|
51
|
+
continue
|
|
52
|
+
compare_data = compare_data_dict.get(key)
|
|
53
|
+
if compare_data:
|
|
54
|
+
headers = CompareConst.COMPARE_RESULT_HEADER
|
|
55
|
+
id_list = [headers.index(x) for x in GraphConst.REAL_DATA_INDEX_LIST]
|
|
56
|
+
ModeAdapter._match_data(value, compare_data, GraphConst.REAL_DATA_INDEX_LIST, id_list)
|
|
57
|
+
# 跳过scalar data,因为无法计算双千指标,会得到Nan
|
|
58
|
+
if not value.get(Const.SHAPE):
|
|
59
|
+
continue
|
|
60
|
+
# 获取一个节点所有的输入或输出最小的双千指标
|
|
61
|
+
thousandth = value.get(CompareConst.ONE_THOUSANDTH_ERR_RATIO)
|
|
62
|
+
# 可能是None,可能是非数字内容str
|
|
63
|
+
try:
|
|
64
|
+
thousandth = float(thousandth)
|
|
65
|
+
except (ValueError, TypeError):
|
|
66
|
+
thousandth = None
|
|
67
|
+
if thousandth is not None:
|
|
68
|
+
numbers.append(thousandth)
|
|
69
|
+
node_data[key] = value
|
|
70
|
+
# 双千指标都是None的异常情况
|
|
71
|
+
if not numbers:
|
|
72
|
+
min_thousandth = None
|
|
73
|
+
else:
|
|
74
|
+
min_thousandth = min(numbers + [min_thousandth])
|
|
75
|
+
return min_thousandth
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _add_summary_compare_data(node_data, compare_data_dict):
|
|
79
|
+
max_relative_err = GraphConst.MIN_INDEX_KEY
|
|
80
|
+
# data_info: {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [2, 536320], 'Max': 9.66036224, ...}
|
|
81
|
+
for key, data_info in node_data.items():
|
|
82
|
+
if not isinstance(data_info, dict):
|
|
83
|
+
continue
|
|
84
|
+
compare_data = compare_data_dict.get(key)
|
|
85
|
+
if compare_data:
|
|
86
|
+
dtype = data_info.get(Const.DTYPE)
|
|
87
|
+
# 对应比对结果csv的列
|
|
88
|
+
key_list = GraphConst.SUMMARY_INDEX_LIST
|
|
89
|
+
headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER
|
|
90
|
+
id_list = [headers.index(x) for x in key_list]
|
|
91
|
+
ModeAdapter._match_data(data_info, compare_data, key_list, id_list)
|
|
92
|
+
for index, item in enumerate(key_list[4:]):
|
|
93
|
+
value = data_info.get(GraphConst.VALUE_INDEX_LIST[index])
|
|
94
|
+
value_diff = data_info.get(key_list[index])
|
|
95
|
+
relative_err = str2float(data_info.get(item))
|
|
96
|
+
if isinstance(value, float) and isinstance(value_diff, float) \
|
|
97
|
+
and dtype in GraphConst.SMALL_VALUES.keys():
|
|
98
|
+
small_value = GraphConst.SMALL_VALUES.get(dtype)
|
|
99
|
+
# 小值域
|
|
100
|
+
if abs(value) <= small_value:
|
|
101
|
+
data_info[item] = ToolTip.SMALL_VALUE_TIP.format(data_info.get(item),
|
|
102
|
+
GraphConst.VALUE_INDEX_LIST[index],
|
|
103
|
+
small_value)
|
|
104
|
+
relative_err = GraphConst.MIN_INDEX_KEY \
|
|
105
|
+
if abs(value_diff) <= GraphConst.SMALL_VALUES_ABS_ERROR.get(dtype) \
|
|
106
|
+
else GraphConst.MAX_INDEX_KEY
|
|
107
|
+
max_relative_err = max(max_relative_err, relative_err)
|
|
108
|
+
node_data[key] = data_info
|
|
109
|
+
max_relative_err = 1 if max_relative_err > 1 else max_relative_err
|
|
110
|
+
return max_relative_err
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _match_data(data_dict, compare_data, key_list, id_list):
|
|
114
|
+
"""
|
|
115
|
+
绑定精度指标到node的input_data和output_data
|
|
116
|
+
"""
|
|
117
|
+
if len(key_list) != len(id_list):
|
|
118
|
+
return
|
|
119
|
+
for id_val, key in zip(id_list, key_list):
|
|
120
|
+
data_dict[key] = compare_data[id_val]
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _check_list_len(data_list, len_num):
|
|
124
|
+
if len(data_list) < len_num:
|
|
125
|
+
raise ValueError(f"compare_data_dict_list must contain at least {len_num} items.")
|
|
126
|
+
|
|
127
|
+
def parse_result(self, node, compare_data_dict_list):
|
|
128
|
+
"""
|
|
129
|
+
根据结果返回数据,分别是precision_index,和附加数据
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
other_dict = {}
|
|
133
|
+
if self.compare_mode == GraphConst.MD5_COMPARE:
|
|
134
|
+
ModeAdapter._check_list_len(compare_data_dict_list, 2)
|
|
135
|
+
precision_index_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict_list[0])
|
|
136
|
+
precision_index_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict_list[1])
|
|
137
|
+
# 所有输入输出md5对比通过,这个节点才算通过
|
|
138
|
+
precision_index = min(precision_index_in, precision_index_out)
|
|
139
|
+
other_result = CompareConst.PASS if precision_index == GraphConst.MAX_INDEX_KEY else CompareConst.DIFF
|
|
140
|
+
other_dict[CompareConst.RESULT] = other_result
|
|
141
|
+
elif self.compare_mode == GraphConst.SUMMARY_COMPARE:
|
|
142
|
+
ModeAdapter._check_list_len(compare_data_dict_list, 2)
|
|
143
|
+
ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict_list[0])
|
|
144
|
+
precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict_list[1])
|
|
145
|
+
precision_index = precision_index_out
|
|
146
|
+
else:
|
|
147
|
+
ModeAdapter._check_list_len(compare_data_dict_list, 1)
|
|
148
|
+
min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0])
|
|
149
|
+
min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0])
|
|
150
|
+
if min_thousandth_in is not None and min_thousandth_out is not None:
|
|
151
|
+
change_percentage = min_thousandth_in - min_thousandth_out
|
|
152
|
+
else:
|
|
153
|
+
change_percentage = GraphConst.MIN_INDEX_KEY
|
|
154
|
+
change_percentage = GraphConst.MIN_INDEX_KEY if change_percentage < GraphConst.MIN_INDEX_KEY \
|
|
155
|
+
else change_percentage
|
|
156
|
+
precision_index = GraphConst.MAX_INDEX_KEY \
|
|
157
|
+
if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage
|
|
158
|
+
return precision_index, other_dict
|
|
159
|
+
|
|
160
|
+
def prepare_real_data(self, node):
|
|
161
|
+
"""
|
|
162
|
+
为真实数据比较模式准备节点信息
|
|
163
|
+
"""
|
|
164
|
+
if self.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
165
|
+
self.compare_nodes.append(node)
|
|
166
|
+
return True
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
def add_csv_data(self, compare_result_list):
|
|
170
|
+
if self.compare_mode != GraphConst.REAL_DATA_COMPARE:
|
|
171
|
+
return
|
|
172
|
+
self.csv_data.extend(compare_result_list)
|
|
173
|
+
|
|
174
|
+
def add_error_key(self, node_data):
|
|
175
|
+
"""
|
|
176
|
+
根据不同的模式进行提供不同错误信息
|
|
177
|
+
"""
|
|
178
|
+
for key, value in node_data.items():
|
|
179
|
+
if not isinstance(value, dict):
|
|
180
|
+
continue
|
|
181
|
+
if self.compare_mode == GraphConst.SUMMARY_COMPARE:
|
|
182
|
+
message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
|
|
183
|
+
CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
|
|
184
|
+
elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
185
|
+
message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
|
|
186
|
+
else:
|
|
187
|
+
# 输出件优化
|
|
188
|
+
message = []
|
|
189
|
+
value[GraphConst.ERROR_KEY] = message
|
|
190
|
+
node_data[key] = value
|
|
191
|
+
|
|
192
|
+
def get_tool_tip(self):
|
|
193
|
+
"""
|
|
194
|
+
用于前端展示字段的具体含义
|
|
195
|
+
"""
|
|
196
|
+
if self.compare_mode == GraphConst.SUMMARY_COMPARE:
|
|
197
|
+
tips = {
|
|
198
|
+
CompareConst.MAX_DIFF: ToolTip.MAX_DIFF,
|
|
199
|
+
CompareConst.MIN_DIFF: ToolTip.MIN_DIFF,
|
|
200
|
+
CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF,
|
|
201
|
+
CompareConst.NORM_DIFF: ToolTip.NORM_DIFF}
|
|
202
|
+
elif self.compare_mode == GraphConst.MD5_COMPARE:
|
|
203
|
+
tips = {Const.MD5: ToolTip.MD5}
|
|
204
|
+
else:
|
|
205
|
+
tips = {
|
|
206
|
+
CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO,
|
|
207
|
+
CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
208
|
+
CompareConst.COSINE: ToolTip.COSINE,
|
|
209
|
+
CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
|
|
210
|
+
CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}
|
|
211
|
+
return json.dumps(tips)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from msprobe.core.overflow_check.level import OverflowLevel
|
|
16
|
+
from msprobe.visualization.graph.node_op import NodeOp
|
|
17
|
+
from msprobe.visualization.utils import Suggestions, GraphConst
|
|
18
|
+
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseNode:
|
|
22
|
+
def __init__(self, node_op, node_id, up_node=None):
|
|
23
|
+
self.op = node_op
|
|
24
|
+
self.id = node_id
|
|
25
|
+
self.data = {}
|
|
26
|
+
self.output_data = {}
|
|
27
|
+
self.input_data = {}
|
|
28
|
+
self.upnode = None
|
|
29
|
+
self.add_upnode(up_node)
|
|
30
|
+
self.subnodes = []
|
|
31
|
+
self.matched_node_link = []
|
|
32
|
+
self.suggestions = {}
|
|
33
|
+
self.stack_info = []
|
|
34
|
+
self.micro_step_id = None
|
|
35
|
+
self.overflow_level = None
|
|
36
|
+
|
|
37
|
+
def __str__(self):
|
|
38
|
+
info = f'id:\t{self.id}'
|
|
39
|
+
return info
|
|
40
|
+
|
|
41
|
+
def __eq__(self, other):
|
|
42
|
+
"""
|
|
43
|
+
用来判断两个节点是否可以被匹配上,认为结构上是否一致
|
|
44
|
+
"""
|
|
45
|
+
if not compare_data(self.input_data, other.input_data):
|
|
46
|
+
return False
|
|
47
|
+
if not compare_data(self.output_data, other.output_data):
|
|
48
|
+
return False
|
|
49
|
+
return True
|
|
50
|
+
|
|
51
|
+
def get_suggestions(self):
|
|
52
|
+
"""
|
|
53
|
+
精度疑似有问题时,提供一些建议
|
|
54
|
+
"""
|
|
55
|
+
if self.op == NodeOp.module:
|
|
56
|
+
self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module
|
|
57
|
+
self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL
|
|
58
|
+
elif self.op == NodeOp.function_api:
|
|
59
|
+
self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API
|
|
60
|
+
self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL
|
|
61
|
+
|
|
62
|
+
def set_input_output(self, input_data, output_data):
|
|
63
|
+
self.input_data = input_data
|
|
64
|
+
self.output_data = output_data
|
|
65
|
+
|
|
66
|
+
def set_overflow_level(self, level):
|
|
67
|
+
if not level or not isinstance(level, OverflowLevel):
|
|
68
|
+
return
|
|
69
|
+
self.overflow_level = level
|
|
70
|
+
|
|
71
|
+
def add_upnode(self, node):
|
|
72
|
+
"""
|
|
73
|
+
绑定upnode,用于对两个节点进行上下级关联
|
|
74
|
+
"""
|
|
75
|
+
if not node or node.id == self.id or self.upnode:
|
|
76
|
+
return
|
|
77
|
+
self.upnode = node
|
|
78
|
+
node.subnodes.append(self)
|
|
79
|
+
|
|
80
|
+
def add_link(self, node, ancestors):
|
|
81
|
+
"""
|
|
82
|
+
在节点匹配成功后进行匹配数据的录入
|
|
83
|
+
Args:
|
|
84
|
+
node: 和self相互匹配的节点
|
|
85
|
+
ancestors: 对面节点的祖先信息
|
|
86
|
+
"""
|
|
87
|
+
self.matched_node_link = ancestors
|
|
88
|
+
node.matched_node_link = ancestors
|
|
89
|
+
|
|
90
|
+
def to_dict(self):
|
|
91
|
+
"""
|
|
92
|
+
输出数据
|
|
93
|
+
"""
|
|
94
|
+
result = {
|
|
95
|
+
'id': self.id,
|
|
96
|
+
'node_type': self.op.value,
|
|
97
|
+
'output_data': format_node_data(self.output_data),
|
|
98
|
+
'input_data': format_node_data(self.input_data),
|
|
99
|
+
'upnode': self.upnode.id if self.upnode else 'None',
|
|
100
|
+
'subnodes': [node.id for node in self.subnodes],
|
|
101
|
+
'matched_node_link': self.matched_node_link,
|
|
102
|
+
'suggestions': self.suggestions,
|
|
103
|
+
'stack_info': self.stack_info
|
|
104
|
+
}
|
|
105
|
+
if self.micro_step_id is not None:
|
|
106
|
+
result['micro_step_id'] = self.micro_step_id
|
|
107
|
+
# 是否存在overflow,并保存结果
|
|
108
|
+
if self.overflow_level and isinstance(self.overflow_level, OverflowLevel):
|
|
109
|
+
if self.data is None:
|
|
110
|
+
self.data = dict()
|
|
111
|
+
self.data['overflow_level'] = self.overflow_level.value
|
|
112
|
+
result['data'] = self.data
|
|
113
|
+
return result
|
|
114
|
+
|
|
115
|
+
def get_ancestors(self):
|
|
116
|
+
"""
|
|
117
|
+
获取节点所有祖先的列表
|
|
118
|
+
"""
|
|
119
|
+
ancestors = []
|
|
120
|
+
current_node = self.upnode
|
|
121
|
+
while current_node:
|
|
122
|
+
ancestors.append(current_node.id)
|
|
123
|
+
current_node = current_node.upnode
|
|
124
|
+
return list(reversed(ancestors))
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from msprobe.core.overflow_check.checker import AnomalyDetector
|
|
16
|
+
from msprobe.visualization.graph.base_node import BaseNode
|
|
17
|
+
from msprobe.visualization.graph.node_op import NodeOp
|
|
18
|
+
from msprobe.visualization.utils import GraphConst
|
|
19
|
+
from msprobe.core.common.log import logger
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
MAX_RECUR_LEVEL = 100
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Graph:
|
|
27
|
+
def __init__(self, model_name, data_path='', dump_data=None):
|
|
28
|
+
self.node_map = {}
|
|
29
|
+
self.node_id_map = {}
|
|
30
|
+
self.add_node(NodeOp.module, model_name)
|
|
31
|
+
self.root = self.get_node(model_name)
|
|
32
|
+
self.data_path = data_path
|
|
33
|
+
self.dump_data = dump_data
|
|
34
|
+
|
|
35
|
+
def __str__(self):
|
|
36
|
+
infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
|
|
37
|
+
info = "\n".join(infos)
|
|
38
|
+
return info
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def match(graph_n, node_n, graph_b):
|
|
42
|
+
"""
|
|
43
|
+
给定节点n,在另一个graph中匹配它对应的节点。前置条件是它的父节点匹配已经完成
|
|
44
|
+
目前采用完全匹配的方式,后续可能在这里加入一定的模糊匹配逻辑
|
|
45
|
+
返回匹配结果,匹配到的节点,以及祖先树。没匹配到则返回None, []
|
|
46
|
+
"""
|
|
47
|
+
if not node_n or node_n.id not in graph_b.node_map:
|
|
48
|
+
return None, []
|
|
49
|
+
node_b = graph_b.node_map.get(node_n.id)
|
|
50
|
+
if node_n != node_b:
|
|
51
|
+
return None, []
|
|
52
|
+
ancestors_n = node_n.get_ancestors()
|
|
53
|
+
ancestors_b = node_b.get_ancestors()
|
|
54
|
+
if ancestors_n != ancestors_b:
|
|
55
|
+
return None, []
|
|
56
|
+
return node_b, ancestors_n
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def mapping_match(node_n, graph_b, mapping_dict):
|
|
60
|
+
"""
|
|
61
|
+
根据映射配置对节点进行匹配
|
|
62
|
+
"""
|
|
63
|
+
node_b = graph_b.node_map.get(mapping_dict.get(node_n.id, node_n.id))
|
|
64
|
+
if not node_b:
|
|
65
|
+
return None, [], []
|
|
66
|
+
ancestors_n = node_n.get_ancestors()
|
|
67
|
+
ancestors_b = node_b.get_ancestors()
|
|
68
|
+
return node_b, ancestors_n, ancestors_b
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def dfs(node, result):
|
|
72
|
+
info = node.to_dict()
|
|
73
|
+
result[node.id] = info
|
|
74
|
+
for subnode in node.subnodes:
|
|
75
|
+
Graph.dfs(subnode, result)
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def split_nodes_by_micro_step(nodes):
|
|
79
|
+
"""
|
|
80
|
+
根据Module名称, 区分一个step中的多个micro steps.
|
|
81
|
+
一个micro step必须是一次完整的前反向过程
|
|
82
|
+
Example::
|
|
83
|
+
=============== micro step0
|
|
84
|
+
Module.forward
|
|
85
|
+
Module.forward
|
|
86
|
+
...
|
|
87
|
+
Module.backward
|
|
88
|
+
Module.backward
|
|
89
|
+
=============== micro step1
|
|
90
|
+
Module.forward
|
|
91
|
+
Module.forward
|
|
92
|
+
...
|
|
93
|
+
Module.backward
|
|
94
|
+
Module.backward
|
|
95
|
+
=============== micro step2
|
|
96
|
+
Module.forward
|
|
97
|
+
Module.forward
|
|
98
|
+
...
|
|
99
|
+
Module.backward
|
|
100
|
+
Module.backward
|
|
101
|
+
|
|
102
|
+
如果是非Module节点,分类到前一个Module节点所在的micro step.
|
|
103
|
+
"""
|
|
104
|
+
result = {}
|
|
105
|
+
micro_step = 0
|
|
106
|
+
result[micro_step] = []
|
|
107
|
+
backward_flag = False
|
|
108
|
+
|
|
109
|
+
for node in nodes:
|
|
110
|
+
if node.op == NodeOp.module:
|
|
111
|
+
if f'{Const.SEP}{Const.FORWARD}{Const.SEP}' in node.id:
|
|
112
|
+
if backward_flag:
|
|
113
|
+
micro_step += 1
|
|
114
|
+
result[micro_step] = []
|
|
115
|
+
backward_flag = False
|
|
116
|
+
else:
|
|
117
|
+
backward_flag = True
|
|
118
|
+
result[micro_step].append(node)
|
|
119
|
+
return result
|
|
120
|
+
|
|
121
|
+
def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
|
|
122
|
+
"""
|
|
123
|
+
在graph中进行节点的添加
|
|
124
|
+
Args:
|
|
125
|
+
node_op: 需要添加的节点类型
|
|
126
|
+
node_id: 需要添加的节点id
|
|
127
|
+
up_node:对应节点的父节点
|
|
128
|
+
id_accumulation: 是否对传入的重复node_id进行累加
|
|
129
|
+
"""
|
|
130
|
+
if node_id in self.node_map:
|
|
131
|
+
if id_accumulation:
|
|
132
|
+
self.node_id_map[node_id] = 0
|
|
133
|
+
else:
|
|
134
|
+
return node_id
|
|
135
|
+
if id_accumulation:
|
|
136
|
+
if node_id in self.node_id_map:
|
|
137
|
+
self.node_id_map[node_id] += 1
|
|
138
|
+
else:
|
|
139
|
+
self.node_id_map[node_id] = 0
|
|
140
|
+
node_id = f'{node_id}.{self.node_id_map[node_id]}'
|
|
141
|
+
node = BaseNode(node_op, node_id, up_node)
|
|
142
|
+
self.node_map[node_id] = node
|
|
143
|
+
return node_id
|
|
144
|
+
|
|
145
|
+
def get_node(self, node_id):
|
|
146
|
+
"""
|
|
147
|
+
返回节点,不存在返回None
|
|
148
|
+
"""
|
|
149
|
+
return self.node_map.get(node_id, None)
|
|
150
|
+
|
|
151
|
+
def to_dict(self):
|
|
152
|
+
"""
|
|
153
|
+
用于数据输出
|
|
154
|
+
"""
|
|
155
|
+
result = {}
|
|
156
|
+
result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
|
|
157
|
+
result[GraphConst.JSON_DATA_KEY] = self.data_path
|
|
158
|
+
result[GraphConst.JSON_NODE_KEY] = {}
|
|
159
|
+
for node_id in self.node_map:
|
|
160
|
+
info = self.node_map.get(node_id).to_dict()
|
|
161
|
+
result[GraphConst.JSON_NODE_KEY][node_id] = info
|
|
162
|
+
return result
|
|
163
|
+
|
|
164
|
+
def paging_by_micro_step(self, graph_other=None):
|
|
165
|
+
"""
|
|
166
|
+
给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
|
|
167
|
+
比对场景中,同步更新另一个图graph_other中相应节点的micro step信息
|
|
168
|
+
Args:
|
|
169
|
+
self: 当前graph
|
|
170
|
+
graph_other: 可选参数,另一个graph
|
|
171
|
+
Returns: 分批的数量
|
|
172
|
+
"""
|
|
173
|
+
batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
|
|
174
|
+
for batch_number, nodes in batches_n.items():
|
|
175
|
+
for node in nodes:
|
|
176
|
+
node.micro_step_id = batch_number
|
|
177
|
+
# 在graph_other中更新已匹配节点的micro_step_id
|
|
178
|
+
if graph_other and node.matched_node_link:
|
|
179
|
+
node_other = graph_other.get_node(node.matched_node_link[-1])
|
|
180
|
+
if node_other:
|
|
181
|
+
node_other.micro_step_id = batch_number
|
|
182
|
+
# 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id
|
|
183
|
+
if graph_other:
|
|
184
|
+
for node in graph_other.root.subnodes:
|
|
185
|
+
if node.micro_step_id is None:
|
|
186
|
+
try:
|
|
187
|
+
micro_step_id = int(node.id.split(Const.SEP)[-1])
|
|
188
|
+
except ValueError:
|
|
189
|
+
micro_step_id = 0
|
|
190
|
+
node.micro_step_id = micro_step_id
|
|
191
|
+
return len(batches_n)
|
|
192
|
+
|
|
193
|
+
def overflow_check(self):
|
|
194
|
+
detector = AnomalyDetector(self.dump_data)
|
|
195
|
+
detector.analyze().filter()
|
|
196
|
+
|
|
197
|
+
for node_id, _node in self.node_map.items():
|
|
198
|
+
if detector.has_overflow(node_id):
|
|
199
|
+
lv = detector.get_overflow_level(node_id)
|
|
200
|
+
_node.set_overflow_level(lv)
|