mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.import functools
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import multiprocessing
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Dict, List, Tuple, Optional, Any
|
|
20
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
21
|
+
from functools import partial
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
import pandas as pd
|
|
25
|
+
import numpy as np
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
|
|
28
|
+
from msprobe.core.common.log import logger
|
|
29
|
+
from msprobe.core.common.utils import CompareException
|
|
30
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
31
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \
|
|
32
|
+
check_path_before_create, load_npy
|
|
33
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
34
|
+
from msprobe.core.compare.npy_compare import compare_ops_apply
|
|
35
|
+
from msprobe.core.compare.multiprocessing_compute import check_accuracy
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]:
|
|
39
|
+
"""
|
|
40
|
+
高级目录比对函数,完全镜像输入目录结构
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
input_params: 包含npu_path和bench_path的字典
|
|
44
|
+
output_dir: 输出根目录
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
当输入目录是平铺npy文件时返回DataFrame,否则返回None
|
|
48
|
+
"""
|
|
49
|
+
npu_root = Path(input_params.get('npu_path'))
|
|
50
|
+
bench_root = Path(input_params.get('bench_path'))
|
|
51
|
+
name_map_dict = input_params.get('map_dict', {})
|
|
52
|
+
file_tree = build_mirror_file_tree(npu_root, bench_root)
|
|
53
|
+
|
|
54
|
+
# 处理文件比对
|
|
55
|
+
with ProcessPoolExecutor() as executor:
|
|
56
|
+
results = list(tqdm(
|
|
57
|
+
executor.map(
|
|
58
|
+
partial(process_directory_pair, name_map_dict=name_map_dict, output_dir=output_dir),
|
|
59
|
+
file_tree.items()
|
|
60
|
+
),
|
|
61
|
+
total=len(file_tree),
|
|
62
|
+
desc="Processing directories"
|
|
63
|
+
))
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def process_directory_pair(item: Tuple[Path, Tuple[Path, Path]], name_map_dict: Dict, output_dir: str):
|
|
68
|
+
"""
|
|
69
|
+
处理一个目录对
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
item: (相对路径, (npu目录, bench目录))元组
|
|
73
|
+
output_dir: 输出根目录
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
比对结果的DataFrame(仅平铺结构时返回)
|
|
77
|
+
"""
|
|
78
|
+
rel_path, (npu_dir, bench_dir) = item
|
|
79
|
+
|
|
80
|
+
# 创建镜像输出目录
|
|
81
|
+
output_path = Path(output_dir) / rel_path
|
|
82
|
+
create_directory(output_path)
|
|
83
|
+
|
|
84
|
+
# 生成文件映射
|
|
85
|
+
npu_files = find_npy_files(npu_dir)
|
|
86
|
+
bench_files = find_npy_files(bench_dir)
|
|
87
|
+
map_dict = generate_map_dict(npu_files, bench_files, name_map_dict)
|
|
88
|
+
|
|
89
|
+
if not map_dict:
|
|
90
|
+
logger.warning(f"No file pairs found in {rel_path}")
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
# 执行比对
|
|
94
|
+
result_df = do_multi_process(process_chunk, map_dict)
|
|
95
|
+
check_path_before_create(output_path)
|
|
96
|
+
# 保存结果
|
|
97
|
+
result_path = os.path.join(output_path, 'result.csv')
|
|
98
|
+
write_df_to_csv(result_df, result_path)
|
|
99
|
+
logger.info(f"Results saved to {result_path}")
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple[Path, Path]]:
|
|
104
|
+
"""
|
|
105
|
+
构建镜像文件树,键为相对路径,值为(npu_path, bench_path)元组
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
npu_root: NPU数据根目录
|
|
109
|
+
bench_root: 基准数据根目录
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
文件树字典
|
|
113
|
+
"""
|
|
114
|
+
file_tree = {}
|
|
115
|
+
|
|
116
|
+
# 遍历NPU目录构建树结构
|
|
117
|
+
for npu_path in npu_root.rglob('*.npy'):
|
|
118
|
+
dir_path = npu_path.relative_to(npu_root).parent
|
|
119
|
+
npu_dir_pair = os.path.join(npu_root, dir_path)
|
|
120
|
+
bench_dir_pair = os.path.join(bench_root, dir_path)
|
|
121
|
+
try:
|
|
122
|
+
check_file_or_directory_path(bench_dir_pair, isdir=True)
|
|
123
|
+
except FileCheckException:
|
|
124
|
+
continue
|
|
125
|
+
# 添加到文件树
|
|
126
|
+
if dir_path not in file_tree:
|
|
127
|
+
file_tree[dir_path] = (npu_dir_pair, bench_dir_pair)
|
|
128
|
+
|
|
129
|
+
return file_tree
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def find_npy_files(directory):
|
|
133
|
+
npy_files_dict = {}
|
|
134
|
+
for root, _, files in os.walk(directory):
|
|
135
|
+
for file in files:
|
|
136
|
+
if file.endswith(".npy"):
|
|
137
|
+
# 分割文件名并去掉最后两个元素
|
|
138
|
+
file_name = file.split('_')
|
|
139
|
+
if len(file_name) < 2:
|
|
140
|
+
continue
|
|
141
|
+
key = '_'.join(file_name[:-2])
|
|
142
|
+
# 文件的完整路径
|
|
143
|
+
value = os.path.join(root, file)
|
|
144
|
+
# 添加到字典中
|
|
145
|
+
if not npy_files_dict.get(key):
|
|
146
|
+
npy_files_dict[key] = []
|
|
147
|
+
npy_files_dict[key].append(value)
|
|
148
|
+
return npy_files_dict
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
|
|
152
|
+
for k, npu_file_list in npu_file_dict.items():
|
|
153
|
+
bench_file_list = bench_file_dict.get(k)
|
|
154
|
+
if not bench_file_list and k in name_map_dict:
|
|
155
|
+
bench_file_list = bench_file_dict.get(name_map_dict.get(k))
|
|
156
|
+
bench_length = len(bench_file_list)
|
|
157
|
+
if not (bench_file_list and bench_length):
|
|
158
|
+
continue
|
|
159
|
+
result_dict = {}
|
|
160
|
+
for i, npu_file in enumerate(npu_file_list):
|
|
161
|
+
if i >= bench_length:
|
|
162
|
+
break
|
|
163
|
+
bench_file = bench_file_list[i]
|
|
164
|
+
result_dict[f"{k}_{i}"] = (npu_file, bench_file)
|
|
165
|
+
return result_dict
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def do_multi_process(func, map_dict):
|
|
169
|
+
lock = multiprocessing.Manager().RLock()
|
|
170
|
+
result_len = len(map_dict)
|
|
171
|
+
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
172
|
+
# every block size
|
|
173
|
+
df_chunk_size = result_len // process_num
|
|
174
|
+
|
|
175
|
+
# generate the same len of map_dict df
|
|
176
|
+
result_df = initialize_result_df(result_len)
|
|
177
|
+
if df_chunk_size > 0:
|
|
178
|
+
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
179
|
+
else:
|
|
180
|
+
df_chunks = [result_df]
|
|
181
|
+
process_num = 1
|
|
182
|
+
logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}")
|
|
183
|
+
|
|
184
|
+
# 分割字典
|
|
185
|
+
map_chunks = split_dict(map_dict, df_chunk_size)
|
|
186
|
+
|
|
187
|
+
# 创建结果列表和进程池
|
|
188
|
+
results = []
|
|
189
|
+
pool = multiprocessing.Pool(process_num)
|
|
190
|
+
|
|
191
|
+
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
192
|
+
|
|
193
|
+
def update_progress(size, progress_lock, extra_param=None):
|
|
194
|
+
with progress_lock:
|
|
195
|
+
progress_bar.update(size)
|
|
196
|
+
|
|
197
|
+
def err_call(args):
|
|
198
|
+
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
199
|
+
try:
|
|
200
|
+
pool.close()
|
|
201
|
+
except OSError as e:
|
|
202
|
+
logger.error(f'pool terminate failed: {str(e)}')
|
|
203
|
+
results = []
|
|
204
|
+
try:
|
|
205
|
+
# 提交任务到进程池
|
|
206
|
+
for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
|
|
207
|
+
start_idx = df_chunk_size * process_idx
|
|
208
|
+
result = pool.apply_async(
|
|
209
|
+
func,
|
|
210
|
+
args=(df_chunk, start_idx, map_chunk, lock),
|
|
211
|
+
error_callback=err_call,
|
|
212
|
+
callback=partial(update_progress, len(map_chunk), lock)
|
|
213
|
+
)
|
|
214
|
+
results.append(result)
|
|
215
|
+
|
|
216
|
+
final_results = [r.get() for r in results]
|
|
217
|
+
# 等待所有任务完成
|
|
218
|
+
pool.close()
|
|
219
|
+
pool.join()
|
|
220
|
+
return pd.concat(final_results, ignore_index=True)
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f"\nMain process error: {str(e)}")
|
|
223
|
+
pool.terminate()
|
|
224
|
+
return pd.DataFrame({})
|
|
225
|
+
finally:
|
|
226
|
+
pool.close()
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def initialize_result_df(total_size):
|
|
230
|
+
"""预分配结果DataFrame"""
|
|
231
|
+
columns = [
|
|
232
|
+
CompareConst.NAME,
|
|
233
|
+
CompareConst.NPU_DTYPE,
|
|
234
|
+
CompareConst.BENCH_DTYPE,
|
|
235
|
+
CompareConst.NPU_SHAPE,
|
|
236
|
+
CompareConst.BENCH_SHAPE,
|
|
237
|
+
CompareConst.COSINE,
|
|
238
|
+
CompareConst.EUC_DIST,
|
|
239
|
+
CompareConst.MAX_ABS_ERR,
|
|
240
|
+
CompareConst.MAX_RELATIVE_ERR,
|
|
241
|
+
CompareConst.ONE_THOUSANDTH_ERR_RATIO,
|
|
242
|
+
CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
243
|
+
CompareConst.NPU_MAX,
|
|
244
|
+
CompareConst.NPU_MIN,
|
|
245
|
+
CompareConst.NPU_MEAN,
|
|
246
|
+
CompareConst.NPU_NORM,
|
|
247
|
+
CompareConst.BENCH_MAX,
|
|
248
|
+
CompareConst.BENCH_MIN,
|
|
249
|
+
CompareConst.BENCH_MEAN,
|
|
250
|
+
CompareConst.BENCH_NORM,
|
|
251
|
+
CompareConst.ACCURACY,
|
|
252
|
+
CompareConst.ERROR_MESSAGE,
|
|
253
|
+
CompareConst.DATA_NAME
|
|
254
|
+
]
|
|
255
|
+
return pd.DataFrame(index=range(total_size), columns=columns)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def split_dict(input_dict, chunk_size):
|
|
259
|
+
"""将字典按指定chunk_size分割"""
|
|
260
|
+
items = list(input_dict.items())
|
|
261
|
+
if chunk_size > 0:
|
|
262
|
+
return [dict(items[i:i + chunk_size]) for i in range(0, len(items), chunk_size)]
|
|
263
|
+
return [input_dict]
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def get_tensor_stats(tensor: np.ndarray) -> Tuple[float, float, float, float]:
|
|
267
|
+
"""获取张量的统计信息"""
|
|
268
|
+
t_max = np.max(tensor)
|
|
269
|
+
t_min = np.min(tensor)
|
|
270
|
+
t_mean = np.mean(tensor)
|
|
271
|
+
t_l2norm = np.linalg.norm(tensor)
|
|
272
|
+
return t_max, t_min, t_mean, t_l2norm
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def process_chunk(df, start_idx, map_chunk, lock):
|
|
276
|
+
"""处理一个数据块"""
|
|
277
|
+
err_mess = []
|
|
278
|
+
results = []
|
|
279
|
+
for name, file_pair in map_chunk.items():
|
|
280
|
+
err_msg = ""
|
|
281
|
+
npu_file, bench_file = file_pair
|
|
282
|
+
n_value = load_npy(npu_file)
|
|
283
|
+
# if need to support cross frame b_value need to add load_pt
|
|
284
|
+
b_value = load_npy(bench_file)
|
|
285
|
+
error_flag = False
|
|
286
|
+
|
|
287
|
+
err_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
|
|
288
|
+
cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio = err_list
|
|
289
|
+
a_max, a_min, a_mean, a_l2norm = get_tensor_stats(n_value)
|
|
290
|
+
b_max, b_min, b_mean, b_l2norm = get_tensor_stats(b_value)
|
|
291
|
+
err_mess.append(err_msg)
|
|
292
|
+
# 使用示例
|
|
293
|
+
result = ComparisonResult(
|
|
294
|
+
name=name, # CompareConst.NAME
|
|
295
|
+
npu_dtype=n_value.dtype, # CompareConst.NPU_DTYPE
|
|
296
|
+
bench_dtype=b_value.dtype, # CompareConst.BENCH_DTYPE
|
|
297
|
+
npu_shape=n_value.shape, # CompareConst.NPU_SHAPE
|
|
298
|
+
bench_shape=b_value.shape, # CompareConst.BENCH_SHAPE
|
|
299
|
+
cosine=cos_sim, # CompareConst.COSINE
|
|
300
|
+
euc_dist=euc_dist, # CompareConst.EUC_DIST
|
|
301
|
+
max_abs_err=max_abs_err, # CompareConst.MAX_ABS_ERR
|
|
302
|
+
max_relative_err=max_relative_err, # CompareConst.MAX_RELATIVE_ERR
|
|
303
|
+
one_thousandth_err_ratio=one_thousand_err_ratio, # CompareConst.ONE_THOUSANDTH_ERR_RATIO
|
|
304
|
+
five_thousandth_err_ratio=five_thousand_err_ratio, # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
|
|
305
|
+
npu_max=a_max, # CompareConst.NPU_MAX
|
|
306
|
+
npu_min=a_min, # CompareConst.NPU_MIN
|
|
307
|
+
npu_mean=a_mean, # CompareConst.NPU_MEAN
|
|
308
|
+
npu_norm=a_l2norm, # CompareConst.NPU_NORM
|
|
309
|
+
bench_max=b_max, # CompareConst.BENCH_MAX
|
|
310
|
+
bench_min=b_min, # CompareConst.BENCH_MIN
|
|
311
|
+
bench_mean=b_mean, # CompareConst.BENCH_MEAN
|
|
312
|
+
bench_norm=b_l2norm, # CompareConst.BENCH_NORM
|
|
313
|
+
accuracy=check_accuracy(cos_sim, max_abs_err), # CompareConst.ACCURACY
|
|
314
|
+
error_message=err_msg, # CompareConst.ERROR_MESSAGE
|
|
315
|
+
data_name=[npu_file, bench_file] # CompareConst.DATA_NAME
|
|
316
|
+
)
|
|
317
|
+
results.append(result)
|
|
318
|
+
return _save_part_df(df, start_idx, results, lock)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@dataclass
|
|
322
|
+
class ComparisonResult:
|
|
323
|
+
name: str # CompareConst.NAME
|
|
324
|
+
npu_dtype: Any # CompareConst.NPU_DTYPE
|
|
325
|
+
bench_dtype: Any # CompareConst.BENCH_DTYPE
|
|
326
|
+
npu_shape: Tuple[int, ...] # CompareConst.NPU_SHAPE
|
|
327
|
+
bench_shape: Tuple[int, ...] # CompareConst.BENCH_SHAPE
|
|
328
|
+
cosine: float # Cons t.COSINE
|
|
329
|
+
euc_dist: float # CompareConst.EUC_DIST
|
|
330
|
+
max_abs_err: float # CompareConst.MAX_ABS_ERR
|
|
331
|
+
max_relative_err: float # CompareConst.MAX_RELATIVE_ERR
|
|
332
|
+
one_thousandth_err_ratio: float # CompareConst.ONE_THOUSANDTH_ERR_RATIO
|
|
333
|
+
five_thousandth_err_ratio: float # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
|
|
334
|
+
npu_max: float # CompareConst.NPU_MAX
|
|
335
|
+
npu_min: float # CompareConst.NPU_MIN
|
|
336
|
+
npu_mean: float # CompareConst.NPU_MEAN
|
|
337
|
+
npu_norm: float # CompareConst.NPU_NORM
|
|
338
|
+
bench_max: float # CompareConst.BENCH_MAX
|
|
339
|
+
bench_min: float # CompareConst.BENCH_MIN
|
|
340
|
+
bench_mean: float # CompareConst.BENCH_MEAN
|
|
341
|
+
bench_norm: float # CompareConst.BENCH_NORM
|
|
342
|
+
accuracy: bool # CompareConst.ACCURACY
|
|
343
|
+
error_message: str # CompareConst.ERROR_MESSAGE
|
|
344
|
+
data_name: List[str] # CompareConst.DATA_NAME
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _save_part_df(df, start_idx, results, lock):
|
|
348
|
+
lock.acquire()
|
|
349
|
+
try:
|
|
350
|
+
for i, result in enumerate(results):
|
|
351
|
+
process_index = i + start_idx
|
|
352
|
+
df.loc[process_index, CompareConst.NAME] = result.name
|
|
353
|
+
df.loc[process_index, CompareConst.NPU_DTYPE] = result.npu_dtype
|
|
354
|
+
df.loc[process_index, CompareConst.BENCH_DTYPE] = result.bench_dtype
|
|
355
|
+
df.loc[process_index, CompareConst.NPU_SHAPE] = str(result.npu_shape) # 通常将tuple转为字符串存储
|
|
356
|
+
df.loc[process_index, CompareConst.BENCH_SHAPE] = str(result.bench_shape)
|
|
357
|
+
df.loc[process_index, CompareConst.COSINE] = result.cosine
|
|
358
|
+
df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist
|
|
359
|
+
df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_abs_err
|
|
360
|
+
df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err
|
|
361
|
+
df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousandth_err_ratio
|
|
362
|
+
df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousandth_err_ratio
|
|
363
|
+
df.loc[process_index, CompareConst.NPU_MAX] = result.npu_max
|
|
364
|
+
df.loc[process_index, CompareConst.NPU_MIN] = result.npu_min
|
|
365
|
+
df.loc[process_index, CompareConst.NPU_MEAN] = result.npu_mean
|
|
366
|
+
df.loc[process_index, CompareConst.NPU_NORM] = result.npu_norm
|
|
367
|
+
df.loc[process_index, CompareConst.BENCH_MAX] = result.bench_max
|
|
368
|
+
df.loc[process_index, CompareConst.BENCH_MIN] = result.bench_min
|
|
369
|
+
df.loc[process_index, CompareConst.BENCH_MEAN] = result.bench_mean
|
|
370
|
+
df.loc[process_index, CompareConst.BENCH_NORM] = result.bench_norm
|
|
371
|
+
df.loc[process_index, CompareConst.ACCURACY] = result.accuracy
|
|
372
|
+
df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.error_message
|
|
373
|
+
df.loc[process_index, CompareConst.DATA_NAME] = str(result.data_name) # 列表转为字符串存储
|
|
374
|
+
return df
|
|
375
|
+
except ValueError as e:
|
|
376
|
+
logger.error('result dataframe is not found.')
|
|
377
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
378
|
+
except IndexError as e:
|
|
379
|
+
logger.error('result dataframe elements can not be access.')
|
|
380
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
381
|
+
finally:
|
|
382
|
+
lock.release()
|
|
@@ -13,41 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
16
|
from msprobe.core.common.utils import CompareException
|
|
18
17
|
from msprobe.core.common.file_utils import create_directory
|
|
19
18
|
from msprobe.core.common.exceptions import FileCheckException
|
|
20
19
|
from msprobe.mindspore.common.log import logger
|
|
21
20
|
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
22
|
-
from msprobe.core.compare.utils import
|
|
21
|
+
from msprobe.core.compare.utils import compare_distributed_inner
|
|
23
22
|
from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
27
|
-
|
|
28
|
-
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
29
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
30
|
-
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
31
|
-
# get the ranks and match by order
|
|
32
|
-
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
33
|
-
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
34
|
-
if len(npu_ranks) != len(bench_ranks):
|
|
35
|
-
logger.error('The number of ranks in the two runs are different. '
|
|
36
|
-
'Unable to match the ranks. Please use another folder to compare '
|
|
37
|
-
'or use compare() api and manually match the ranks.')
|
|
38
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
39
|
-
for nr, br in zip(npu_ranks, bench_ranks):
|
|
40
|
-
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
41
|
-
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
42
|
-
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
43
|
-
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
44
|
-
|
|
45
|
-
dump_result_param = {
|
|
46
|
-
'npu_json_path': npu_path,
|
|
47
|
-
'bench_json_path': bench_path,
|
|
48
|
-
'is_print_compare_log': is_print_compare_log
|
|
49
|
-
}
|
|
50
|
-
ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
26
|
+
compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, ms_compare, **kwargs)
|
|
51
27
|
|
|
52
28
|
|
|
53
29
|
def ms_graph_compare(inputs, outputs):
|