mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import multiprocessing
|
|
19
|
+
from functools import partial
|
|
20
|
+
|
|
21
|
+
import pandas as pd
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory
|
|
25
|
+
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
26
|
+
from msprobe.core.common.utils import CompareException, add_time_with_xlsx
|
|
27
|
+
from msprobe.core.compare.utils import table_value_is_valid
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def check_compare_result_name(file_name):
|
|
31
|
+
"""
|
|
32
|
+
check whether the compare result name is as expected
|
|
33
|
+
"""
|
|
34
|
+
single_rank_pattern = r"^compare_result_rank-rank_\d{14}.xlsx$"
|
|
35
|
+
multi_ranks_pattern = r"^compare_result_rank(\d+)-rank\1_\d{14}.xlsx$"
|
|
36
|
+
if re.match(multi_ranks_pattern, file_name):
|
|
37
|
+
return True
|
|
38
|
+
if re.match(single_rank_pattern, file_name):
|
|
39
|
+
logger.warning("Single rank compare result do not need to be merged.")
|
|
40
|
+
return False
|
|
41
|
+
logger.error(f"Wrong compare result name: {file_name}, please check!")
|
|
42
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def reorder_path(compare_result_path_list):
|
|
46
|
+
"""
|
|
47
|
+
reorder compare results by rank num
|
|
48
|
+
"""
|
|
49
|
+
rank_pattern = r"compare_result_rank(\d+)-rank"
|
|
50
|
+
reorder_path_list = sorted(
|
|
51
|
+
compare_result_path_list,
|
|
52
|
+
key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1))
|
|
53
|
+
)
|
|
54
|
+
return reorder_path_list
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_result_path(input_dir):
|
|
58
|
+
"""
|
|
59
|
+
get rank ordered compare result file path list
|
|
60
|
+
"""
|
|
61
|
+
compare_result_path_list = [os.path.join(input_dir, f)
|
|
62
|
+
for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)]
|
|
63
|
+
filt_compare_result_path_list = []
|
|
64
|
+
for file_path in compare_result_path_list:
|
|
65
|
+
file_name = os.path.basename(file_path)
|
|
66
|
+
if check_compare_result_name(file_name):
|
|
67
|
+
compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
|
|
68
|
+
compare_result_path = compare_result_path_checker.common_check()
|
|
69
|
+
filt_compare_result_path_list.append(compare_result_path)
|
|
70
|
+
|
|
71
|
+
filt_compare_result_path_list = reorder_path(filt_compare_result_path_list) # 多卡比对结果按rank序号重新排序
|
|
72
|
+
|
|
73
|
+
if len(filt_compare_result_path_list) < 2:
|
|
74
|
+
logger.warning("Number of compare result is no more than 1, no need to merge.") # 单卡结果无需合并,直接退出
|
|
75
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
76
|
+
return filt_compare_result_path_list
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_dump_mode(result_df, rank_num):
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
get dump mode from header of first compare result table
|
|
83
|
+
"""
|
|
84
|
+
header = result_df.columns.tolist()
|
|
85
|
+
if header in [CompareConst.COMPARE_RESULT_HEADER + [CompareConst.DATA_NAME],
|
|
86
|
+
CompareConst.COMPARE_RESULT_HEADER_STACK + [CompareConst.DATA_NAME]]:
|
|
87
|
+
return Const.ALL
|
|
88
|
+
elif header in [CompareConst.SUMMARY_COMPARE_RESULT_HEADER, CompareConst.SUMMARY_COMPARE_RESULT_HEADER_STACK]:
|
|
89
|
+
return Const.SUMMARY
|
|
90
|
+
elif header in [CompareConst.MD5_COMPARE_RESULT_HEADER, CompareConst.MD5_COMPARE_RESULT_HEADER_STACK]:
|
|
91
|
+
return Const.MD5
|
|
92
|
+
else:
|
|
93
|
+
logger.warning(f"A valid dump task can not be identified from rank{rank_num} compare result, please check! "
|
|
94
|
+
f"The compare result will not be shown in merged result.")
|
|
95
|
+
return ""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def check_index_dump_mode_consistent(dump_mode, rank_num):
|
|
99
|
+
"""
|
|
100
|
+
check compare index to merge is consistent with dump mode
|
|
101
|
+
if compare_index_list is None, return all compare_indexes of dump mode
|
|
102
|
+
"""
|
|
103
|
+
if dump_mode == Const.MD5:
|
|
104
|
+
logger.warning(f"Rank{rank_num} compare result is 'md5' dump task and does not support merging result, please "
|
|
105
|
+
f"check! The compare result will not be shown in merged result.")
|
|
106
|
+
return []
|
|
107
|
+
|
|
108
|
+
dump_mode_compare_index_map = {
|
|
109
|
+
Const.ALL: CompareConst.ALL_COMPARE_INDEX,
|
|
110
|
+
Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX
|
|
111
|
+
}
|
|
112
|
+
valid_compare_index = dump_mode_compare_index_map.get(dump_mode)
|
|
113
|
+
|
|
114
|
+
share_list = list(share_compare_index_list)
|
|
115
|
+
|
|
116
|
+
# 如果传入的compare_index_list为空,则比对指标为dump_mode对应的全部比对指标
|
|
117
|
+
if not share_list:
|
|
118
|
+
share_compare_index_list.extend(valid_compare_index)
|
|
119
|
+
return list(share_compare_index_list)
|
|
120
|
+
if set(share_list).issubset(valid_compare_index):
|
|
121
|
+
return share_list
|
|
122
|
+
else:
|
|
123
|
+
invalid_compare_index = set(valid_compare_index) - set(share_list)
|
|
124
|
+
logger.warning(f"Compare indexes in rank{rank_num} compare result are not consistent with "
|
|
125
|
+
f"those in other compare results, please check!")
|
|
126
|
+
logger.warning(f"The compare result will not be shown in merged result.")
|
|
127
|
+
logger.warning(f"The invalid compare indexes: {invalid_compare_index}")
|
|
128
|
+
return []
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def extract_api_full_name(api_list, result_df, rank_num):
|
|
132
|
+
"""
|
|
133
|
+
find api full name from compare result according to api list
|
|
134
|
+
"""
|
|
135
|
+
api_full_name_list = []
|
|
136
|
+
for api in api_list:
|
|
137
|
+
api_pat = api + Const.SEP
|
|
138
|
+
escaped_api_pat = api_pat.replace('.', r'\.')
|
|
139
|
+
single_api_full_name_list = result_df.loc[
|
|
140
|
+
result_df[CompareConst.NPU_NAME].str.contains(escaped_api_pat, na=False), CompareConst.NPU_NAME].tolist()
|
|
141
|
+
if len(single_api_full_name_list) == 0:
|
|
142
|
+
logger.warning(f"{api} not found in rank{rank_num} compare result.")
|
|
143
|
+
continue
|
|
144
|
+
api_full_name_list.extend(single_api_full_name_list)
|
|
145
|
+
return api_full_name_list
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def search_api_index_result(api_list, compare_index_list, result_df, rank_num, compare_index_dict):
|
|
149
|
+
"""
|
|
150
|
+
parsing single rank compare result into the intermediate target dict
|
|
151
|
+
{
|
|
152
|
+
compare_index1: {
|
|
153
|
+
api_full_name1:{
|
|
154
|
+
rank1: value,
|
|
155
|
+
},
|
|
156
|
+
api_full_name2,
|
|
157
|
+
...
|
|
158
|
+
},
|
|
159
|
+
compare_index2: {},
|
|
160
|
+
...
|
|
161
|
+
}
|
|
162
|
+
"""
|
|
163
|
+
api_full_name_list = extract_api_full_name(api_list, result_df, rank_num)
|
|
164
|
+
for compare_index in compare_index_list:
|
|
165
|
+
api_index_dict = {}
|
|
166
|
+
for api_full_name in api_full_name_list:
|
|
167
|
+
table_value_check(api_full_name)
|
|
168
|
+
row_num = result_df.index[result_df[CompareConst.NPU_NAME] == api_full_name].tolist()[0]
|
|
169
|
+
index_value = result_df.loc[row_num, compare_index]
|
|
170
|
+
table_value_check(index_value)
|
|
171
|
+
api_index_dict.setdefault(api_full_name, {})[rank_num] = index_value # update api_index_dict
|
|
172
|
+
compare_index_dict[compare_index] = api_index_dict
|
|
173
|
+
return compare_index_dict
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def table_value_check(value):
|
|
177
|
+
if not table_value_is_valid(value):
|
|
178
|
+
raise RuntimeError(
|
|
179
|
+
f"Malicious value [{value}] is not allowed to be written into the merged xlsx.")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def result_process(compare_result_path_list, api_list):
|
|
183
|
+
"""
|
|
184
|
+
process compare results into target intermediate dict list
|
|
185
|
+
"""
|
|
186
|
+
compare_index_dict_list = []
|
|
187
|
+
rank_num_list = []
|
|
188
|
+
compare_index_list = []
|
|
189
|
+
|
|
190
|
+
for compare_result_path in compare_result_path_list:
|
|
191
|
+
compare_index_dict = {}
|
|
192
|
+
result_df = read_xlsx(compare_result_path)
|
|
193
|
+
|
|
194
|
+
rank_pattern = r"compare_result_rank(\d+)-rank"
|
|
195
|
+
rank_num = int(re.search(rank_pattern, os.path.basename(compare_result_path)).group(1))
|
|
196
|
+
logger.info(f"Parsing rank{rank_num} compare result...")
|
|
197
|
+
if not result_df.empty:
|
|
198
|
+
dump_mode = get_dump_mode(result_df, rank_num)
|
|
199
|
+
if dump_mode == "":
|
|
200
|
+
return [], [], []
|
|
201
|
+
# 因为compare_index是指定的,固定不变,所以一旦compare_index是确定的,dump_mode也是确定的,
|
|
202
|
+
# 所以只要校验compare_index和dump_mode一致性就能保证所有rank的结果都是dump_mode一致的
|
|
203
|
+
compare_index_list = check_index_dump_mode_consistent(dump_mode, rank_num)
|
|
204
|
+
if len(compare_index_list) == 0:
|
|
205
|
+
return [], [], []
|
|
206
|
+
compare_index_dict = search_api_index_result(api_list, share_compare_index_list,
|
|
207
|
+
result_df, rank_num, compare_index_dict)
|
|
208
|
+
compare_index_dict_list.append(compare_index_dict)
|
|
209
|
+
rank_num_list.append(rank_num)
|
|
210
|
+
else:
|
|
211
|
+
logger.warning(f"Rank{rank_num} compare result is empty and will not shown in merged result.")
|
|
212
|
+
|
|
213
|
+
return compare_index_dict_list, rank_num_list, compare_index_list
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def handle_multi_process(func, func_args, lock):
|
|
217
|
+
compare_result_path_list, api_list = func_args
|
|
218
|
+
|
|
219
|
+
result_num = len(compare_result_path_list)
|
|
220
|
+
process_num = int((multiprocessing.cpu_count() + 1) / 2)
|
|
221
|
+
if result_num <= process_num:
|
|
222
|
+
process_num = result_num
|
|
223
|
+
chunks = [[compare_result_path] for compare_result_path in compare_result_path_list]
|
|
224
|
+
else:
|
|
225
|
+
chunk_size = result_num // process_num
|
|
226
|
+
chunks = [compare_result_path_list[i:i + chunk_size] for i in range(0, result_num, chunk_size)]
|
|
227
|
+
|
|
228
|
+
pool = multiprocessing.Pool(process_num)
|
|
229
|
+
|
|
230
|
+
def err_call(args):
|
|
231
|
+
logger.error('Multiprocess merge result failed! Reason: {}'.format(args))
|
|
232
|
+
try:
|
|
233
|
+
pool.terminate()
|
|
234
|
+
except OSError:
|
|
235
|
+
logger.error("Pool terminate failed")
|
|
236
|
+
|
|
237
|
+
progress_bar = tqdm(total=result_num, desc="Compare Result Parsing Process", unit="num", ncols=100)
|
|
238
|
+
|
|
239
|
+
def update_progress(size, progress_lock, extra_param=None):
|
|
240
|
+
with progress_lock:
|
|
241
|
+
progress_bar.update(size)
|
|
242
|
+
|
|
243
|
+
results = []
|
|
244
|
+
for chunk in chunks:
|
|
245
|
+
chunk_size = len(chunk)
|
|
246
|
+
result = pool.apply_async(func, # pool.apply_async立即返回ApplyResult对象,因此results中结果是顺序的
|
|
247
|
+
args=(chunk, api_list),
|
|
248
|
+
error_callback=err_call,
|
|
249
|
+
callback=partial(update_progress, chunk_size, lock)
|
|
250
|
+
)
|
|
251
|
+
results.append(result)
|
|
252
|
+
|
|
253
|
+
all_compare_index_dict_list = []
|
|
254
|
+
all_rank_num_list = []
|
|
255
|
+
all_compare_index_list_list = []
|
|
256
|
+
for result in results:
|
|
257
|
+
compare_index_dict, rank_num_list, compare_index_list = result.get()
|
|
258
|
+
all_compare_index_dict_list.append(compare_index_dict)
|
|
259
|
+
all_rank_num_list.append(rank_num_list)
|
|
260
|
+
all_compare_index_list_list.append(compare_index_list)
|
|
261
|
+
|
|
262
|
+
pool.close()
|
|
263
|
+
pool.join()
|
|
264
|
+
|
|
265
|
+
if not any(all_compare_index_dict_list):
|
|
266
|
+
logger.warning("Nothing to merge.")
|
|
267
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
268
|
+
|
|
269
|
+
return all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def generate_result_df(api_index_dict, header):
|
|
273
|
+
"""
|
|
274
|
+
Generates a DataFrame from the given api_index_dict and header.
|
|
275
|
+
api_index_dict:
|
|
276
|
+
{
|
|
277
|
+
api_full_name1:{
|
|
278
|
+
rank1: value,
|
|
279
|
+
},
|
|
280
|
+
api_full_name2:{
|
|
281
|
+
rank1: value
|
|
282
|
+
},
|
|
283
|
+
...
|
|
284
|
+
}
|
|
285
|
+
"""
|
|
286
|
+
result = []
|
|
287
|
+
for api_full_name, rank_value_dict in api_index_dict.items():
|
|
288
|
+
result_item = [api_full_name]
|
|
289
|
+
result_item.extend(rank_value_dict.values())
|
|
290
|
+
result.append(result_item)
|
|
291
|
+
return pd.DataFrame(result, columns=header, dtype="object")
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list, output_dir):
|
|
295
|
+
"""
|
|
296
|
+
generate merge result from the intermediate dict.
|
|
297
|
+
one compare index, one sheet
|
|
298
|
+
"""
|
|
299
|
+
file_name = add_time_with_xlsx("multi_ranks_compare_merge")
|
|
300
|
+
output_path = os.path.join(output_dir, file_name)
|
|
301
|
+
|
|
302
|
+
compare_index_list = None
|
|
303
|
+
for item in all_compare_index_list_list:
|
|
304
|
+
if item:
|
|
305
|
+
compare_index_list = item
|
|
306
|
+
break
|
|
307
|
+
if not compare_index_list:
|
|
308
|
+
logger.error("No compare index recognized, please check!")
|
|
309
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
310
|
+
|
|
311
|
+
all_result_df_list = []
|
|
312
|
+
for compare_index_dict_list, rank_num_list in zip(all_compare_index_dict_list, all_rank_num_list):
|
|
313
|
+
for compare_index_dict, rank_num in zip(compare_index_dict_list, rank_num_list):
|
|
314
|
+
header = [CompareConst.NPU_NAME, "rank" + str(rank_num)]
|
|
315
|
+
result_df_list = []
|
|
316
|
+
for _, api_index_dict in compare_index_dict.items():
|
|
317
|
+
result_df = generate_result_df(api_index_dict, header)
|
|
318
|
+
result_df_list.append(result_df)
|
|
319
|
+
all_result_df_list.append(result_df_list)
|
|
320
|
+
|
|
321
|
+
merge_df_list = df_merge(all_result_df_list)
|
|
322
|
+
final_result_df_list = []
|
|
323
|
+
for i, df in enumerate(merge_df_list):
|
|
324
|
+
# merge_df_list中df与compare_index_list中compare_index一一对应
|
|
325
|
+
final_result_df_list.append((df, compare_index_list[i]))
|
|
326
|
+
save_excel(output_path, final_result_df_list)
|
|
327
|
+
logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def df_merge(all_result_df_list):
|
|
331
|
+
"""
|
|
332
|
+
merge different rank result_df
|
|
333
|
+
"""
|
|
334
|
+
if len(all_result_df_list) == 0:
|
|
335
|
+
logger.warning("Nothing to merge.")
|
|
336
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
337
|
+
if len(all_result_df_list) == 1:
|
|
338
|
+
logger.info("Only one compare result gets merge data.")
|
|
339
|
+
merge_df_base = all_result_df_list[0]
|
|
340
|
+
for sublist in all_result_df_list[1:]:
|
|
341
|
+
for i, sub_df in enumerate(sublist):
|
|
342
|
+
merge_df_base[i] = pd.merge(merge_df_base[i], sub_df, on=CompareConst.NPU_NAME, how='outer')
|
|
343
|
+
for i, value in enumerate(merge_df_base):
|
|
344
|
+
merge_df_base[i] = value.reindex(
|
|
345
|
+
columns=[CompareConst.NPU_NAME] + [col for col in value.columns if col != CompareConst.NPU_NAME])
|
|
346
|
+
return merge_df_base
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
share_compare_index_list = []
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def initialize_compare_index(config):
|
|
353
|
+
global share_compare_index_list
|
|
354
|
+
manager = multiprocessing.Manager()
|
|
355
|
+
share_compare_index_list = manager.list(config.get("compare_index", [])) # 创建共享全局列表
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def merge_result(input_dir, output_dir, config_path):
|
|
359
|
+
input_dir = FileChecker(input_dir, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
360
|
+
create_directory(output_dir)
|
|
361
|
+
|
|
362
|
+
compare_result_path_list = get_result_path(input_dir) # 获得的input_dir中所有比对结果件的全路径,数量少于2,便提示退出
|
|
363
|
+
|
|
364
|
+
config = load_yaml(config_path)
|
|
365
|
+
if not config:
|
|
366
|
+
logger.error('config.yaml is empty, please check.')
|
|
367
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
368
|
+
api_list = config.get('api')
|
|
369
|
+
if not api_list:
|
|
370
|
+
logger.error('The APIs required to merge data were not found')
|
|
371
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
372
|
+
|
|
373
|
+
# 初始化共享全局变量share_compare_index_list
|
|
374
|
+
initialize_compare_index(config)
|
|
375
|
+
|
|
376
|
+
func_args = (compare_result_path_list, api_list)
|
|
377
|
+
all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list = (
|
|
378
|
+
handle_multi_process(result_process, func_args, multiprocessing.Manager().RLock()))
|
|
379
|
+
|
|
380
|
+
generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list, output_dir)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.compare.merge_result.merge_result import merge_result
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _merge_result_parser(parser):
|
|
20
|
+
parser.add_argument("-i", "--input_dir", dest="input_dir", type=str,
|
|
21
|
+
help="<Required> The compare result path, a dir.", required=True)
|
|
22
|
+
parser.add_argument("-o", "--output_dir", dest="output_dir", type=str,
|
|
23
|
+
help="<Required> The result merge output path, a dir.", required=True)
|
|
24
|
+
parser.add_argument("-config", "--config-path", dest="config_path", type=str,
|
|
25
|
+
help="<Required> Yaml path containing distribute APIs and compare indexes for merging data "
|
|
26
|
+
"from compare results.",
|
|
27
|
+
required=True)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def merge_result_cli(args):
|
|
31
|
+
merge_result(args.input_dir, args.output_dir, args.config_path)
|
|
@@ -23,7 +23,7 @@ from msprobe.core.common.const import CompareConst
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def _handle_multi_process(func, input_parma, result_df, lock):
|
|
26
|
-
process_num = int((multiprocessing.cpu_count() + 1)
|
|
26
|
+
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
27
27
|
op_name_mapping_dict = read_dump_data(result_df)
|
|
28
28
|
|
|
29
29
|
df_chunk_size = len(result_df) // process_num
|
|
@@ -63,7 +63,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
|
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
66
|
-
process_num = int((multiprocessing.cpu_count() + 1) // 4)
|
|
66
|
+
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
67
67
|
df_chunk_size = len(result_df) // process_num
|
|
68
68
|
if df_chunk_size > 0:
|
|
69
69
|
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|