mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,382 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.import functools
15
+
16
+ import os
17
+ import multiprocessing
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List, Tuple, Optional, Any
20
+ from concurrent.futures import ProcessPoolExecutor
21
+ from functools import partial
22
+ from pathlib import Path
23
+
24
+ import pandas as pd
25
+ import numpy as np
26
+ from tqdm import tqdm
27
+
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.core.common.utils import CompareException
30
+ from msprobe.core.common.exceptions import FileCheckException
31
+ from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \
32
+ check_path_before_create, load_npy
33
+ from msprobe.core.common.const import CompareConst, FileCheckConst
34
+ from msprobe.core.compare.npy_compare import compare_ops_apply
35
+ from msprobe.core.compare.multiprocessing_compute import check_accuracy
36
+
37
+
38
+ def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]:
39
+ """
40
+ 高级目录比对函数,完全镜像输入目录结构
41
+
42
+ Args:
43
+ input_params: 包含npu_path和bench_path的字典
44
+ output_dir: 输出根目录
45
+
46
+ Returns:
47
+ 当输入目录是平铺npy文件时返回DataFrame,否则返回None
48
+ """
49
+ npu_root = Path(input_params.get('npu_path'))
50
+ bench_root = Path(input_params.get('bench_path'))
51
+ name_map_dict = input_params.get('map_dict', {})
52
+ file_tree = build_mirror_file_tree(npu_root, bench_root)
53
+
54
+ # 处理文件比对
55
+ with ProcessPoolExecutor() as executor:
56
+ results = list(tqdm(
57
+ executor.map(
58
+ partial(process_directory_pair, name_map_dict=name_map_dict, output_dir=output_dir),
59
+ file_tree.items()
60
+ ),
61
+ total=len(file_tree),
62
+ desc="Processing directories"
63
+ ))
64
+ return
65
+
66
+
67
+ def process_directory_pair(item: Tuple[Path, Tuple[Path, Path]], name_map_dict: Dict, output_dir: str):
68
+ """
69
+ 处理一个目录对
70
+
71
+ Args:
72
+ item: (相对路径, (npu目录, bench目录))元组
73
+ output_dir: 输出根目录
74
+
75
+ Returns:
76
+ 比对结果的DataFrame(仅平铺结构时返回)
77
+ """
78
+ rel_path, (npu_dir, bench_dir) = item
79
+
80
+ # 创建镜像输出目录
81
+ output_path = Path(output_dir) / rel_path
82
+ create_directory(output_path)
83
+
84
+ # 生成文件映射
85
+ npu_files = find_npy_files(npu_dir)
86
+ bench_files = find_npy_files(bench_dir)
87
+ map_dict = generate_map_dict(npu_files, bench_files, name_map_dict)
88
+
89
+ if not map_dict:
90
+ logger.warning(f"No file pairs found in {rel_path}")
91
+ return None
92
+
93
+ # 执行比对
94
+ result_df = do_multi_process(process_chunk, map_dict)
95
+ check_path_before_create(output_path)
96
+ # 保存结果
97
+ result_path = os.path.join(output_path, 'result.csv')
98
+ write_df_to_csv(result_df, result_path)
99
+ logger.info(f"Results saved to {result_path}")
100
+ return None
101
+
102
+
103
+ def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple[Path, Path]]:
104
+ """
105
+ 构建镜像文件树,键为相对路径,值为(npu_path, bench_path)元组
106
+
107
+ Args:
108
+ npu_root: NPU数据根目录
109
+ bench_root: 基准数据根目录
110
+
111
+ Returns:
112
+ 文件树字典
113
+ """
114
+ file_tree = {}
115
+
116
+ # 遍历NPU目录构建树结构
117
+ for npu_path in npu_root.rglob('*.npy'):
118
+ dir_path = npu_path.relative_to(npu_root).parent
119
+ npu_dir_pair = os.path.join(npu_root, dir_path)
120
+ bench_dir_pair = os.path.join(bench_root, dir_path)
121
+ try:
122
+ check_file_or_directory_path(bench_dir_pair, isdir=True)
123
+ except FileCheckException:
124
+ continue
125
+ # 添加到文件树
126
+ if dir_path not in file_tree:
127
+ file_tree[dir_path] = (npu_dir_pair, bench_dir_pair)
128
+
129
+ return file_tree
130
+
131
+
132
+ def find_npy_files(directory):
133
+ npy_files_dict = {}
134
+ for root, _, files in os.walk(directory):
135
+ for file in files:
136
+ if file.endswith(".npy"):
137
+ # 分割文件名并去掉最后两个元素
138
+ file_name = file.split('_')
139
+ if len(file_name) < 2:
140
+ continue
141
+ key = '_'.join(file_name[:-2])
142
+ # 文件的完整路径
143
+ value = os.path.join(root, file)
144
+ # 添加到字典中
145
+ if not npy_files_dict.get(key):
146
+ npy_files_dict[key] = []
147
+ npy_files_dict[key].append(value)
148
+ return npy_files_dict
149
+
150
+
151
+ def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
152
+ for k, npu_file_list in npu_file_dict.items():
153
+ bench_file_list = bench_file_dict.get(k)
154
+ if not bench_file_list and k in name_map_dict:
155
+ bench_file_list = bench_file_dict.get(name_map_dict.get(k))
156
+ bench_length = len(bench_file_list)
157
+ if not (bench_file_list and bench_length):
158
+ continue
159
+ result_dict = {}
160
+ for i, npu_file in enumerate(npu_file_list):
161
+ if i >= bench_length:
162
+ break
163
+ bench_file = bench_file_list[i]
164
+ result_dict[f"{k}_{i}"] = (npu_file, bench_file)
165
+ return result_dict
166
+
167
+
168
+ def do_multi_process(func, map_dict):
169
+ lock = multiprocessing.Manager().RLock()
170
+ result_len = len(map_dict)
171
+ process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
172
+ # every block size
173
+ df_chunk_size = result_len // process_num
174
+
175
+ # generate the same len of map_dict df
176
+ result_df = initialize_result_df(result_len)
177
+ if df_chunk_size > 0:
178
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
179
+ else:
180
+ df_chunks = [result_df]
181
+ process_num = 1
182
+ logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}")
183
+
184
+ # 分割字典
185
+ map_chunks = split_dict(map_dict, df_chunk_size)
186
+
187
+ # 创建结果列表和进程池
188
+ results = []
189
+ pool = multiprocessing.Pool(process_num)
190
+
191
+ progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
192
+
193
+ def update_progress(size, progress_lock, extra_param=None):
194
+ with progress_lock:
195
+ progress_bar.update(size)
196
+
197
+ def err_call(args):
198
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
199
+ try:
200
+ pool.close()
201
+ except OSError as e:
202
+ logger.error(f'pool terminate failed: {str(e)}')
203
+ results = []
204
+ try:
205
+ # 提交任务到进程池
206
+ for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
207
+ start_idx = df_chunk_size * process_idx
208
+ result = pool.apply_async(
209
+ func,
210
+ args=(df_chunk, start_idx, map_chunk, lock),
211
+ error_callback=err_call,
212
+ callback=partial(update_progress, len(map_chunk), lock)
213
+ )
214
+ results.append(result)
215
+
216
+ final_results = [r.get() for r in results]
217
+ # 等待所有任务完成
218
+ pool.close()
219
+ pool.join()
220
+ return pd.concat(final_results, ignore_index=True)
221
+ except Exception as e:
222
+ logger.error(f"\nMain process error: {str(e)}")
223
+ pool.terminate()
224
+ return pd.DataFrame({})
225
+ finally:
226
+ pool.close()
227
+
228
+
229
+ def initialize_result_df(total_size):
230
+ """预分配结果DataFrame"""
231
+ columns = [
232
+ CompareConst.NAME,
233
+ CompareConst.NPU_DTYPE,
234
+ CompareConst.BENCH_DTYPE,
235
+ CompareConst.NPU_SHAPE,
236
+ CompareConst.BENCH_SHAPE,
237
+ CompareConst.COSINE,
238
+ CompareConst.EUC_DIST,
239
+ CompareConst.MAX_ABS_ERR,
240
+ CompareConst.MAX_RELATIVE_ERR,
241
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO,
242
+ CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
243
+ CompareConst.NPU_MAX,
244
+ CompareConst.NPU_MIN,
245
+ CompareConst.NPU_MEAN,
246
+ CompareConst.NPU_NORM,
247
+ CompareConst.BENCH_MAX,
248
+ CompareConst.BENCH_MIN,
249
+ CompareConst.BENCH_MEAN,
250
+ CompareConst.BENCH_NORM,
251
+ CompareConst.ACCURACY,
252
+ CompareConst.ERROR_MESSAGE,
253
+ CompareConst.DATA_NAME
254
+ ]
255
+ return pd.DataFrame(index=range(total_size), columns=columns)
256
+
257
+
258
+ def split_dict(input_dict, chunk_size):
259
+ """将字典按指定chunk_size分割"""
260
+ items = list(input_dict.items())
261
+ if chunk_size > 0:
262
+ return [dict(items[i:i + chunk_size]) for i in range(0, len(items), chunk_size)]
263
+ return [input_dict]
264
+
265
+
266
+ def get_tensor_stats(tensor: np.ndarray) -> Tuple[float, float, float, float]:
267
+ """获取张量的统计信息"""
268
+ t_max = np.max(tensor)
269
+ t_min = np.min(tensor)
270
+ t_mean = np.mean(tensor)
271
+ t_l2norm = np.linalg.norm(tensor)
272
+ return t_max, t_min, t_mean, t_l2norm
273
+
274
+
275
+ def process_chunk(df, start_idx, map_chunk, lock):
276
+ """处理一个数据块"""
277
+ err_mess = []
278
+ results = []
279
+ for name, file_pair in map_chunk.items():
280
+ err_msg = ""
281
+ npu_file, bench_file = file_pair
282
+ n_value = load_npy(npu_file)
283
+ # if need to support cross frame b_value need to add load_pt
284
+ b_value = load_npy(bench_file)
285
+ error_flag = False
286
+
287
+ err_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
288
+ cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio = err_list
289
+ a_max, a_min, a_mean, a_l2norm = get_tensor_stats(n_value)
290
+ b_max, b_min, b_mean, b_l2norm = get_tensor_stats(b_value)
291
+ err_mess.append(err_msg)
292
+ # 使用示例
293
+ result = ComparisonResult(
294
+ name=name, # CompareConst.NAME
295
+ npu_dtype=n_value.dtype, # CompareConst.NPU_DTYPE
296
+ bench_dtype=b_value.dtype, # CompareConst.BENCH_DTYPE
297
+ npu_shape=n_value.shape, # CompareConst.NPU_SHAPE
298
+ bench_shape=b_value.shape, # CompareConst.BENCH_SHAPE
299
+ cosine=cos_sim, # CompareConst.COSINE
300
+ euc_dist=euc_dist, # CompareConst.EUC_DIST
301
+ max_abs_err=max_abs_err, # CompareConst.MAX_ABS_ERR
302
+ max_relative_err=max_relative_err, # CompareConst.MAX_RELATIVE_ERR
303
+ one_thousandth_err_ratio=one_thousand_err_ratio, # CompareConst.ONE_THOUSANDTH_ERR_RATIO
304
+ five_thousandth_err_ratio=five_thousand_err_ratio, # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
305
+ npu_max=a_max, # CompareConst.NPU_MAX
306
+ npu_min=a_min, # CompareConst.NPU_MIN
307
+ npu_mean=a_mean, # CompareConst.NPU_MEAN
308
+ npu_norm=a_l2norm, # CompareConst.NPU_NORM
309
+ bench_max=b_max, # CompareConst.BENCH_MAX
310
+ bench_min=b_min, # CompareConst.BENCH_MIN
311
+ bench_mean=b_mean, # CompareConst.BENCH_MEAN
312
+ bench_norm=b_l2norm, # CompareConst.BENCH_NORM
313
+ accuracy=check_accuracy(cos_sim, max_abs_err), # CompareConst.ACCURACY
314
+ error_message=err_msg, # CompareConst.ERROR_MESSAGE
315
+ data_name=[npu_file, bench_file] # CompareConst.DATA_NAME
316
+ )
317
+ results.append(result)
318
+ return _save_part_df(df, start_idx, results, lock)
319
+
320
+
321
+ @dataclass
322
+ class ComparisonResult:
323
+ name: str # CompareConst.NAME
324
+ npu_dtype: Any # CompareConst.NPU_DTYPE
325
+ bench_dtype: Any # CompareConst.BENCH_DTYPE
326
+ npu_shape: Tuple[int, ...] # CompareConst.NPU_SHAPE
327
+ bench_shape: Tuple[int, ...] # CompareConst.BENCH_SHAPE
328
+ cosine: float # Cons t.COSINE
329
+ euc_dist: float # CompareConst.EUC_DIST
330
+ max_abs_err: float # CompareConst.MAX_ABS_ERR
331
+ max_relative_err: float # CompareConst.MAX_RELATIVE_ERR
332
+ one_thousandth_err_ratio: float # CompareConst.ONE_THOUSANDTH_ERR_RATIO
333
+ five_thousandth_err_ratio: float # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
334
+ npu_max: float # CompareConst.NPU_MAX
335
+ npu_min: float # CompareConst.NPU_MIN
336
+ npu_mean: float # CompareConst.NPU_MEAN
337
+ npu_norm: float # CompareConst.NPU_NORM
338
+ bench_max: float # CompareConst.BENCH_MAX
339
+ bench_min: float # CompareConst.BENCH_MIN
340
+ bench_mean: float # CompareConst.BENCH_MEAN
341
+ bench_norm: float # CompareConst.BENCH_NORM
342
+ accuracy: bool # CompareConst.ACCURACY
343
+ error_message: str # CompareConst.ERROR_MESSAGE
344
+ data_name: List[str] # CompareConst.DATA_NAME
345
+
346
+
347
+ def _save_part_df(df, start_idx, results, lock):
348
+ lock.acquire()
349
+ try:
350
+ for i, result in enumerate(results):
351
+ process_index = i + start_idx
352
+ df.loc[process_index, CompareConst.NAME] = result.name
353
+ df.loc[process_index, CompareConst.NPU_DTYPE] = result.npu_dtype
354
+ df.loc[process_index, CompareConst.BENCH_DTYPE] = result.bench_dtype
355
+ df.loc[process_index, CompareConst.NPU_SHAPE] = str(result.npu_shape) # 通常将tuple转为字符串存储
356
+ df.loc[process_index, CompareConst.BENCH_SHAPE] = str(result.bench_shape)
357
+ df.loc[process_index, CompareConst.COSINE] = result.cosine
358
+ df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist
359
+ df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_abs_err
360
+ df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err
361
+ df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousandth_err_ratio
362
+ df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousandth_err_ratio
363
+ df.loc[process_index, CompareConst.NPU_MAX] = result.npu_max
364
+ df.loc[process_index, CompareConst.NPU_MIN] = result.npu_min
365
+ df.loc[process_index, CompareConst.NPU_MEAN] = result.npu_mean
366
+ df.loc[process_index, CompareConst.NPU_NORM] = result.npu_norm
367
+ df.loc[process_index, CompareConst.BENCH_MAX] = result.bench_max
368
+ df.loc[process_index, CompareConst.BENCH_MIN] = result.bench_min
369
+ df.loc[process_index, CompareConst.BENCH_MEAN] = result.bench_mean
370
+ df.loc[process_index, CompareConst.BENCH_NORM] = result.bench_norm
371
+ df.loc[process_index, CompareConst.ACCURACY] = result.accuracy
372
+ df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.error_message
373
+ df.loc[process_index, CompareConst.DATA_NAME] = str(result.data_name) # 列表转为字符串存储
374
+ return df
375
+ except ValueError as e:
376
+ logger.error('result dataframe is not found.')
377
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
378
+ except IndexError as e:
379
+ logger.error('result dataframe elements can not be access.')
380
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
381
+ finally:
382
+ lock.release()
@@ -13,41 +13,17 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
17
16
  from msprobe.core.common.utils import CompareException
18
17
  from msprobe.core.common.file_utils import create_directory
19
18
  from msprobe.core.common.exceptions import FileCheckException
20
19
  from msprobe.mindspore.common.log import logger
21
20
  from msprobe.mindspore.compare.ms_compare import ms_compare
22
- from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
21
+ from msprobe.core.compare.utils import compare_distributed_inner
23
22
  from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
24
23
 
25
24
 
26
25
  def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
27
- if kwargs.get('suffix'):
28
- logger.error("Argument 'suffix' is not supported for compare_distributed.")
29
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
30
- is_print_compare_log = kwargs.get('is_print_compare_log', True)
31
- # get the ranks and match by order
32
- npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
33
- bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
34
- if len(npu_ranks) != len(bench_ranks):
35
- logger.error('The number of ranks in the two runs are different. '
36
- 'Unable to match the ranks. Please use another folder to compare '
37
- 'or use compare() api and manually match the ranks.')
38
- raise CompareException(CompareException.INVALID_PATH_ERROR)
39
- for nr, br in zip(npu_ranks, bench_ranks):
40
- npu_data_dir = os.path.join(npu_dump_dir, nr)
41
- bench_data_dir = os.path.join(bench_dump_dir, br)
42
- npu_path = extract_json(npu_data_dir, stack_json=False)
43
- bench_path = extract_json(bench_data_dir, stack_json=False)
44
-
45
- dump_result_param = {
46
- 'npu_json_path': npu_path,
47
- 'bench_json_path': bench_path,
48
- 'is_print_compare_log': is_print_compare_log
49
- }
50
- ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
26
+ compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, ms_compare, **kwargs)
51
27
 
52
28
 
53
29
  def ms_graph_compare(inputs, outputs):