mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import random
|
|
18
|
+
from functools import wraps
|
|
19
|
+
from typing import Callable, List, Dict, Tuple, Optional
|
|
20
|
+
import inspect
|
|
21
|
+
import os
|
|
22
|
+
import json
|
|
23
|
+
from collections import defaultdict
|
|
24
|
+
import difflib
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import pandas as pd
|
|
28
|
+
from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list
|
|
29
|
+
from msprobe.core.common.file_utils import create_file_in_zip, load_json
|
|
30
|
+
from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
31
|
+
from msprobe.core.config_check.utils.utils import config_checking_print
|
|
32
|
+
from msprobe.core.common.framework_adapter import FmkAdp
|
|
33
|
+
from msprobe.core.common.const import Const
|
|
34
|
+
from msprobe.core.common.log import logger
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# 数据结构:{随机操作名字: [{count: 调用次数, stack: 调用栈列表}]}
|
|
38
|
+
random_op_stats = defaultdict(list)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_call_stack(frame) -> List[str]:
|
|
42
|
+
"""获取详细的调用栈信息,每个元素包含完整路径、行号、函数名和代码行"""
|
|
43
|
+
stack = []
|
|
44
|
+
current_frame = frame.f_back # 跳过当前函数
|
|
45
|
+
|
|
46
|
+
while current_frame:
|
|
47
|
+
frame_info = inspect.getframeinfo(current_frame)
|
|
48
|
+
filename = os.path.abspath(frame_info.filename)
|
|
49
|
+
code_line = frame_info.code_context[0].strip() if frame_info.code_context else ""
|
|
50
|
+
|
|
51
|
+
# 格式化为详细的栈帧信息
|
|
52
|
+
stack_entry = f"File {filename}, line {frame_info.lineno}, in {frame_info.function}, {code_line}"
|
|
53
|
+
stack.append(stack_entry)
|
|
54
|
+
|
|
55
|
+
current_frame = current_frame.f_back
|
|
56
|
+
|
|
57
|
+
# 反转堆栈以显示正确的调用顺序(栈底到栈顶)
|
|
58
|
+
return stack[::-1]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def track_random_call(func: Callable, name: str):
|
|
62
|
+
"""记录随机函数的调用信息"""
|
|
63
|
+
@wraps(func)
|
|
64
|
+
def wrapper(*args, **kwargs):
|
|
65
|
+
frame = inspect.currentframe()
|
|
66
|
+
stack = get_call_stack(frame)
|
|
67
|
+
|
|
68
|
+
# 更新调用统计:操作名 -> [{count: 次数, stack: 调用栈列表}]
|
|
69
|
+
# 检查是否已有相同调用栈的记录
|
|
70
|
+
for entry in random_op_stats[name]:
|
|
71
|
+
if entry['stack'] == stack:
|
|
72
|
+
entry['count'] += 1
|
|
73
|
+
break
|
|
74
|
+
else:
|
|
75
|
+
# 新增调用栈记录
|
|
76
|
+
random_op_stats[name].append({'count': 1, 'stack': stack})
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
result = func(*args, **kwargs)
|
|
80
|
+
return result
|
|
81
|
+
except Exception as e:
|
|
82
|
+
raise e
|
|
83
|
+
finally:
|
|
84
|
+
del frame
|
|
85
|
+
|
|
86
|
+
return wrapper
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_stats_files(directory: str) -> Dict[str, Dict[str, List[Dict]]]:
|
|
90
|
+
"""加载目录下所有统计文件并按rank组织数据"""
|
|
91
|
+
rank_data = {}
|
|
92
|
+
for file in os.listdir(directory):
|
|
93
|
+
file_path = os.path.join(directory, file)
|
|
94
|
+
if file.startswith('rank') and file.endswith('.json'):
|
|
95
|
+
rank = os.path.basename(file.split('.')[0])[4:]
|
|
96
|
+
if not rank or not rank.isdigit():
|
|
97
|
+
logger.error(f"extract rank id from {file} failed")
|
|
98
|
+
raise ValueError
|
|
99
|
+
|
|
100
|
+
# 加载并存储数据
|
|
101
|
+
data = load_json(file_path)
|
|
102
|
+
rank_data[int(rank)] = data
|
|
103
|
+
|
|
104
|
+
return rank_data
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def stack_match(stack1: List[str], stack2: List[str], threshold: float = 0.8) -> bool:
|
|
108
|
+
"""
|
|
109
|
+
比较两个调用栈是否相似,同时考虑路径、函数名和代码行(各占1/3),每一层的相似度阈值需要达到0.8
|
|
110
|
+
|
|
111
|
+
参数:
|
|
112
|
+
- stack1: 第一个调用栈列表
|
|
113
|
+
- stack2: 第二个调用栈列表
|
|
114
|
+
- threshold: 相似度阈值,默认0.8
|
|
115
|
+
|
|
116
|
+
返回:
|
|
117
|
+
- 两个调用栈是否相似的布尔值
|
|
118
|
+
"""
|
|
119
|
+
if len(stack1) != len(stack2):
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
for frame1, frame2 in zip(stack1, stack2):
|
|
123
|
+
# 提取路径、函数名和代码行
|
|
124
|
+
path1, func1, code1 = _parse_frame(frame1)
|
|
125
|
+
path2, func2, code2 = _parse_frame(frame2)
|
|
126
|
+
|
|
127
|
+
# 计算相似度得分 (路径、函数名、代码行各占1/3权重)
|
|
128
|
+
path_score = _compare_path(path1, path2)
|
|
129
|
+
func_score = 1.0 if func1 == func2 else 0.0
|
|
130
|
+
# 代码相似度
|
|
131
|
+
code_score = difflib.SequenceMatcher(None, code1, code2).ratio()
|
|
132
|
+
|
|
133
|
+
frame_score = (path_score + func_score + code_score) / 3.0
|
|
134
|
+
if frame_score < threshold:
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _parse_frame(frame: str) -> Tuple[str, str, str]:
|
|
141
|
+
"""
|
|
142
|
+
解析栈帧字符串,提取路径、函数名和代码行
|
|
143
|
+
|
|
144
|
+
参数:
|
|
145
|
+
- frame: 栈帧字符串。格式为"File {path}, line {line}, in {func}, {code}"
|
|
146
|
+
|
|
147
|
+
返回:
|
|
148
|
+
- path, func, code
|
|
149
|
+
"""
|
|
150
|
+
path = func = code = ''
|
|
151
|
+
stack_info = frame.split(' ')
|
|
152
|
+
if len(stack_info) > 6:
|
|
153
|
+
path = stack_info[1][:-1]
|
|
154
|
+
func = stack_info[5][:-1]
|
|
155
|
+
code = ' '.join(stack_info[6:])
|
|
156
|
+
return path, func, code
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _compare_path(path1: str, path2: str) -> float:
|
|
160
|
+
"""比较两个路径的相似度,只考虑文件名"""
|
|
161
|
+
if not path1 or not path2:
|
|
162
|
+
return 0.0
|
|
163
|
+
|
|
164
|
+
# 提取文件名(忽略目录路径)
|
|
165
|
+
file1 = os.path.basename(path1)
|
|
166
|
+
file2 = os.path.basename(path2)
|
|
167
|
+
|
|
168
|
+
return 1.0 if file1 == file2 else 0.0
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def find_matching_stack(bench_stack: List[str], cmp_stacks: List[Dict]) -> Optional[Dict]:
|
|
172
|
+
"""
|
|
173
|
+
查找匹配的调用栈
|
|
174
|
+
|
|
175
|
+
参数:
|
|
176
|
+
- bench_stack: 基准侧的调用栈列表
|
|
177
|
+
- cmp_stacks: 比较侧的调用栈条目列表,每个条目是{'count': 次数, 'stack': 调用栈列表}
|
|
178
|
+
|
|
179
|
+
返回:
|
|
180
|
+
- 匹配的调用栈条目或None
|
|
181
|
+
"""
|
|
182
|
+
for cmp_entry in cmp_stacks:
|
|
183
|
+
if stack_match(cmp_entry['stack'], bench_stack):
|
|
184
|
+
return cmp_entry
|
|
185
|
+
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def stack_list_to_string(stack_list):
|
|
190
|
+
"""
|
|
191
|
+
将调用栈列表转换为换行分隔的字符串
|
|
192
|
+
如果输入是特殊标记(如"no match stack"),则直接返回
|
|
193
|
+
"""
|
|
194
|
+
if isinstance(stack_list, list):
|
|
195
|
+
return '\n'.join(stack_list)
|
|
196
|
+
return stack_list
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def compare_random_calls(bench_dir: str = 'bench', cmp_dir: str = 'cmp') -> pd.DataFrame:
|
|
200
|
+
"""比较两个目录下的随机调用栈统计,生成详细比对结果"""
|
|
201
|
+
bench_rank_data = load_stats_files(bench_dir)
|
|
202
|
+
cmp_rank_data = load_stats_files(cmp_dir)
|
|
203
|
+
|
|
204
|
+
# 获取所有rank
|
|
205
|
+
all_ranks = sorted(set(bench_rank_data.keys()) | set(cmp_rank_data.keys()))
|
|
206
|
+
|
|
207
|
+
results = []
|
|
208
|
+
|
|
209
|
+
for rank in all_ranks:
|
|
210
|
+
bench_data = bench_rank_data.get(rank, {})
|
|
211
|
+
cmp_data = cmp_rank_data.get(rank, {})
|
|
212
|
+
|
|
213
|
+
# 获取所有操作
|
|
214
|
+
all_ops = set(bench_data.keys()) | set(cmp_data.keys())
|
|
215
|
+
|
|
216
|
+
for op in all_ops:
|
|
217
|
+
bench_stacks = bench_data.get(op, [])
|
|
218
|
+
cmp_stacks = cmp_data.get(op, [])
|
|
219
|
+
|
|
220
|
+
# 处理bench侧的每个调用栈
|
|
221
|
+
for bench_entry in bench_stacks:
|
|
222
|
+
bench_stack = bench_entry['stack']
|
|
223
|
+
bench_count = bench_entry['count']
|
|
224
|
+
|
|
225
|
+
# 查找匹配的cmp侧调用栈
|
|
226
|
+
cmp_entry = find_matching_stack(bench_stack, cmp_stacks)
|
|
227
|
+
|
|
228
|
+
if cmp_entry:
|
|
229
|
+
cmp_count = cmp_entry['count']
|
|
230
|
+
check_result = bench_count == cmp_count
|
|
231
|
+
results.append([op, rank, bench_stack, cmp_entry['stack'], bench_count, cmp_count, check_result])
|
|
232
|
+
else:
|
|
233
|
+
# 没有匹配的调用栈
|
|
234
|
+
results.append([op, rank, bench_stack, "no match stack", bench_count, 0, False])
|
|
235
|
+
|
|
236
|
+
# 处理cmp侧中没有在bench侧出现的调用栈
|
|
237
|
+
for cmp_entry in cmp_stacks:
|
|
238
|
+
cmp_stack = cmp_entry['stack']
|
|
239
|
+
# 检查是否已经在上面处理过
|
|
240
|
+
if not any(stack_match(bench_entry['stack'], cmp_stack) for bench_entry in bench_stacks):
|
|
241
|
+
results.append([op, rank, "no match stack", cmp_stack, 0, cmp_entry['count'], False])
|
|
242
|
+
|
|
243
|
+
# 创建DataFrame
|
|
244
|
+
df = pd.DataFrame(results, columns=RandomChecker.result_header)
|
|
245
|
+
|
|
246
|
+
# 应用转换函数
|
|
247
|
+
df['bench_stack'] = df['bench_stack'].apply(stack_list_to_string)
|
|
248
|
+
df['cmp_stack'] = df['cmp_stack'].apply(stack_list_to_string)
|
|
249
|
+
|
|
250
|
+
return df
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def torch_patchs():
|
|
254
|
+
"""补丁Torch随机函数"""
|
|
255
|
+
import torch
|
|
256
|
+
torch_patches = {
|
|
257
|
+
'rand': torch.rand,
|
|
258
|
+
'randint': torch.randint,
|
|
259
|
+
'randn': torch.randn,
|
|
260
|
+
'rand_like': torch.rand_like,
|
|
261
|
+
'randint_like': torch.randint_like,
|
|
262
|
+
'randn_like': torch.randn_like,
|
|
263
|
+
'manual_seed': torch.manual_seed
|
|
264
|
+
}
|
|
265
|
+
for name, func in torch_patches.items():
|
|
266
|
+
setattr(torch, name, track_random_call(func, f"torch.{name}"))
|
|
267
|
+
|
|
268
|
+
tensor_patches = {
|
|
269
|
+
'exponential_': torch.Tensor.exponential_,
|
|
270
|
+
'geometric_': torch.Tensor.geometric_,
|
|
271
|
+
'log_normal_': torch.Tensor.log_normal_,
|
|
272
|
+
'cauchy_': torch.Tensor.cauchy_
|
|
273
|
+
}
|
|
274
|
+
for name, func in tensor_patches.items():
|
|
275
|
+
setattr(torch.Tensor, name, track_random_call(func, f"torch.Tensor.{name}"))
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def mindspore_patchs():
|
|
279
|
+
"""补丁MindSpore随机函数"""
|
|
280
|
+
import mindspore
|
|
281
|
+
|
|
282
|
+
mindspore_ops_patches = {
|
|
283
|
+
'rand': mindspore.ops.uniform,
|
|
284
|
+
'randint': mindspore.ops.randint,
|
|
285
|
+
'randn': mindspore.ops.normal
|
|
286
|
+
}
|
|
287
|
+
for name, func in mindspore_ops_patches.items():
|
|
288
|
+
setattr(mindspore.ops, name, track_random_call(func, f"mindspore.ops.{name}"))
|
|
289
|
+
|
|
290
|
+
mindspore_patches = {
|
|
291
|
+
'manual_seed': mindspore.set_seed
|
|
292
|
+
}
|
|
293
|
+
for name, func in mindspore_patches.items():
|
|
294
|
+
setattr(mindspore, name, track_random_call(func, f"mindspore.{name}"))
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@register_checker_item("random")
|
|
298
|
+
class RandomChecker(BaseChecker):
|
|
299
|
+
input_needed = None
|
|
300
|
+
target_name_in_zip = "random"
|
|
301
|
+
result_header = ['op', 'rank', 'bench_stack', 'cmp_stack', 'bench_count', 'cmp_count', 'check_result']
|
|
302
|
+
write_once = False
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def pack(pack_input):
|
|
306
|
+
"""打包随机调用统计到zip文件"""
|
|
307
|
+
output_zip_path = pack_input.output_zip_path
|
|
308
|
+
|
|
309
|
+
def collect_input(model, args, kwargs, step):
|
|
310
|
+
if RandomChecker.write_once:
|
|
311
|
+
return
|
|
312
|
+
|
|
313
|
+
random_stats_dir = os.path.join(RandomChecker.target_name_in_zip)
|
|
314
|
+
stats_filepath = os.path.join(random_stats_dir, f"rank{FmkAdp.get_rank_id()}.json")
|
|
315
|
+
|
|
316
|
+
# 转换为JSON格式:{操作名: [{count: 次数, stack: 调用栈列表}]}
|
|
317
|
+
stats_json = {}
|
|
318
|
+
for op_name, entries in random_op_stats.items():
|
|
319
|
+
stats_json[op_name] = entries
|
|
320
|
+
|
|
321
|
+
create_file_in_zip(output_zip_path, stats_filepath, json.dumps(stats_json, indent=4))
|
|
322
|
+
config_checking_print(f"已将随机调用统计打包到: {stats_filepath}")
|
|
323
|
+
RandomChecker.write_once = True
|
|
324
|
+
|
|
325
|
+
register_pre_forward_fun_list(collect_input)
|
|
326
|
+
|
|
327
|
+
@staticmethod
|
|
328
|
+
def compare(bench_dir, cmp_dir, output_path, fmk):
|
|
329
|
+
"""比较两组随机调用统计"""
|
|
330
|
+
bench_stats_path = os.path.join(bench_dir, RandomChecker.target_name_in_zip)
|
|
331
|
+
cmp_stats_path = os.path.join(cmp_dir, RandomChecker.target_name_in_zip)
|
|
332
|
+
|
|
333
|
+
df = compare_random_calls(bench_stats_path, cmp_stats_path)
|
|
334
|
+
pass_check = False not in df['check_result'].values
|
|
335
|
+
|
|
336
|
+
return RandomChecker.target_name_in_zip, pass_check, df
|
|
337
|
+
|
|
338
|
+
@staticmethod
|
|
339
|
+
def apply_patches(fmk=Const.PT_FRAMEWORK):
|
|
340
|
+
"""应用随机函数补丁"""
|
|
341
|
+
# 补丁Python random模块
|
|
342
|
+
random_patches = {
|
|
343
|
+
'random': random.random,
|
|
344
|
+
'randint': random.randint,
|
|
345
|
+
'uniform': random.uniform,
|
|
346
|
+
'choice': random.choice
|
|
347
|
+
}
|
|
348
|
+
for name, func in random_patches.items():
|
|
349
|
+
setattr(random, name, track_random_call(func, f"random.{name}"))
|
|
350
|
+
|
|
351
|
+
# 补丁Numpy随机函数
|
|
352
|
+
np_random_patches = {
|
|
353
|
+
'rand': np.random.rand,
|
|
354
|
+
'randint': np.random.randint,
|
|
355
|
+
'choice': np.random.choice,
|
|
356
|
+
'normal': np.random.normal
|
|
357
|
+
}
|
|
358
|
+
for name, func in np_random_patches.items():
|
|
359
|
+
setattr(np.random, name, track_random_call(func, f"np.random.{name}"))
|
|
360
|
+
|
|
361
|
+
# 补丁框架特定随机函数
|
|
362
|
+
if fmk == Const.PT_FRAMEWORK:
|
|
363
|
+
torch_patchs()
|
|
364
|
+
elif fmk == Const.MS_FRAMEWORK:
|
|
365
|
+
mindspore_patchs()
|
|
366
|
+
else:
|
|
367
|
+
raise Exception(f"不支持的框架: {fmk}, 支持的框架: {FmkAdp.supported_fmk}")
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import json
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.file_utils import create_file_in_zip, os_walk_for_files, load_json
|
|
21
|
+
from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
22
|
+
from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list
|
|
23
|
+
from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features
|
|
24
|
+
from msprobe.core.common.framework_adapter import FmkAdp
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def collect_weights_data(model):
|
|
28
|
+
weights_data = {}
|
|
29
|
+
for name, param in FmkAdp.named_parameters(model):
|
|
30
|
+
if param.dtype != FmkAdp.dtype("float32"):
|
|
31
|
+
param = param.float()
|
|
32
|
+
weights_data[name] = get_tensor_features(param)
|
|
33
|
+
return weights_data
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def compare_weight_file(bench_file, cmp_file):
|
|
37
|
+
bench_data = load_json(bench_file)
|
|
38
|
+
cmp_data = load_json(cmp_file)
|
|
39
|
+
|
|
40
|
+
results = []
|
|
41
|
+
for weight_name in set(bench_data.keys()) | set(cmp_data.keys()):
|
|
42
|
+
result = {
|
|
43
|
+
"weight_name": weight_name,
|
|
44
|
+
"equal": None,
|
|
45
|
+
"max_relative_diff": None,
|
|
46
|
+
"min_relative_diff": None,
|
|
47
|
+
"mean_relative_diff": None,
|
|
48
|
+
"norm_relative_diff": None
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
if weight_name not in bench_data:
|
|
52
|
+
result["equal"] = "only cmp have"
|
|
53
|
+
results.append(result)
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
if weight_name not in cmp_data:
|
|
57
|
+
result["equal"] = "only bench have"
|
|
58
|
+
results.append(result)
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
bench_vals = bench_data[weight_name]
|
|
62
|
+
cmp_vals = cmp_data[weight_name]
|
|
63
|
+
keys = ["max", "min", "mean", "norm"]
|
|
64
|
+
equal = all([bench_vals[k] == cmp_vals[k] for k in keys])
|
|
65
|
+
result["equal"] = equal
|
|
66
|
+
|
|
67
|
+
for key in keys:
|
|
68
|
+
diff_key = f"{key}_relative_diff"
|
|
69
|
+
result[diff_key] = (abs(bench_vals[key] - cmp_vals[key]) / bench_vals[key]) \
|
|
70
|
+
if bench_vals[key] != 0 else None
|
|
71
|
+
|
|
72
|
+
results.append(result)
|
|
73
|
+
|
|
74
|
+
return results
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def compare_weight(bench_dir, cmp_dir):
|
|
78
|
+
all_results = []
|
|
79
|
+
bench_files_info = os_walk_for_files(bench_dir, 10)
|
|
80
|
+
for info in bench_files_info:
|
|
81
|
+
if not info["file"].endswith('.json'):
|
|
82
|
+
continue
|
|
83
|
+
bench_file = os.path.join(info["root"], info["file"])
|
|
84
|
+
relative_path = os.path.relpath(info["root"], bench_dir)
|
|
85
|
+
cmp_root = os.path.join(cmp_dir, relative_path)
|
|
86
|
+
cmp_file = os.path.join(cmp_root, info["file"])
|
|
87
|
+
|
|
88
|
+
path_list = relative_path.split(os.sep)
|
|
89
|
+
if len(path_list) < 2:
|
|
90
|
+
raise Exception("Can not compare weights because the extracted file has been corrupted!")
|
|
91
|
+
step = int(path_list[0].replace("step", ""))
|
|
92
|
+
rank = int(path_list[1].replace("rank", ""))
|
|
93
|
+
|
|
94
|
+
if not os.path.exists(cmp_file):
|
|
95
|
+
bench_data = load_json(bench_file)
|
|
96
|
+
for weight_name in bench_data.keys():
|
|
97
|
+
result = {
|
|
98
|
+
"step": step,
|
|
99
|
+
"rank": rank,
|
|
100
|
+
"weight_name": weight_name,
|
|
101
|
+
"equal": "only bench have",
|
|
102
|
+
"max_relative_diff": None,
|
|
103
|
+
"min_relative_diff": None,
|
|
104
|
+
"mean_relative_diff": None,
|
|
105
|
+
"norm_relative_diff": None
|
|
106
|
+
}
|
|
107
|
+
all_results.append(result)
|
|
108
|
+
else:
|
|
109
|
+
results = compare_weight_file(bench_file, cmp_file)
|
|
110
|
+
for res in results:
|
|
111
|
+
res["step"] = step
|
|
112
|
+
res["rank"] = rank
|
|
113
|
+
all_results.append(res)
|
|
114
|
+
|
|
115
|
+
df = pd.DataFrame(all_results, columns=WeightsChecker.result_header)
|
|
116
|
+
df = df.sort_values(by=['step', 'rank'], ascending=[True, True])
|
|
117
|
+
return df
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@register_checker_item("weights")
|
|
121
|
+
class WeightsChecker(BaseChecker):
|
|
122
|
+
input_needed = "model"
|
|
123
|
+
multi_rank = True
|
|
124
|
+
|
|
125
|
+
target_name_in_zip = "weights"
|
|
126
|
+
result_header = ["step", "rank", "weight_name", "equal", "max_relative_diff",
|
|
127
|
+
"min_relative_diff", "mean_relative_diff", "norm_relative_diff"]
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def pack(pack_input):
|
|
131
|
+
output_zip_path = pack_input.output_zip_path
|
|
132
|
+
|
|
133
|
+
def collect_weights(model, args, kwargs, step):
|
|
134
|
+
weights_data_dict = collect_weights_data(model)
|
|
135
|
+
weights_data_filepath = os.path.join(WeightsChecker.target_name_in_zip,
|
|
136
|
+
f"step{step}", f"rank{FmkAdp.get_rank_id()}", "weight.json")
|
|
137
|
+
create_file_in_zip(output_zip_path, weights_data_filepath, json.dumps(weights_data_dict, indent=4))
|
|
138
|
+
config_checking_print(f"add weights info to zip")
|
|
139
|
+
register_pre_forward_fun_list(collect_weights)
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def compare(bench_dir, cmp_dir, output_path, fmk):
|
|
143
|
+
bench_weight_pack_path = os.path.join(bench_dir, WeightsChecker.target_name_in_zip)
|
|
144
|
+
cmp_weight_pack_path = os.path.join(cmp_dir, WeightsChecker.target_name_in_zip)
|
|
145
|
+
df = compare_weight(bench_weight_pack_path, cmp_weight_pack_path)
|
|
146
|
+
pass_check = False not in df['equal'].values
|
|
147
|
+
return WeightsChecker.target_name_in_zip, pass_check, df
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Dict
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.file_utils import save_json, check_path_before_create, check_path_not_exists
|
|
20
|
+
from msprobe.core.common.log import logger
|
|
21
|
+
from msprobe.core.config_check.ckpt_compare.megatron_loader import load_megatron_weights
|
|
22
|
+
from msprobe.core.config_check.ckpt_compare.metrics import METRIC_FUNC
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict:
|
|
27
|
+
"""Compare weights between two checkpoints using cosine similarity and L2 distance.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
ckpt_path1 (str): Path to first checkpoint directory
|
|
31
|
+
ckpt_path2 (str): Path to second checkpoint directory
|
|
32
|
+
output_path (str): Path to save comparison results JSON file
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dict: Dictionary containing comparison metrics for each parameter. The dictionary has the following structure:
|
|
36
|
+
{
|
|
37
|
+
"param_name": {
|
|
38
|
+
"cosine_similarity": float, # Cosine similarity between parameter tensors
|
|
39
|
+
"l2_distance": float, # L2 distance between parameter tensors
|
|
40
|
+
"shape": List[int] # Shape of the parameter tensors
|
|
41
|
+
},
|
|
42
|
+
...
|
|
43
|
+
}
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
# Load both checkpoints
|
|
47
|
+
check_path_before_create(output_path)
|
|
48
|
+
check_path_not_exists(output_path)
|
|
49
|
+
weights1 = load_megatron_weights(ckpt_path1)
|
|
50
|
+
weights2 = load_megatron_weights(ckpt_path2)
|
|
51
|
+
|
|
52
|
+
# Initialize results dictionary
|
|
53
|
+
results = {}
|
|
54
|
+
|
|
55
|
+
# Compare weights with matching keys
|
|
56
|
+
common = set(weights1) & set(weights2)
|
|
57
|
+
logger.warning(f'Parameters not in ckpt2: {set(weights1) - set(weights2)}')
|
|
58
|
+
logger.warning(f'Parameters not in ckpt1: {set(weights2) - set(weights1)}')
|
|
59
|
+
for key in tqdm(common):
|
|
60
|
+
tensor1 = weights1[key]
|
|
61
|
+
tensor2 = weights2[key]
|
|
62
|
+
|
|
63
|
+
results[key] = {}
|
|
64
|
+
for metric, func in METRIC_FUNC.items():
|
|
65
|
+
try:
|
|
66
|
+
results[key][metric] = func(tensor1, tensor2)
|
|
67
|
+
except Exception as e:
|
|
68
|
+
results[key][metric] = 'error'
|
|
69
|
+
logger.warning(f'Error when calculate {metric} for reason: {e}')
|
|
70
|
+
|
|
71
|
+
# Write results to JSON file
|
|
72
|
+
save_json(output_path, results, indent=4)
|
|
73
|
+
logger.info(f"Comparison results written to {output_path}")
|
|
74
|
+
return results
|