mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
msprobe/core/service.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
import copy
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
import functools
|
|
21
|
+
import os
|
|
22
|
+
|
|
23
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
24
|
+
from msprobe.core.common.file_utils import create_directory
|
|
25
|
+
from msprobe.core.common.runtime import Runtime
|
|
26
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
27
|
+
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
28
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
29
|
+
from msprobe.core.hook_manager import BaseHookManager
|
|
30
|
+
from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseService(ABC):
|
|
34
|
+
def __init__(self, config):
|
|
35
|
+
self.config = copy.deepcopy(config)
|
|
36
|
+
self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置
|
|
37
|
+
self.model = None
|
|
38
|
+
self.data_collector = build_data_collector(self.config)
|
|
39
|
+
self.attl_manager = None
|
|
40
|
+
self.current_iter = 0
|
|
41
|
+
self.loop = 0
|
|
42
|
+
self.init_step = 0
|
|
43
|
+
self.cur_token_id = 0
|
|
44
|
+
self.first_start = True
|
|
45
|
+
self.primitive_switch = False
|
|
46
|
+
self.current_rank = None
|
|
47
|
+
self.dump_iter_dir = None
|
|
48
|
+
self.should_stop_service = False
|
|
49
|
+
self.ori_customer_func = {}
|
|
50
|
+
self.debug_variable_counter = None
|
|
51
|
+
self.currrent_step_first_debug_save = True
|
|
52
|
+
self.logger = None # 子类中注入
|
|
53
|
+
self.api_register = None # 子类中注入
|
|
54
|
+
self.api_template = None # 子类中注入
|
|
55
|
+
self.hook_manager = None # 子类中注入
|
|
56
|
+
self._init_specific_components()
|
|
57
|
+
self._register_api_hook()
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def _is_debug_level(self):
|
|
61
|
+
return self.config.level == Const.LEVEL_DEBUG
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def _is_l2_level(self):
|
|
65
|
+
return self.config.level == Const.LEVEL_L2
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def _is_mix_level(self):
|
|
69
|
+
return self.config.level == Const.LEVEL_MIX
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def _is_need_module_hook(self):
|
|
73
|
+
return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def _is_need_api_hook(self):
|
|
77
|
+
return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def _is_no_dump_step(self):
|
|
81
|
+
return (self.config.step and self.current_iter not in self.config.step)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def _is_no_dump_rank(self):
|
|
85
|
+
return (self.config.rank and self.current_rank not in self.config.rank)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def _need_tensor_data(self):
|
|
89
|
+
"""判断是否需要采集tensor数据"""
|
|
90
|
+
return bool(
|
|
91
|
+
self.config.task in self.data_collector.tasks_need_tensor_data or
|
|
92
|
+
(self.config.task == Const.STATISTICS and self.config.tensor_list)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def _is_online_run_ut(self):
|
|
97
|
+
return getattr(self.config, "online_run_ut", False)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def _get_framework_type(self):
|
|
102
|
+
"""获取框架类型"""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def _get_current_rank():
|
|
108
|
+
"""获取当前rank_id"""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _change_jit_switch(status):
|
|
113
|
+
"""修改JitDump开关,mindspore子类重写"""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
def start(self, model=None, token_range=None):
|
|
117
|
+
"""通用start模板"""
|
|
118
|
+
self._process_iteration()
|
|
119
|
+
if self._is_debug_level:
|
|
120
|
+
return
|
|
121
|
+
if self._need_stop_service():
|
|
122
|
+
return
|
|
123
|
+
self.model = model
|
|
124
|
+
self.cur_token_id = 0
|
|
125
|
+
if self.first_start:
|
|
126
|
+
try:
|
|
127
|
+
self.current_rank = self._get_current_rank()
|
|
128
|
+
except DistributedNotInitializedError:
|
|
129
|
+
self.current_rank = None
|
|
130
|
+
Runtime.current_rank = self.current_rank
|
|
131
|
+
if self._is_no_dump_rank:
|
|
132
|
+
return
|
|
133
|
+
self._register_hook()
|
|
134
|
+
if self._is_need_module_hook:
|
|
135
|
+
self._register_module_hook()
|
|
136
|
+
self.first_start = False
|
|
137
|
+
|
|
138
|
+
if token_range:
|
|
139
|
+
self._register_infer_count_hook(self.model, token_range)
|
|
140
|
+
self.logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
|
|
141
|
+
if token_range is None:
|
|
142
|
+
Runtime.is_running = True
|
|
143
|
+
self.primitive_switch = True
|
|
144
|
+
self._change_jit_switch(True)
|
|
145
|
+
self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
146
|
+
if self._is_online_run_ut:
|
|
147
|
+
self._run_ut_dispatch(True)
|
|
148
|
+
else:
|
|
149
|
+
self.create_dirs()
|
|
150
|
+
self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
151
|
+
|
|
152
|
+
def stop(self):
|
|
153
|
+
"""通用stop模板"""
|
|
154
|
+
if self._is_debug_level or self.should_stop_service:
|
|
155
|
+
return
|
|
156
|
+
if self._is_no_dump_step or self._is_no_dump_rank:
|
|
157
|
+
return
|
|
158
|
+
self.logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
|
|
159
|
+
"Please set debugger.start() to turn on the dump switch again. ")
|
|
160
|
+
Runtime.is_running = False
|
|
161
|
+
self.primitive_switch = False
|
|
162
|
+
self._change_jit_switch(False)
|
|
163
|
+
if self._is_l2_level:
|
|
164
|
+
return
|
|
165
|
+
if self._is_online_run_ut:
|
|
166
|
+
self._run_ut_dispatch(False)
|
|
167
|
+
self._process_async_dump()
|
|
168
|
+
self.data_collector.write_json()
|
|
169
|
+
|
|
170
|
+
def step(self):
|
|
171
|
+
"""通用step处理"""
|
|
172
|
+
if self.should_stop_service:
|
|
173
|
+
return
|
|
174
|
+
self._process_async_dump()
|
|
175
|
+
self.data_collector.write_json()
|
|
176
|
+
self.currrent_step_first_debug_save = True
|
|
177
|
+
self.loop += 1
|
|
178
|
+
self._reset_status()
|
|
179
|
+
|
|
180
|
+
def save(self, variable, name, save_backward):
|
|
181
|
+
'''
|
|
182
|
+
Args:
|
|
183
|
+
variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
|
|
184
|
+
name: str
|
|
185
|
+
save_backward: boolean
|
|
186
|
+
Return:
|
|
187
|
+
void
|
|
188
|
+
'''
|
|
189
|
+
if not self._is_debug_level:
|
|
190
|
+
return
|
|
191
|
+
self.current_iter = self.loop + self.init_step
|
|
192
|
+
if self._is_no_dump_step:
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
if self.currrent_step_first_debug_save:
|
|
196
|
+
try:
|
|
197
|
+
self.current_rank = self._get_current_rank()
|
|
198
|
+
except DistributedNotInitializedError:
|
|
199
|
+
self.current_rank = None
|
|
200
|
+
|
|
201
|
+
self.create_dirs()
|
|
202
|
+
self.debug_variable_counter = defaultdict(int)
|
|
203
|
+
self.currrent_step_first_debug_save = False
|
|
204
|
+
|
|
205
|
+
count = self.debug_variable_counter[name]
|
|
206
|
+
self.debug_variable_counter[name] += 1
|
|
207
|
+
|
|
208
|
+
name_with_count = f"{name}.{count}"
|
|
209
|
+
grad_name_with_count = f"{name}_grad.{count}"
|
|
210
|
+
|
|
211
|
+
# forward save
|
|
212
|
+
self.data_collector.debug_data_collect_forward(variable, name_with_count)
|
|
213
|
+
|
|
214
|
+
# backward save
|
|
215
|
+
if save_backward:
|
|
216
|
+
self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
|
|
217
|
+
|
|
218
|
+
def register_custom_api(self, module, api_name, api_prefix):
|
|
219
|
+
self.ori_customer_func[str(module) + Const.SEP + api_name] = getattr(module, api_name)
|
|
220
|
+
ApiRegistry.register_custom_api(module, api_name, api_prefix,
|
|
221
|
+
functools.partial(self.build_hook, Const.API), self.api_template)
|
|
222
|
+
|
|
223
|
+
def restore_custom_api(self, module, api):
|
|
224
|
+
ori_func = self.ori_customer_func.get(str(module) + Const.SEP + api)
|
|
225
|
+
if ori_func:
|
|
226
|
+
setattr(module, api, ori_func)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def build_hook(self, hook_type, name):
|
|
230
|
+
return self.hook_manager.build_hook(hook_type, name)
|
|
231
|
+
|
|
232
|
+
def create_dirs(self):
|
|
233
|
+
"""统一目录创建逻辑"""
|
|
234
|
+
create_directory(self.config.dump_path)
|
|
235
|
+
if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE:
|
|
236
|
+
self.dump_iter_dir = os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, f"step{self.current_iter}")
|
|
237
|
+
else:
|
|
238
|
+
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
239
|
+
|
|
240
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
241
|
+
if self._is_l2_level:
|
|
242
|
+
self._create_l2_dirs(cur_rank)
|
|
243
|
+
else:
|
|
244
|
+
self._create_default_dirs(cur_rank)
|
|
245
|
+
|
|
246
|
+
@abstractmethod
|
|
247
|
+
def _init_specific_components(self):
|
|
248
|
+
"""初始化框架特定组件"""
|
|
249
|
+
pass
|
|
250
|
+
|
|
251
|
+
@abstractmethod
|
|
252
|
+
def _register_hook(self):
|
|
253
|
+
"""注册hook函数"""
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
@abstractmethod
|
|
257
|
+
def _register_module_hook(self):
|
|
258
|
+
"""注册模块级别的hook函数"""
|
|
259
|
+
|
|
260
|
+
def _need_stop_service(self):
|
|
261
|
+
if self.should_stop_service:
|
|
262
|
+
return True
|
|
263
|
+
end_service = self.config.step and self.current_iter > max(self.config.step) or \
|
|
264
|
+
self.data_collector and self.data_collector.data_processor.is_terminated
|
|
265
|
+
if end_service:
|
|
266
|
+
if self._is_online_run_ut and self.attl_manager:
|
|
267
|
+
self.attl_manager.attl_stop()
|
|
268
|
+
self.primitive_switch = False
|
|
269
|
+
self._change_jit_switch(False)
|
|
270
|
+
Runtime.is_running = False
|
|
271
|
+
self.should_stop_service = True
|
|
272
|
+
print_tools_ends_info()
|
|
273
|
+
return True
|
|
274
|
+
if self._is_no_dump_step:
|
|
275
|
+
return True
|
|
276
|
+
return False
|
|
277
|
+
|
|
278
|
+
def _register_api_hook(self):
|
|
279
|
+
if self._is_need_api_hook:
|
|
280
|
+
self.api_register.initialize_hook(functools.partial(self.build_hook, Const.API))
|
|
281
|
+
self.api_register.register_all_api()
|
|
282
|
+
self.logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
283
|
+
|
|
284
|
+
def _register_infer_count_hook(self, root_model, token_range):
|
|
285
|
+
"""
|
|
286
|
+
通过root_model执行的轮次来判断当前在第几个token
|
|
287
|
+
param root_model: 需要采集的推理模型
|
|
288
|
+
param token_range: [start, end], 采集infer的token循环范围,左右皆包含在内
|
|
289
|
+
return: None
|
|
290
|
+
"""
|
|
291
|
+
def infer_hook(model, args):
|
|
292
|
+
if self.cur_token_id == token_range[0]:
|
|
293
|
+
Runtime.is_running = True
|
|
294
|
+
self.primitive_switch = True
|
|
295
|
+
self._change_jit_switch(True)
|
|
296
|
+
self.logger.info(f"Current token id: {self.cur_token_id}, start dump infer token.")
|
|
297
|
+
elif token_range[0] < self.cur_token_id <= token_range[1]:
|
|
298
|
+
self.logger.debug(f"Current token id: {self.cur_token_id}.")
|
|
299
|
+
elif self.cur_token_id == token_range[1] + 1:
|
|
300
|
+
Runtime.is_running = False
|
|
301
|
+
self.primitive_switch = False
|
|
302
|
+
self._change_jit_switch(False)
|
|
303
|
+
self.logger.info(
|
|
304
|
+
f"Current token id: {self.cur_token_id}, exceed token_range, early stop dump infer token.")
|
|
305
|
+
self.cur_token_id += 1
|
|
306
|
+
if isinstance(root_model, list):
|
|
307
|
+
root_model = root_model[0]
|
|
308
|
+
self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
|
|
309
|
+
if self._is_online_run_ut:
|
|
310
|
+
return
|
|
311
|
+
root_model.register_forward_pre_hook(infer_hook)
|
|
312
|
+
|
|
313
|
+
def _create_l2_dirs(self, cur_rank):
|
|
314
|
+
create_directory(self.dump_iter_dir)
|
|
315
|
+
kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
|
|
316
|
+
self.config.kernel_config_path = kernel_config_path
|
|
317
|
+
|
|
318
|
+
def _create_default_dirs(self, cur_rank):
|
|
319
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
320
|
+
create_directory(dump_dir)
|
|
321
|
+
|
|
322
|
+
dump_data_dir = None
|
|
323
|
+
if self._need_tensor_data:
|
|
324
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
325
|
+
create_directory(dump_data_dir)
|
|
326
|
+
|
|
327
|
+
self._configure_dump_paths(dump_dir, dump_data_dir)
|
|
328
|
+
|
|
329
|
+
def _configure_dump_paths(self, dump_dir, dump_data_dir):
|
|
330
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
331
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
332
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
333
|
+
dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
334
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
335
|
+
dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
|
|
336
|
+
dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
|
|
337
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
338
|
+
self.data_collector.initialize_json_file(self._get_framework_type)
|
|
339
|
+
|
|
340
|
+
def _process_iteration(self):
|
|
341
|
+
"""处理迭代计数"""
|
|
342
|
+
self.current_iter = self.loop + self.init_step
|
|
343
|
+
self.data_collector.update_iter(self.current_iter)
|
|
344
|
+
Runtime.current_iter = self.current_iter
|
|
345
|
+
|
|
346
|
+
def _process_async_dump(self):
|
|
347
|
+
"""处理异步dump逻辑"""
|
|
348
|
+
if self.config.async_dump and self.config.task in [Const.STATISTICS, Const.TENSOR]:
|
|
349
|
+
self.data_collector.data_processor.dump_async_data()
|
|
350
|
+
|
|
351
|
+
def _reset_status(self):
|
|
352
|
+
"""通用状态重置"""
|
|
353
|
+
self.data_collector.reset_status()
|
|
354
|
+
BaseHookManager.params_grad_info.clear()
|
|
355
|
+
if self._is_l2_level:
|
|
356
|
+
self.data_collector.data_processor.reset_status()
|
|
File without changes
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import multiprocessing
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pandas as pd
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory, save_excel
|
|
25
|
+
from msprobe.core.common.log import logger
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class CompareResult:
|
|
30
|
+
max_abs_error: float
|
|
31
|
+
max_relative_error: float
|
|
32
|
+
same_percentage: float
|
|
33
|
+
first_mismatch_index: int
|
|
34
|
+
percentage_within_thousandth: float
|
|
35
|
+
percentage_within_hundredth: float
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SingleComparator:
|
|
39
|
+
result_header = [
|
|
40
|
+
'step',
|
|
41
|
+
'rank',
|
|
42
|
+
'micro_step',
|
|
43
|
+
'id',
|
|
44
|
+
'shape1',
|
|
45
|
+
'shape2',
|
|
46
|
+
'相同元素百分比(%)',
|
|
47
|
+
'首个不匹配元素索引',
|
|
48
|
+
'最大绝对误差',
|
|
49
|
+
'最大相对误差',
|
|
50
|
+
'误差在千分之一内元素占比(%)',
|
|
51
|
+
'误差在百分之一内元素占比(%)'
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def compare(cls, dir1, dir2, output_path="./msprobe_compare_output", num_processes=8):
|
|
56
|
+
data_dir1 = os.path.join(dir1, "data")
|
|
57
|
+
data_dir2 = os.path.join(dir2, "data")
|
|
58
|
+
check_file_or_directory_path(data_dir1, isdir=True)
|
|
59
|
+
check_file_or_directory_path(data_dir2, isdir=True)
|
|
60
|
+
# 确保输出目录存在,如果不存在则创建
|
|
61
|
+
if not os.path.exists(output_path):
|
|
62
|
+
create_directory(output_path)
|
|
63
|
+
cls.compare_data(data_dir1, data_dir2, output_path, num_processes)
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def compare_arrays(cls, array1, array2) -> CompareResult:
|
|
67
|
+
"""
|
|
68
|
+
比较两个NumPy数组,计算最大绝对误差、最大相对误差和相同元素的百分比
|
|
69
|
+
"""
|
|
70
|
+
# 计算每个维度上的最小尺寸
|
|
71
|
+
min_shape = [min(s1, s2) for s1, s2 in zip(array1.shape, array2.shape)]
|
|
72
|
+
# 截取数组到相同的形状
|
|
73
|
+
sliced_array1 = array1[tuple(slice(0, s) for s in min_shape)]
|
|
74
|
+
sliced_array2 = array2[tuple(slice(0, s) for s in min_shape)]
|
|
75
|
+
|
|
76
|
+
abs_error = np.abs(sliced_array1 - sliced_array2)
|
|
77
|
+
max_abs_error = np.max(abs_error)
|
|
78
|
+
|
|
79
|
+
# 计算相对误差,处理分母为零的情况
|
|
80
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
81
|
+
relative_error = np.abs(sliced_array1 - sliced_array2) / \
|
|
82
|
+
np.maximum(np.abs(sliced_array1), np.abs(sliced_array2))
|
|
83
|
+
relative_error = np.nan_to_num(relative_error)
|
|
84
|
+
max_relative_error = np.max(relative_error)
|
|
85
|
+
|
|
86
|
+
same_elements = np.sum(sliced_array1 == sliced_array2)
|
|
87
|
+
total_elements = sliced_array1.size
|
|
88
|
+
same_percentage = (same_elements / total_elements) * 100
|
|
89
|
+
|
|
90
|
+
# 展平数组
|
|
91
|
+
flat_array1 = sliced_array1.flatten()
|
|
92
|
+
flat_array2 = sliced_array2.flatten()
|
|
93
|
+
|
|
94
|
+
# 计算从第几个元素开始对不上
|
|
95
|
+
mismatch_indices = np.nonzero(flat_array1 != flat_array2)[0]
|
|
96
|
+
first_mismatch_index = mismatch_indices[0] if mismatch_indices.size > 0 else None
|
|
97
|
+
|
|
98
|
+
# 计算误差在千分之一内的元素占比
|
|
99
|
+
threshold = 0.001 * np.maximum(np.abs(sliced_array1), np.abs(sliced_array2))
|
|
100
|
+
error_within_thousandth = np.sum(abs_error <= threshold)
|
|
101
|
+
percentage_within_thousandth = (error_within_thousandth / total_elements) * 100
|
|
102
|
+
|
|
103
|
+
# 计算误差在百分之一内的元素占比
|
|
104
|
+
threshold = 0.01 * np.maximum(np.abs(sliced_array1), np.abs(sliced_array2))
|
|
105
|
+
error_within_hundredth = np.sum(abs_error <= threshold)
|
|
106
|
+
percentage_within_hundredth = (error_within_hundredth / total_elements) * 100
|
|
107
|
+
|
|
108
|
+
return CompareResult(
|
|
109
|
+
max_abs_error,
|
|
110
|
+
max_relative_error,
|
|
111
|
+
same_percentage,
|
|
112
|
+
first_mismatch_index,
|
|
113
|
+
percentage_within_thousandth,
|
|
114
|
+
percentage_within_hundredth
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def get_steps(cls, tag_path):
|
|
119
|
+
for step_folder in os.listdir(tag_path):
|
|
120
|
+
if step_folder.startswith('step'):
|
|
121
|
+
try:
|
|
122
|
+
step = int(step_folder[4:])
|
|
123
|
+
except Exception as e:
|
|
124
|
+
raise RuntimeError(f"parse step number error") from e
|
|
125
|
+
yield step, os.path.join(tag_path, step_folder)
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def get_ranks(cls, step_path):
|
|
129
|
+
for rank_folder in os.listdir(step_path):
|
|
130
|
+
if rank_folder.startswith('rank'):
|
|
131
|
+
try:
|
|
132
|
+
rank = int(rank_folder[4:])
|
|
133
|
+
except Exception as e:
|
|
134
|
+
raise RuntimeError(f"parse rank number error") from e
|
|
135
|
+
yield rank, os.path.join(step_path, rank_folder)
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def get_micro_steps(cls, rank_path):
|
|
139
|
+
for micro_step_folder in os.listdir(rank_path):
|
|
140
|
+
if micro_step_folder.startswith('micro_step'):
|
|
141
|
+
try:
|
|
142
|
+
micro_step = int(micro_step_folder[10:])
|
|
143
|
+
except Exception as e:
|
|
144
|
+
raise RuntimeError(f"parse nicro_step number error") from e
|
|
145
|
+
yield micro_step, os.path.join(rank_path, micro_step_folder)
|
|
146
|
+
else:
|
|
147
|
+
yield 0, rank_path
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def get_arrays(cls, micro_step_path):
|
|
151
|
+
for file in os.listdir(micro_step_path):
|
|
152
|
+
if file.endswith('.npy'):
|
|
153
|
+
try:
|
|
154
|
+
parts = file.rsplit('.', 2)
|
|
155
|
+
if len(parts) > 1 and parts[-2].isdigit():
|
|
156
|
+
array_id = int(parts[-2])
|
|
157
|
+
else:
|
|
158
|
+
array_id = 0
|
|
159
|
+
except ValueError:
|
|
160
|
+
array_id = 0
|
|
161
|
+
yield array_id, os.path.join(micro_step_path, file)
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def get_array_paths(cls, dir_path):
|
|
165
|
+
"""
|
|
166
|
+
获取目录中所有符合结构的NumPy数组文件路径
|
|
167
|
+
"""
|
|
168
|
+
array_paths = {}
|
|
169
|
+
if not os.path.exists(dir_path):
|
|
170
|
+
return array_paths
|
|
171
|
+
for tag in os.listdir(dir_path):
|
|
172
|
+
tag_path = os.path.join(dir_path, tag)
|
|
173
|
+
if not os.path.isdir(tag_path):
|
|
174
|
+
continue
|
|
175
|
+
for step, step_path in cls.get_steps(tag_path):
|
|
176
|
+
for rank, rank_path in cls.get_ranks(step_path):
|
|
177
|
+
for micro_step, micro_step_path in cls.get_micro_steps(rank_path):
|
|
178
|
+
for array_id, array_path in cls.get_arrays(micro_step_path):
|
|
179
|
+
array_paths.setdefault(tag, []).append((step, rank, micro_step, array_id, array_path))
|
|
180
|
+
return array_paths
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def compare_single_tag(cls, tag, array_paths1, array_paths2, output_dir):
|
|
184
|
+
try:
|
|
185
|
+
data = []
|
|
186
|
+
paths1 = array_paths1.get(tag, [])
|
|
187
|
+
paths2 = array_paths2.get(tag, [])
|
|
188
|
+
path_dict1 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths1}
|
|
189
|
+
path_dict2 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths2}
|
|
190
|
+
common_keys = set(path_dict1.keys()) & set(path_dict2.keys())
|
|
191
|
+
for key in common_keys:
|
|
192
|
+
try:
|
|
193
|
+
array1 = np.load(path_dict1[key])
|
|
194
|
+
array2 = np.load(path_dict2[key])
|
|
195
|
+
result = cls.compare_arrays(array1, array2)
|
|
196
|
+
step, rank, micro_step, array_id = key
|
|
197
|
+
data.append([
|
|
198
|
+
step, rank, micro_step, array_id,
|
|
199
|
+
list(array1.shape), list(array2.shape),
|
|
200
|
+
result.same_percentage,
|
|
201
|
+
result.first_mismatch_index,
|
|
202
|
+
result.max_abs_error,
|
|
203
|
+
result.max_relative_error,
|
|
204
|
+
result.percentage_within_thousandth,
|
|
205
|
+
result.percentage_within_hundredth
|
|
206
|
+
])
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.error(f"Error comparing {path_dict1[key]} and {path_dict2[key]}: {e}")
|
|
209
|
+
|
|
210
|
+
df = pd.DataFrame(data, columns=SingleComparator.result_header)
|
|
211
|
+
df = df.sort_values(by=['step', 'rank', 'micro_step', 'id'])
|
|
212
|
+
# 构建输出文件的完整路径
|
|
213
|
+
output_file_path = os.path.join(output_dir, f'{tag}.xlsx')
|
|
214
|
+
save_excel(output_file_path, df)
|
|
215
|
+
except Exception as e:
|
|
216
|
+
logger.error(f"Error processing tag {tag}: {e}")
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def compare_data(cls, dir1, dir2, output_dir, num_processes=8):
|
|
220
|
+
"""
|
|
221
|
+
比较两个目录中的NumPy数组文件,并将结果保存到指定目录的Excel文件中
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
array_paths1 = cls.get_array_paths(dir1)
|
|
225
|
+
array_paths2 = cls.get_array_paths(dir2)
|
|
226
|
+
|
|
227
|
+
all_tags = set(array_paths1.keys()) | set(array_paths2.keys())
|
|
228
|
+
|
|
229
|
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
|
230
|
+
args = [(tag, array_paths1, array_paths2, output_dir) for tag in all_tags]
|
|
231
|
+
try:
|
|
232
|
+
results = pool.starmap_async(cls.compare_single_tag, args)
|
|
233
|
+
with tqdm(total=len(all_tags), desc="Processing data") as pbar:
|
|
234
|
+
while not results.ready():
|
|
235
|
+
pbar.n = len(all_tags) - results._number_left
|
|
236
|
+
pbar.refresh()
|
|
237
|
+
results.wait()
|
|
238
|
+
results.get()
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.error(f"Multiprocessing error: {e}")
|
|
241
|
+
finally:
|
|
242
|
+
pool.close()
|
|
243
|
+
pool.join()
|