mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,367 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import random
18
+ from functools import wraps
19
+ from typing import Callable, List, Dict, Tuple, Optional
20
+ import inspect
21
+ import os
22
+ import json
23
+ from collections import defaultdict
24
+ import difflib
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+ from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list
29
+ from msprobe.core.common.file_utils import create_file_in_zip, load_json
30
+ from msprobe.core.config_check.checkers.base_checker import BaseChecker
31
+ from msprobe.core.config_check.utils.utils import config_checking_print
32
+ from msprobe.core.common.framework_adapter import FmkAdp
33
+ from msprobe.core.common.const import Const
34
+ from msprobe.core.common.log import logger
35
+
36
+
37
+ # 数据结构:{随机操作名字: [{count: 调用次数, stack: 调用栈列表}]}
38
+ random_op_stats = defaultdict(list)
39
+
40
+
41
+ def get_call_stack(frame) -> List[str]:
42
+ """获取详细的调用栈信息,每个元素包含完整路径、行号、函数名和代码行"""
43
+ stack = []
44
+ current_frame = frame.f_back # 跳过当前函数
45
+
46
+ while current_frame:
47
+ frame_info = inspect.getframeinfo(current_frame)
48
+ filename = os.path.abspath(frame_info.filename)
49
+ code_line = frame_info.code_context[0].strip() if frame_info.code_context else ""
50
+
51
+ # 格式化为详细的栈帧信息
52
+ stack_entry = f"File {filename}, line {frame_info.lineno}, in {frame_info.function}, {code_line}"
53
+ stack.append(stack_entry)
54
+
55
+ current_frame = current_frame.f_back
56
+
57
+ # 反转堆栈以显示正确的调用顺序(栈底到栈顶)
58
+ return stack[::-1]
59
+
60
+
61
+ def track_random_call(func: Callable, name: str):
62
+ """记录随机函数的调用信息"""
63
+ @wraps(func)
64
+ def wrapper(*args, **kwargs):
65
+ frame = inspect.currentframe()
66
+ stack = get_call_stack(frame)
67
+
68
+ # 更新调用统计:操作名 -> [{count: 次数, stack: 调用栈列表}]
69
+ # 检查是否已有相同调用栈的记录
70
+ for entry in random_op_stats[name]:
71
+ if entry['stack'] == stack:
72
+ entry['count'] += 1
73
+ break
74
+ else:
75
+ # 新增调用栈记录
76
+ random_op_stats[name].append({'count': 1, 'stack': stack})
77
+
78
+ try:
79
+ result = func(*args, **kwargs)
80
+ return result
81
+ except Exception as e:
82
+ raise e
83
+ finally:
84
+ del frame
85
+
86
+ return wrapper
87
+
88
+
89
+ def load_stats_files(directory: str) -> Dict[str, Dict[str, List[Dict]]]:
90
+ """加载目录下所有统计文件并按rank组织数据"""
91
+ rank_data = {}
92
+ for file in os.listdir(directory):
93
+ file_path = os.path.join(directory, file)
94
+ if file.startswith('rank') and file.endswith('.json'):
95
+ rank = os.path.basename(file.split('.')[0])[4:]
96
+ if not rank or not rank.isdigit():
97
+ logger.error(f"extract rank id from {file} failed")
98
+ raise ValueError
99
+
100
+ # 加载并存储数据
101
+ data = load_json(file_path)
102
+ rank_data[int(rank)] = data
103
+
104
+ return rank_data
105
+
106
+
107
+ def stack_match(stack1: List[str], stack2: List[str], threshold: float = 0.8) -> bool:
108
+ """
109
+ 比较两个调用栈是否相似,同时考虑路径、函数名和代码行(各占1/3),每一层的相似度阈值需要达到0.8
110
+
111
+ 参数:
112
+ - stack1: 第一个调用栈列表
113
+ - stack2: 第二个调用栈列表
114
+ - threshold: 相似度阈值,默认0.8
115
+
116
+ 返回:
117
+ - 两个调用栈是否相似的布尔值
118
+ """
119
+ if len(stack1) != len(stack2):
120
+ return False
121
+
122
+ for frame1, frame2 in zip(stack1, stack2):
123
+ # 提取路径、函数名和代码行
124
+ path1, func1, code1 = _parse_frame(frame1)
125
+ path2, func2, code2 = _parse_frame(frame2)
126
+
127
+ # 计算相似度得分 (路径、函数名、代码行各占1/3权重)
128
+ path_score = _compare_path(path1, path2)
129
+ func_score = 1.0 if func1 == func2 else 0.0
130
+ # 代码相似度
131
+ code_score = difflib.SequenceMatcher(None, code1, code2).ratio()
132
+
133
+ frame_score = (path_score + func_score + code_score) / 3.0
134
+ if frame_score < threshold:
135
+ return False
136
+
137
+ return True
138
+
139
+
140
+ def _parse_frame(frame: str) -> Tuple[str, str, str]:
141
+ """
142
+ 解析栈帧字符串,提取路径、函数名和代码行
143
+
144
+ 参数:
145
+ - frame: 栈帧字符串。格式为"File {path}, line {line}, in {func}, {code}"
146
+
147
+ 返回:
148
+ - path, func, code
149
+ """
150
+ path = func = code = ''
151
+ stack_info = frame.split(' ')
152
+ if len(stack_info) > 6:
153
+ path = stack_info[1][:-1]
154
+ func = stack_info[5][:-1]
155
+ code = ' '.join(stack_info[6:])
156
+ return path, func, code
157
+
158
+
159
+ def _compare_path(path1: str, path2: str) -> float:
160
+ """比较两个路径的相似度,只考虑文件名"""
161
+ if not path1 or not path2:
162
+ return 0.0
163
+
164
+ # 提取文件名(忽略目录路径)
165
+ file1 = os.path.basename(path1)
166
+ file2 = os.path.basename(path2)
167
+
168
+ return 1.0 if file1 == file2 else 0.0
169
+
170
+
171
+ def find_matching_stack(bench_stack: List[str], cmp_stacks: List[Dict]) -> Optional[Dict]:
172
+ """
173
+ 查找匹配的调用栈
174
+
175
+ 参数:
176
+ - bench_stack: 基准侧的调用栈列表
177
+ - cmp_stacks: 比较侧的调用栈条目列表,每个条目是{'count': 次数, 'stack': 调用栈列表}
178
+
179
+ 返回:
180
+ - 匹配的调用栈条目或None
181
+ """
182
+ for cmp_entry in cmp_stacks:
183
+ if stack_match(cmp_entry['stack'], bench_stack):
184
+ return cmp_entry
185
+
186
+ return None
187
+
188
+
189
+ def stack_list_to_string(stack_list):
190
+ """
191
+ 将调用栈列表转换为换行分隔的字符串
192
+ 如果输入是特殊标记(如"no match stack"),则直接返回
193
+ """
194
+ if isinstance(stack_list, list):
195
+ return '\n'.join(stack_list)
196
+ return stack_list
197
+
198
+
199
+ def compare_random_calls(bench_dir: str = 'bench', cmp_dir: str = 'cmp') -> pd.DataFrame:
200
+ """比较两个目录下的随机调用栈统计,生成详细比对结果"""
201
+ bench_rank_data = load_stats_files(bench_dir)
202
+ cmp_rank_data = load_stats_files(cmp_dir)
203
+
204
+ # 获取所有rank
205
+ all_ranks = sorted(set(bench_rank_data.keys()) | set(cmp_rank_data.keys()))
206
+
207
+ results = []
208
+
209
+ for rank in all_ranks:
210
+ bench_data = bench_rank_data.get(rank, {})
211
+ cmp_data = cmp_rank_data.get(rank, {})
212
+
213
+ # 获取所有操作
214
+ all_ops = set(bench_data.keys()) | set(cmp_data.keys())
215
+
216
+ for op in all_ops:
217
+ bench_stacks = bench_data.get(op, [])
218
+ cmp_stacks = cmp_data.get(op, [])
219
+
220
+ # 处理bench侧的每个调用栈
221
+ for bench_entry in bench_stacks:
222
+ bench_stack = bench_entry['stack']
223
+ bench_count = bench_entry['count']
224
+
225
+ # 查找匹配的cmp侧调用栈
226
+ cmp_entry = find_matching_stack(bench_stack, cmp_stacks)
227
+
228
+ if cmp_entry:
229
+ cmp_count = cmp_entry['count']
230
+ check_result = bench_count == cmp_count
231
+ results.append([op, rank, bench_stack, cmp_entry['stack'], bench_count, cmp_count, check_result])
232
+ else:
233
+ # 没有匹配的调用栈
234
+ results.append([op, rank, bench_stack, "no match stack", bench_count, 0, False])
235
+
236
+ # 处理cmp侧中没有在bench侧出现的调用栈
237
+ for cmp_entry in cmp_stacks:
238
+ cmp_stack = cmp_entry['stack']
239
+ # 检查是否已经在上面处理过
240
+ if not any(stack_match(bench_entry['stack'], cmp_stack) for bench_entry in bench_stacks):
241
+ results.append([op, rank, "no match stack", cmp_stack, 0, cmp_entry['count'], False])
242
+
243
+ # 创建DataFrame
244
+ df = pd.DataFrame(results, columns=RandomChecker.result_header)
245
+
246
+ # 应用转换函数
247
+ df['bench_stack'] = df['bench_stack'].apply(stack_list_to_string)
248
+ df['cmp_stack'] = df['cmp_stack'].apply(stack_list_to_string)
249
+
250
+ return df
251
+
252
+
253
+ def torch_patchs():
254
+ """补丁Torch随机函数"""
255
+ import torch
256
+ torch_patches = {
257
+ 'rand': torch.rand,
258
+ 'randint': torch.randint,
259
+ 'randn': torch.randn,
260
+ 'rand_like': torch.rand_like,
261
+ 'randint_like': torch.randint_like,
262
+ 'randn_like': torch.randn_like,
263
+ 'manual_seed': torch.manual_seed
264
+ }
265
+ for name, func in torch_patches.items():
266
+ setattr(torch, name, track_random_call(func, f"torch.{name}"))
267
+
268
+ tensor_patches = {
269
+ 'exponential_': torch.Tensor.exponential_,
270
+ 'geometric_': torch.Tensor.geometric_,
271
+ 'log_normal_': torch.Tensor.log_normal_,
272
+ 'cauchy_': torch.Tensor.cauchy_
273
+ }
274
+ for name, func in tensor_patches.items():
275
+ setattr(torch.Tensor, name, track_random_call(func, f"torch.Tensor.{name}"))
276
+
277
+
278
+ def mindspore_patchs():
279
+ """补丁MindSpore随机函数"""
280
+ import mindspore
281
+
282
+ mindspore_ops_patches = {
283
+ 'rand': mindspore.ops.uniform,
284
+ 'randint': mindspore.ops.randint,
285
+ 'randn': mindspore.ops.normal
286
+ }
287
+ for name, func in mindspore_ops_patches.items():
288
+ setattr(mindspore.ops, name, track_random_call(func, f"mindspore.ops.{name}"))
289
+
290
+ mindspore_patches = {
291
+ 'manual_seed': mindspore.set_seed
292
+ }
293
+ for name, func in mindspore_patches.items():
294
+ setattr(mindspore, name, track_random_call(func, f"mindspore.{name}"))
295
+
296
+
297
+ @register_checker_item("random")
298
+ class RandomChecker(BaseChecker):
299
+ input_needed = None
300
+ target_name_in_zip = "random"
301
+ result_header = ['op', 'rank', 'bench_stack', 'cmp_stack', 'bench_count', 'cmp_count', 'check_result']
302
+ write_once = False
303
+
304
+ @staticmethod
305
+ def pack(pack_input):
306
+ """打包随机调用统计到zip文件"""
307
+ output_zip_path = pack_input.output_zip_path
308
+
309
+ def collect_input(model, args, kwargs, step):
310
+ if RandomChecker.write_once:
311
+ return
312
+
313
+ random_stats_dir = os.path.join(RandomChecker.target_name_in_zip)
314
+ stats_filepath = os.path.join(random_stats_dir, f"rank{FmkAdp.get_rank_id()}.json")
315
+
316
+ # 转换为JSON格式:{操作名: [{count: 次数, stack: 调用栈列表}]}
317
+ stats_json = {}
318
+ for op_name, entries in random_op_stats.items():
319
+ stats_json[op_name] = entries
320
+
321
+ create_file_in_zip(output_zip_path, stats_filepath, json.dumps(stats_json, indent=4))
322
+ config_checking_print(f"已将随机调用统计打包到: {stats_filepath}")
323
+ RandomChecker.write_once = True
324
+
325
+ register_pre_forward_fun_list(collect_input)
326
+
327
+ @staticmethod
328
+ def compare(bench_dir, cmp_dir, output_path, fmk):
329
+ """比较两组随机调用统计"""
330
+ bench_stats_path = os.path.join(bench_dir, RandomChecker.target_name_in_zip)
331
+ cmp_stats_path = os.path.join(cmp_dir, RandomChecker.target_name_in_zip)
332
+
333
+ df = compare_random_calls(bench_stats_path, cmp_stats_path)
334
+ pass_check = False not in df['check_result'].values
335
+
336
+ return RandomChecker.target_name_in_zip, pass_check, df
337
+
338
+ @staticmethod
339
+ def apply_patches(fmk=Const.PT_FRAMEWORK):
340
+ """应用随机函数补丁"""
341
+ # 补丁Python random模块
342
+ random_patches = {
343
+ 'random': random.random,
344
+ 'randint': random.randint,
345
+ 'uniform': random.uniform,
346
+ 'choice': random.choice
347
+ }
348
+ for name, func in random_patches.items():
349
+ setattr(random, name, track_random_call(func, f"random.{name}"))
350
+
351
+ # 补丁Numpy随机函数
352
+ np_random_patches = {
353
+ 'rand': np.random.rand,
354
+ 'randint': np.random.randint,
355
+ 'choice': np.random.choice,
356
+ 'normal': np.random.normal
357
+ }
358
+ for name, func in np_random_patches.items():
359
+ setattr(np.random, name, track_random_call(func, f"np.random.{name}"))
360
+
361
+ # 补丁框架特定随机函数
362
+ if fmk == Const.PT_FRAMEWORK:
363
+ torch_patchs()
364
+ elif fmk == Const.MS_FRAMEWORK:
365
+ mindspore_patchs()
366
+ else:
367
+ raise Exception(f"不支持的框架: {fmk}, 支持的框架: {FmkAdp.supported_fmk}")
@@ -0,0 +1,147 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import json
18
+ import pandas as pd
19
+
20
+ from msprobe.core.common.file_utils import create_file_in_zip, os_walk_for_files, load_json
21
+ from msprobe.core.config_check.checkers.base_checker import BaseChecker
22
+ from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list
23
+ from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features
24
+ from msprobe.core.common.framework_adapter import FmkAdp
25
+
26
+
27
+ def collect_weights_data(model):
28
+ weights_data = {}
29
+ for name, param in FmkAdp.named_parameters(model):
30
+ if param.dtype != FmkAdp.dtype("float32"):
31
+ param = param.float()
32
+ weights_data[name] = get_tensor_features(param)
33
+ return weights_data
34
+
35
+
36
+ def compare_weight_file(bench_file, cmp_file):
37
+ bench_data = load_json(bench_file)
38
+ cmp_data = load_json(cmp_file)
39
+
40
+ results = []
41
+ for weight_name in set(bench_data.keys()) | set(cmp_data.keys()):
42
+ result = {
43
+ "weight_name": weight_name,
44
+ "equal": None,
45
+ "max_relative_diff": None,
46
+ "min_relative_diff": None,
47
+ "mean_relative_diff": None,
48
+ "norm_relative_diff": None
49
+ }
50
+
51
+ if weight_name not in bench_data:
52
+ result["equal"] = "only cmp have"
53
+ results.append(result)
54
+ continue
55
+
56
+ if weight_name not in cmp_data:
57
+ result["equal"] = "only bench have"
58
+ results.append(result)
59
+ continue
60
+
61
+ bench_vals = bench_data[weight_name]
62
+ cmp_vals = cmp_data[weight_name]
63
+ keys = ["max", "min", "mean", "norm"]
64
+ equal = all([bench_vals[k] == cmp_vals[k] for k in keys])
65
+ result["equal"] = equal
66
+
67
+ for key in keys:
68
+ diff_key = f"{key}_relative_diff"
69
+ result[diff_key] = (abs(bench_vals[key] - cmp_vals[key]) / bench_vals[key]) \
70
+ if bench_vals[key] != 0 else None
71
+
72
+ results.append(result)
73
+
74
+ return results
75
+
76
+
77
+ def compare_weight(bench_dir, cmp_dir):
78
+ all_results = []
79
+ bench_files_info = os_walk_for_files(bench_dir, 10)
80
+ for info in bench_files_info:
81
+ if not info["file"].endswith('.json'):
82
+ continue
83
+ bench_file = os.path.join(info["root"], info["file"])
84
+ relative_path = os.path.relpath(info["root"], bench_dir)
85
+ cmp_root = os.path.join(cmp_dir, relative_path)
86
+ cmp_file = os.path.join(cmp_root, info["file"])
87
+
88
+ path_list = relative_path.split(os.sep)
89
+ if len(path_list) < 2:
90
+ raise Exception("Can not compare weights because the extracted file has been corrupted!")
91
+ step = int(path_list[0].replace("step", ""))
92
+ rank = int(path_list[1].replace("rank", ""))
93
+
94
+ if not os.path.exists(cmp_file):
95
+ bench_data = load_json(bench_file)
96
+ for weight_name in bench_data.keys():
97
+ result = {
98
+ "step": step,
99
+ "rank": rank,
100
+ "weight_name": weight_name,
101
+ "equal": "only bench have",
102
+ "max_relative_diff": None,
103
+ "min_relative_diff": None,
104
+ "mean_relative_diff": None,
105
+ "norm_relative_diff": None
106
+ }
107
+ all_results.append(result)
108
+ else:
109
+ results = compare_weight_file(bench_file, cmp_file)
110
+ for res in results:
111
+ res["step"] = step
112
+ res["rank"] = rank
113
+ all_results.append(res)
114
+
115
+ df = pd.DataFrame(all_results, columns=WeightsChecker.result_header)
116
+ df = df.sort_values(by=['step', 'rank'], ascending=[True, True])
117
+ return df
118
+
119
+
120
+ @register_checker_item("weights")
121
+ class WeightsChecker(BaseChecker):
122
+ input_needed = "model"
123
+ multi_rank = True
124
+
125
+ target_name_in_zip = "weights"
126
+ result_header = ["step", "rank", "weight_name", "equal", "max_relative_diff",
127
+ "min_relative_diff", "mean_relative_diff", "norm_relative_diff"]
128
+
129
+ @staticmethod
130
+ def pack(pack_input):
131
+ output_zip_path = pack_input.output_zip_path
132
+
133
+ def collect_weights(model, args, kwargs, step):
134
+ weights_data_dict = collect_weights_data(model)
135
+ weights_data_filepath = os.path.join(WeightsChecker.target_name_in_zip,
136
+ f"step{step}", f"rank{FmkAdp.get_rank_id()}", "weight.json")
137
+ create_file_in_zip(output_zip_path, weights_data_filepath, json.dumps(weights_data_dict, indent=4))
138
+ config_checking_print(f"add weights info to zip")
139
+ register_pre_forward_fun_list(collect_weights)
140
+
141
+ @staticmethod
142
+ def compare(bench_dir, cmp_dir, output_path, fmk):
143
+ bench_weight_pack_path = os.path.join(bench_dir, WeightsChecker.target_name_in_zip)
144
+ cmp_weight_pack_path = os.path.join(cmp_dir, WeightsChecker.target_name_in_zip)
145
+ df = compare_weight(bench_weight_pack_path, cmp_weight_pack_path)
146
+ pass_check = False not in df['equal'].values
147
+ return WeightsChecker.target_name_in_zip, pass_check, df
@@ -0,0 +1,74 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict
17
+ from tqdm import tqdm
18
+
19
+ from msprobe.core.common.file_utils import save_json, check_path_before_create, check_path_not_exists
20
+ from msprobe.core.common.log import logger
21
+ from msprobe.core.config_check.ckpt_compare.megatron_loader import load_megatron_weights
22
+ from msprobe.core.config_check.ckpt_compare.metrics import METRIC_FUNC
23
+
24
+
25
+
26
+ def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict:
27
+ """Compare weights between two checkpoints using cosine similarity and L2 distance.
28
+
29
+ Args:
30
+ ckpt_path1 (str): Path to first checkpoint directory
31
+ ckpt_path2 (str): Path to second checkpoint directory
32
+ output_path (str): Path to save comparison results JSON file
33
+
34
+ Returns:
35
+ Dict: Dictionary containing comparison metrics for each parameter. The dictionary has the following structure:
36
+ {
37
+ "param_name": {
38
+ "cosine_similarity": float, # Cosine similarity between parameter tensors
39
+ "l2_distance": float, # L2 distance between parameter tensors
40
+ "shape": List[int] # Shape of the parameter tensors
41
+ },
42
+ ...
43
+ }
44
+ """
45
+
46
+ # Load both checkpoints
47
+ check_path_before_create(output_path)
48
+ check_path_not_exists(output_path)
49
+ weights1 = load_megatron_weights(ckpt_path1)
50
+ weights2 = load_megatron_weights(ckpt_path2)
51
+
52
+ # Initialize results dictionary
53
+ results = {}
54
+
55
+ # Compare weights with matching keys
56
+ common = set(weights1) & set(weights2)
57
+ logger.warning(f'Parameters not in ckpt2: {set(weights1) - set(weights2)}')
58
+ logger.warning(f'Parameters not in ckpt1: {set(weights2) - set(weights1)}')
59
+ for key in tqdm(common):
60
+ tensor1 = weights1[key]
61
+ tensor2 = weights2[key]
62
+
63
+ results[key] = {}
64
+ for metric, func in METRIC_FUNC.items():
65
+ try:
66
+ results[key][metric] = func(tensor1, tensor2)
67
+ except Exception as e:
68
+ results[key][metric] = 'error'
69
+ logger.warning(f'Error when calculate {metric} for reason: {e}')
70
+
71
+ # Write results to JSON file
72
+ save_json(output_path, results, indent=4)
73
+ logger.info(f"Comparison results written to {output_path}")
74
+ return results