mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,138 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import json
18
+ import pandas as pd
19
+ from msprobe.core.common.file_utils import create_file_in_zip, load_json
20
+ from msprobe.core.config_check.checkers.base_checker import BaseChecker
21
+ from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list
22
+ from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features
23
+ from msprobe.core.common.decorator import recursion_depth_decorator
24
+ from msprobe.core.common.framework_adapter import FmkAdp
25
+
26
+
27
+ @recursion_depth_decorator("config_check: process_obj")
28
+ def process_obj(obj):
29
+ if FmkAdp.is_tensor(obj):
30
+ return get_tensor_features(obj)
31
+ elif isinstance(obj, (tuple, list)):
32
+ return {i: process_obj(x) for i, x in enumerate(obj)}
33
+ elif isinstance(obj, dict):
34
+ return {k: process_obj(v) for k, v in obj.items()}
35
+ else:
36
+ return ""
37
+
38
+
39
+ def parse_args_and_kargs(args, kwargs):
40
+ processed_args = process_obj(args)
41
+ processed_kargs = process_obj(kwargs)
42
+
43
+ return {
44
+ 'args': processed_args,
45
+ 'kwargs': processed_kargs
46
+ }
47
+
48
+
49
+ @recursion_depth_decorator("config_check: compare_dataset_dicts")
50
+ def compare_dataset_dicts(dict1, dict2, tag=''):
51
+ results = []
52
+ # 处理 dict1 中的键
53
+ for key in dict1:
54
+ new_tag = f"{tag}.{key}" if tag else key
55
+ if key not in dict2:
56
+ result = {'tag': new_tag, 'equal': False, 'status': 'delete'}
57
+ results.append(result)
58
+ continue
59
+ value1 = dict1[key]
60
+ value2 = dict2[key]
61
+ if not isinstance(value1, dict):
62
+ continue
63
+ if set(value1.keys()) == {'max', 'min', 'mean', 'norm'}:
64
+ equal = value1 == value2
65
+ relative_diffs = {
66
+ f"{k}_relative_diff": (abs(value1[k] - value2[k]) / value1[k]) if value1[k] != 0 else None
67
+ for k in ['max', 'min', 'mean', 'norm']
68
+ }
69
+ result = {'tag': new_tag, 'equal': equal, 'status': 'unchanged'}
70
+ result.update(relative_diffs)
71
+ results.append(result)
72
+ else:
73
+ results.extend(compare_dataset_dicts(value1, value2, new_tag))
74
+ # 处理 dict2 中独有的键
75
+ for key in dict2:
76
+ if key not in dict1:
77
+ new_tag = f"{tag}.{key}" if tag else key
78
+ result = {'tag': new_tag, 'equal': False, 'status': 'added'}
79
+ results.append(result)
80
+ return results
81
+
82
+
83
+ def compare_dataset(bench_dir, cmp_dir):
84
+ all_results = []
85
+ for step in os.listdir(bench_dir):
86
+ step_path_bench = os.path.join(bench_dir, step)
87
+ if not os.path.isdir(step_path_bench):
88
+ continue
89
+ step_path_cmp = os.path.join(cmp_dir, step)
90
+ for rank in os.listdir(step_path_bench):
91
+ rank_path_bench = os.path.join(step_path_bench, rank, 'dataset.json')
92
+ rank_path_cmp = os.path.join(step_path_cmp, rank, 'dataset.json')
93
+ if not os.path.isfile(rank_path_bench) or not os.path.isfile(rank_path_cmp):
94
+ continue
95
+
96
+ dict1 = load_json(rank_path_bench)
97
+ dict2 = load_json(rank_path_cmp)
98
+ results = compare_dataset_dicts(dict1, dict2)
99
+ for result in results:
100
+ result['step'] = int(step.replace("step", ""))
101
+ result['rank'] = int(rank.replace("rank", ""))
102
+ all_results.extend(results)
103
+
104
+ df = pd.DataFrame(all_results, columns=DatasetChecker.result_header)
105
+ df = df.sort_values(by=['step', 'rank'], ascending=[True, True])
106
+ return df
107
+
108
+
109
+ @register_checker_item("dataset")
110
+ class DatasetChecker(BaseChecker):
111
+ input_needed = "model"
112
+ multi_rank = True
113
+
114
+ target_name_in_zip = "dataset"
115
+ result_header = ['step', 'rank', 'tag', 'equal', 'max_relative_diff',
116
+ 'min_relative_diff', 'mean_relative_diff', 'norm_relative_diff']
117
+
118
+ @staticmethod
119
+ def pack(pack_input):
120
+ output_zip_path = pack_input.output_zip_path
121
+
122
+ def collect_input(model, args, kwargs, step):
123
+ features = parse_args_and_kargs(args, kwargs)
124
+ dataset_filepath = os.path.join(DatasetChecker.target_name_in_zip,
125
+ f"step{step}", f"rank{FmkAdp.get_rank_id()}", "dataset.json")
126
+ create_file_in_zip(output_zip_path, dataset_filepath, json.dumps(features, indent=4))
127
+ config_checking_print(f"add first dataset input features to zip")
128
+
129
+ register_pre_forward_fun_list(collect_input)
130
+
131
+ @staticmethod
132
+ def compare(bench_dir, cmp_dir, output_path, fmk):
133
+ bench_dataset_pack_path = os.path.join(bench_dir, DatasetChecker.target_name_in_zip)
134
+ cmp_dataset_pack_path = os.path.join(cmp_dir, DatasetChecker.target_name_in_zip)
135
+
136
+ df = compare_dataset(bench_dataset_pack_path, cmp_dataset_pack_path)
137
+ pass_check = False not in df['equal'].values
138
+ return DatasetChecker.target_name_in_zip, pass_check, df
@@ -0,0 +1,96 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import json
18
+
19
+ import pandas as pd
20
+
21
+ from msprobe.core.common.file_utils import load_json, load_yaml, create_file_with_content, create_file_in_zip
22
+ from msprobe.core.config_check.checkers.base_checker import BaseChecker
23
+ from msprobe.core.config_check.config_checker import register_checker_item
24
+ from msprobe.core.config_check.utils.utils import config_checking_print
25
+ from msprobe.core.common.const import Const
26
+
27
+
28
+ dirpath = os.path.dirname(__file__)
29
+ env_yaml_path = os.path.join(dirpath, "../resource/env.yaml")
30
+
31
+
32
+ def collect_env_data():
33
+ result = {}
34
+ for key, value in os.environ.items():
35
+ result[key] = value
36
+ return result
37
+
38
+
39
+ def get_device_type(env_json):
40
+ for key in env_json.keys():
41
+ if Const.ASCEND in key:
42
+ return Const.NPU_LOWERCASE
43
+ return Const.GPU_LOWERCASE
44
+
45
+
46
+ def compare_env_data(npu_path, bench_path):
47
+ necessary_env = load_yaml(env_yaml_path)
48
+ cmp_data = load_json(npu_path)
49
+ cmp_type = get_device_type(cmp_data)
50
+ bench_data = load_json(bench_path)
51
+ bench_type = get_device_type(bench_data)
52
+ data = []
53
+ for _, value in necessary_env.items():
54
+ cmp_env = value.get(cmp_type)
55
+ bench_env = value.get(bench_type)
56
+ if not bench_env and not cmp_env:
57
+ continue
58
+ elif cmp_env:
59
+ cmp_env_name = cmp_env["name"]
60
+ cmp_value = cmp_data.get(cmp_env_name, value[cmp_type]["default_value"])
61
+ if not bench_env:
62
+ data.append(["only cmp has this env", cmp_env["name"], "", cmp_value, "warning"])
63
+ continue
64
+ bench_env_name = bench_env["name"]
65
+ bench_value = bench_data.get(bench_env_name, value[bench_type]["default_value"])
66
+ if cmp_value != bench_value:
67
+ data.append([bench_env_name, cmp_env_name, bench_value, cmp_value, "error"])
68
+ else:
69
+ bench_env_name = bench_env["name"]
70
+ bench_value = bench_data.get(bench_env_name) if bench_data.get(bench_env_name) else value[bench_type][
71
+ "default_value"]
72
+ data.append([bench_env_name, "only bench has this env", bench_value, "", "warning"])
73
+ df = pd.DataFrame(data, columns=EnvArgsChecker.result_header)
74
+ return df
75
+
76
+
77
+ @register_checker_item("env")
78
+ class EnvArgsChecker(BaseChecker):
79
+
80
+ target_name_in_zip = "env"
81
+ result_header = ["bench_env_name", "cmp_env_name", "bench_value", "cmp_value", "level"]
82
+
83
+ @staticmethod
84
+ def pack(pack_input):
85
+ output_zip_path = pack_input.output_zip_path
86
+ env_args_dict = collect_env_data()
87
+ create_file_in_zip(output_zip_path, EnvArgsChecker.target_name_in_zip, json.dumps(env_args_dict, indent=4))
88
+ config_checking_print(f"add env args to zip")
89
+
90
+ @staticmethod
91
+ def compare(bench_dir, cmp_dir, output_path, fmk):
92
+ bench_env_data = os.path.join(bench_dir, EnvArgsChecker.target_name_in_zip)
93
+ cmp_env_data = os.path.join(cmp_dir, EnvArgsChecker.target_name_in_zip)
94
+ df = compare_env_data(bench_env_data, cmp_env_data)
95
+ pass_check = "error" not in df['level'].values
96
+ return EnvArgsChecker.target_name_in_zip, pass_check, df
@@ -0,0 +1,170 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import json
18
+ from difflib import SequenceMatcher
19
+
20
+ from typing import Union, List, Dict, Any
21
+ import pandas as pd
22
+
23
+ from msprobe.core.config_check.checkers.base_checker import BaseChecker
24
+ from msprobe.core.config_check.config_checker import register_checker_item
25
+ from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict
26
+ from msprobe.core.config_check.utils.hyperparameter_parser import ParserFactory
27
+ from msprobe.core.common.file_utils import (os_walk_for_files, create_file_in_zip, load_json, create_file_with_list,
28
+ FileOpen, load_yaml)
29
+ from msprobe.core.common.const import Const
30
+
31
+
32
+ dirpath = os.path.dirname(__file__)
33
+ hyperparameters_path = os.path.join(dirpath, "../resource/hyperparameter.yaml")
34
+ parameter_name_mapping = load_yaml(os.path.realpath(hyperparameters_path))
35
+ hyperparameters_dict = {}
36
+
37
+
38
+ @register_checker_item("hyperparameter")
39
+ class HyperparameterChecker(BaseChecker):
40
+ target_name_in_zip = "hyperparameters"
41
+ result_header = ["file_name", "bench_para", "cmp_para", "bench_value", "cmp_value", "matched_with", "level"]
42
+ hyperparameters_file_list = ["hyperparameters_static.json", "hyperparameters_dynamic.json"]
43
+
44
+ @staticmethod
45
+ def pack(pack_input):
46
+ shell_path = pack_input.shell_path
47
+ output_zip_path = pack_input.output_zip_path
48
+
49
+ if shell_path:
50
+ if not isinstance(shell_path, list):
51
+ raise TypeError("shell_path should be a list of file paths.")
52
+
53
+ hyperparameters = {}
54
+ parser_factory = ParserFactory()
55
+ for script_path in shell_path:
56
+ if os.path.isfile(script_path):
57
+ parser = parser_factory.get_parser(os.path.splitext(script_path)[1])
58
+ update_dict(hyperparameters, parser.run(os.path.realpath(script_path)))
59
+ else:
60
+ config_checking_print(f"Warning: Script path {script_path} is not a file.")
61
+ if hyperparameters:
62
+ create_file_in_zip(output_zip_path,
63
+ os.path.join(HyperparameterChecker.target_name_in_zip,
64
+ HyperparameterChecker.hyperparameters_file_list[0]),
65
+ json.dumps(hyperparameters, indent=4))
66
+ config_checking_print(f"add static hyperparameters args to zip")
67
+ else:
68
+ config_checking_print(f"Warning: Failed to extract hyperparameters from script {shell_path}")
69
+ if hyperparameters_dict:
70
+ create_file_in_zip(output_zip_path,
71
+ os.path.join(HyperparameterChecker.target_name_in_zip,
72
+ HyperparameterChecker.hyperparameters_file_list[1]),
73
+ json.dumps(vars(hyperparameters_dict), default=lambda x: None, indent=4))
74
+ config_checking_print(f"add dynamic hyperparameters args to zip")
75
+
76
+ @staticmethod
77
+ def compare(bench_dir, cmp_dir, output_path, fmk):
78
+ all_diffs = []
79
+ for file_name in HyperparameterChecker.hyperparameters_file_list:
80
+ bench_model_dir = os.path.join(bench_dir, HyperparameterChecker.target_name_in_zip, file_name)
81
+ cmp_model_dir = os.path.join(cmp_dir, HyperparameterChecker.target_name_in_zip, file_name)
82
+ if os.path.isfile(bench_model_dir) and os.path.isfile(cmp_model_dir):
83
+ bench_hyperparameters = load_json(bench_model_dir)
84
+ cmp_hyperparameters = load_json(cmp_model_dir)
85
+ all_diffs.extend(
86
+ HyperparameterChecker.compare_param(bench_hyperparameters, cmp_hyperparameters, file_name))
87
+ df = pd.DataFrame(all_diffs, columns=HyperparameterChecker.result_header)
88
+ pass_check = "error" not in df["level"].values
89
+ return HyperparameterChecker.target_name_in_zip, pass_check, df
90
+
91
+ @staticmethod
92
+ def compare_param(bench_params, cmp_params, file_name):
93
+ all_diffs = []
94
+ bench_param_names = bench_params.keys()
95
+ for bench_param_name in bench_param_names:
96
+ matched_cmp_param_name, matched_with = HyperparameterChecker._fuzzy_match_parameter(bench_param_name,
97
+ cmp_params)
98
+ bench_param_value = bench_params[bench_param_name]
99
+ if matched_cmp_param_name:
100
+ cmp_param_value = cmp_params[matched_cmp_param_name]
101
+ if bench_param_value != cmp_param_value:
102
+ all_diffs.append(
103
+ [file_name, bench_param_name, matched_cmp_param_name, bench_param_value, cmp_param_value,
104
+ matched_with, "error"])
105
+ del cmp_params[matched_cmp_param_name]
106
+ else:
107
+ all_diffs.append(
108
+ [file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "", "warning"])
109
+ for cmp_param_name, cmp_param_value in cmp_params.items():
110
+ all_diffs.append([file_name, "Only in comparison", cmp_param_name, "", cmp_param_value, "", "warning"])
111
+ all_diffs.sort()
112
+ return all_diffs
113
+
114
+ @staticmethod
115
+ def apply_patches(fmk):
116
+ try:
117
+ from megatron import training
118
+
119
+ def collect_hyperparameter_wrapper(func):
120
+ def wrapper(*args, **kwargs):
121
+ global hyperparameters_dict
122
+ result = func(*args, **kwargs)
123
+ if not hyperparameters_dict:
124
+ hyperparameters_dict = result
125
+ return result
126
+ return wrapper
127
+ training.get_args = collect_hyperparameter_wrapper(training.get_args)
128
+ except ImportError:
129
+ config_checking_print("No megatron find.")
130
+ except Exception as e:
131
+ config_checking_print(f"Patch megatron method failed, detail:{str(e)}")
132
+
133
+ @staticmethod
134
+ def _fuzzy_match_parameter(param_name: str, available_params: Dict[str, Any]):
135
+ """
136
+ Fuzzy matches a parameter name against available parameter names using predefined
137
+ mappings and string similarity.
138
+ """
139
+ if param_name in available_params:
140
+ return param_name, Const.MATCH_MODE_NAME
141
+
142
+ canonical_name = None
143
+ for standard_name, aliases in parameter_name_mapping.items():
144
+ if param_name == standard_name or param_name in aliases:
145
+ canonical_name = standard_name
146
+ break
147
+
148
+ if canonical_name:
149
+ if canonical_name in available_params:
150
+ return canonical_name, Const.MATCH_MODE_MAPPING
151
+ for alias in parameter_name_mapping[canonical_name]:
152
+ if alias in available_params:
153
+ config_checking_print(
154
+ f"Matched '{param_name}' to alias '{alias}' via canonical name '{canonical_name}'")
155
+ return alias, Const.MATCH_MODE_MAPPING
156
+
157
+ best_match_name = None
158
+ best_match_ratio = 0.8
159
+ for available_param_name in available_params:
160
+ ratio = SequenceMatcher(None, param_name.lower(), available_param_name.lower()).ratio()
161
+ if ratio > best_match_ratio:
162
+ best_match_ratio = ratio
163
+ best_match_name = available_param_name
164
+
165
+ if best_match_name:
166
+ config_checking_print(
167
+ f"Fuzzy matched parameter '{param_name}' to '{best_match_name}' (similarity: {best_match_ratio:.2f})")
168
+ return best_match_name, f"{Const.MATCH_MODE_SIMILARITY}:{best_match_ratio}"
169
+
170
+ return None, None
@@ -0,0 +1,90 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import pandas as pd
18
+ try:
19
+ import importlib.metadata as metadata
20
+ except ImportError:
21
+ import importlib_metadata as metadata
22
+
23
+ from msprobe.core.common.file_utils import load_yaml, create_file_in_zip
24
+ from msprobe.core.config_check.checkers.base_checker import BaseChecker
25
+ from msprobe.core.config_check.config_checker import register_checker_item
26
+ from msprobe.core.config_check.utils.utils import config_checking_print
27
+ from msprobe.core.common.file_utils import FileOpen, save_excel
28
+
29
+ dirpath = os.path.dirname(__file__)
30
+ depend_path = os.path.join(dirpath, "../resource/dependency.yaml")
31
+
32
+
33
+ def load_pip_txt(file_path):
34
+ output_dir = {}
35
+ with FileOpen(file_path, 'r', encoding='utf-8') as file:
36
+ lines = file.readlines()
37
+ for line in lines:
38
+ info_list = line.strip().split("=")
39
+ output_dir[info_list[0]] = "" if len(info_list) != 2 else info_list[1]
40
+ return output_dir
41
+
42
+
43
+ def collect_pip_data():
44
+ result = ""
45
+ packages = metadata.distributions()
46
+ for pkg in packages:
47
+ if pkg.metadata:
48
+ result += f"{pkg.metadata.get('Name')}={pkg.version}\n"
49
+ return result
50
+
51
+
52
+ def compare_pip_data(bench_pip_path, cmp_pip_path, fmk):
53
+ necessary_dependency = load_yaml(depend_path)["dependency"]
54
+ necessary_dependency.append(fmk)
55
+ bench_data = load_pip_txt(bench_pip_path)
56
+ cmp_data = load_pip_txt(cmp_pip_path)
57
+ data = []
58
+ for package in necessary_dependency:
59
+ bench_version = bench_data.get(package)
60
+ cmp_version = cmp_data.get(package)
61
+
62
+ if bench_version != cmp_version:
63
+ data.append([package, bench_version if bench_version else 'None',
64
+ cmp_version if cmp_version else 'None',
65
+ "error"])
66
+
67
+ df = pd.DataFrame(data, columns=PipPackageChecker.result_header)
68
+ return df
69
+
70
+
71
+ @register_checker_item("pip")
72
+ class PipPackageChecker(BaseChecker):
73
+
74
+ target_name_in_zip = "pip"
75
+ result_header = ['package', 'bench version', 'cmp version', 'level']
76
+
77
+ @staticmethod
78
+ def pack(pack_input):
79
+ output_zip_path = pack_input.output_zip_path
80
+ pip_data = collect_pip_data()
81
+ create_file_in_zip(output_zip_path, PipPackageChecker.target_name_in_zip, pip_data)
82
+ config_checking_print(f"add pip info to zip")
83
+
84
+ @staticmethod
85
+ def compare(bench_dir, cmp_dir, output_path, fmk):
86
+ bench_pip_path = os.path.join(bench_dir, PipPackageChecker.target_name_in_zip)
87
+ cmp_pip_path = os.path.join(cmp_dir, PipPackageChecker.target_name_in_zip)
88
+ df = compare_pip_data(bench_pip_path, cmp_pip_path, fmk)
89
+ pass_check = "error" not in df['level'].values
90
+ return PipPackageChecker.target_name_in_zip, pass_check, df