mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,302 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ from collections import defaultdict
19
+ from typing import Dict
20
+ import numpy as np
21
+ from msprobe.core.common.log import logger
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import FileOpen, load_yaml
25
+ from msprobe.core.common.framework_adapter import FmkAdp
26
+
27
+ # both weights and bias are partitioned in column parallel
28
+ COLUMN_PARALLEL_PARAMS = ['linear_qkv', 'linear_fc1', 'word_embeddings.weight', 'output_layer.weight']
29
+ # only weights are partitioned in column parallel
30
+ ROW_PARALLEL_PARAMS = ['linear_fc2.weight', 'linear_proj.weight']
31
+ ARGS = 'args'
32
+ LAYER_IDX_PATTERN = re.compile('layers\.(\d+)\.')
33
+ EXPERT_IDX_PATTERN = re.compile('experts\.(\d+)\.')
34
+ ITER_DIR_PATTERN = re.compile('iter_([\d]{7})')
35
+
36
+
37
+ @recursion_depth_decorator('')
38
+ def _get_parameter(weights, prefix=''):
39
+ for k, v in weights.items():
40
+ name = Const.SEP.join([prefix, k]).strip(Const.SEP)
41
+ if isinstance(v, dict):
42
+ yield from _get_parameter(v, prefix=name)
43
+ elif FmkAdp.is_tensor(v):
44
+ yield name, FmkAdp.asnumpy(v)
45
+
46
+
47
+ def _map_to_mcore_local_names(param_name: str) -> str:
48
+ """Map parameter names to mcore + local transformer implementation names."""
49
+ mcore_local_map = load_yaml(os.path.join(os.path.dirname(__file__), 'name_mapping.yaml'))
50
+ for other_name, mcore_local_name in mcore_local_map.items():
51
+ param_name = param_name.replace(other_name, mcore_local_name)
52
+
53
+ return param_name
54
+
55
+
56
+ def _parse_real_layer_idx(param_name, num_layers_per_stage, pp_size, pp_rank):
57
+ """Map local (virtual) pipeline stage layer index to global layer index.
58
+
59
+ For virtual pipeline parallel, each pipeline stage is further divided into virtual stages.
60
+ The global layer index needs to account for both pipeline stage and virtual stage.
61
+
62
+ Args:
63
+ param_name (str): Parameter name containing layer index: layers.x.<submodule_name>/<vpp_stage>
64
+ num_layers_per_stage (int): Number of layers per pipeline stage
65
+ pp_size (int): Pipeline parallel size
66
+
67
+ Returns:
68
+ int: Global layer index accounting for both pipeline and virtual pipeline stages
69
+ """
70
+ # Extract local layer index from parameter name
71
+ layer_match = re.search(LAYER_IDX_PATTERN, param_name)
72
+ param_name, vpp_stage = param_name.split(Const.SCOPE_SEPARATOR)
73
+ if not layer_match:
74
+ return param_name
75
+
76
+ local_layer_idx = int(layer_match.group(1))
77
+ vpp_stage = int(vpp_stage)
78
+
79
+ # Calculate global layer index based on pipeline stage and virtual stage
80
+ real_layer_idx = local_layer_idx + (pp_size * vpp_stage + pp_rank) * num_layers_per_stage
81
+
82
+ return param_name.replace(f'layers.{local_layer_idx}', f'layers.{real_layer_idx}')
83
+
84
+
85
+ def _parse_real_expert_idx(param_name, num_experts_per_rank, exp_rank):
86
+ """Map local expert index to global expert index. TODO: shared expert
87
+
88
+ For expert parallel, experts are distributed across ranks. This function maps
89
+ the local expert index on a rank to its global index across all ranks.
90
+
91
+ Args:
92
+ param_name (str): Parameter name containing local expert index
93
+ num_experts_per_rank (int): Number of experts on each rank
94
+ exp_rank (int): Expert parallel rank
95
+
96
+ Returns:
97
+ str: Parameter name with local expert index replaced by global expert index
98
+ """
99
+ # Extract local layer index from parameter name
100
+ expert_match = re.search(EXPERT_IDX_PATTERN, param_name)
101
+ if not expert_match:
102
+ return param_name
103
+
104
+ local_expert_idx = int(expert_match.group(1))
105
+ # Calculate global layer index based on pipeline stage and virtual stage
106
+ real_experts_idx = local_expert_idx + exp_rank * num_experts_per_rank
107
+
108
+ return param_name.replace(f'experts.{local_expert_idx}', f'experts.{real_experts_idx}')
109
+
110
+
111
+ def _consolidate_tp_weights(weights: Dict) -> Dict:
112
+ """Consolidate weights from different tensor parallel ranks into combined tensors.
113
+
114
+ Args:
115
+ weights: Dictionary of weights with rank information in keys
116
+
117
+ Returns:
118
+ Dict: Consolidated weights without rank information
119
+ """
120
+ consolidated = {}
121
+ for key, tensors in weights.items():
122
+ if any([name in key for name in COLUMN_PARALLEL_PARAMS]):
123
+ # Column parallel - concatenate along input dimension (dim 0)
124
+ combined = np.concatenate(tensors, axis=0)
125
+ elif any([name in key for name in ROW_PARALLEL_PARAMS]):
126
+ # Row parallel - concatenate along output dimension (dim 1)
127
+ combined = np.concatenate(tensors, axis=1)
128
+ else:
129
+ # For other params, verify identical and use first
130
+ if not all(np.allclose(tensors[0], t) for t in tensors[1:]):
131
+ logger.warning(f"Inconsistent values for {key} across TP ranks")
132
+ combined = tensors[0]
133
+
134
+ consolidated[key] = combined
135
+ return consolidated
136
+
137
+
138
+ def _parse_num_layers_per_stage(tp_partition):
139
+ match = [re.findall(LAYER_IDX_PATTERN, key) for key in tp_partition.keys()]
140
+ layer_idx = [int(i[0]) for i in match if i]
141
+ num_layers_per_pipeline_stage = max(layer_idx) + 1
142
+
143
+ return num_layers_per_pipeline_stage
144
+
145
+
146
+ def parse_parallel_size(checkpoint_dir: str):
147
+ """Parse tensor, pipeline and expert parallel sizes from checkpoint filenames.
148
+
149
+ Args:
150
+ checkpoint_dir (str): Directory containing checkpoint files
151
+
152
+ Returns:
153
+ Namespace
154
+ """
155
+ # Find all rank directories
156
+ rank_dirs = [d for d in os.listdir(checkpoint_dir) if d.startswith('mp_rank_')]
157
+
158
+ if not rank_dirs:
159
+ raise ValueError(f"No checkpoint rank directories found in {checkpoint_dir}")
160
+
161
+ ckpt = FmkAdp.load_checkpoint(
162
+ os.path.join(checkpoint_dir, rank_dirs[0], 'model_optim_rng.pt'),
163
+ to_cpu=True,
164
+ weights_only=False)
165
+ args = ckpt[ARGS]
166
+ return (
167
+ args.tensor_model_parallel_size,
168
+ args.pipeline_model_parallel_size,
169
+ args.expert_model_parallel_size,
170
+ args.num_experts
171
+ )
172
+
173
+
174
+ def parse_iteration(checkpoint_path: str) -> Dict:
175
+ """
176
+ Parse the checkpoint iteration directory from a given checkpoint path.
177
+
178
+ If the path is a top-level checkpoint directory, this function reads the
179
+ 'latest_checkpointed_iteration.txt' file to determine the latest iteration.
180
+ If the path is already an iteration directory (e.g., 'iter_0000005'), it extracts
181
+ the iteration number from the path.
182
+
183
+ Args:
184
+ checkpoint_path (str): Path to the checkpoint directory or iteration directory.
185
+
186
+ Returns:
187
+ str: The full path to the checkpoint directory for the determined iteration.
188
+
189
+ Raises:
190
+ ValueError: If the checkpoint directory for the determined iteration does not exist.
191
+ """
192
+ iteration = None
193
+ tracker_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt")
194
+ if os.path.exists(tracker_file):
195
+ with FileOpen(tracker_file, 'r') as f:
196
+ latest_iteration = f.read().strip()
197
+ if latest_iteration != 'release':
198
+ try:
199
+ iteration = int(latest_iteration)
200
+ except Exception:
201
+ logger.warning(
202
+ f"The latest_checkpointed_iteration is supposed to be `release` or an int. \
203
+ But {latest_iteration} is found."
204
+ )
205
+ checkpoint_path = os.path.join(checkpoint_path, f'iter_{iteration:07d}')
206
+ else:
207
+ match = re.findall(ITER_DIR_PATTERN, checkpoint_path)
208
+ if match:
209
+ iteration = int(match[0])
210
+
211
+ # Checkpoint directory for this iteration
212
+ logger.info(f"Loaded checkpoint from iteration {iteration}")
213
+
214
+ if not os.path.exists(checkpoint_path):
215
+ raise ValueError(f"Checkpoint directory not found: {checkpoint_path}")
216
+
217
+ return checkpoint_path
218
+
219
+
220
+ def get_weights_from_state_dict(state_dict):
221
+ weights = {}
222
+ vpp_stage = 0
223
+ if 'model' in state_dict:
224
+ model_weights = state_dict['model']
225
+
226
+ for key, value in _get_parameter(model_weights):
227
+ key = _map_to_mcore_local_names(key)
228
+ weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value
229
+
230
+ elif 'model0' in state_dict:
231
+ #vpp enabled
232
+ while f'model{vpp_stage}' in state_dict:
233
+ model_weights = state_dict[f'model{vpp_stage}']
234
+ for key, value in _get_parameter(model_weights):
235
+ key = _map_to_mcore_local_names(key)
236
+ weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value
237
+ vpp_stage += 1
238
+ return weights
239
+
240
+
241
+ def load_megatron_weights(checkpoint_path: str) -> Dict:
242
+ """Load Megatron parallel checkpoint weights into a single dictionary.
243
+
244
+ Args:
245
+ checkpoint_path (str): Base checkpoint directory path
246
+
247
+ Returns:
248
+ combined_weights: Dict with weights from all ranks, keys include rank info
249
+ """
250
+ try:
251
+ import megatron
252
+ except ModuleNotFoundError as e:
253
+ raise ModuleNotFoundError("No module named 'megatron', which is required to load a megatron ckpt") from e
254
+
255
+ # Find latest iteration if not specified
256
+ checkpoint_path = parse_iteration(checkpoint_path)
257
+
258
+ # Parse parallel sizes from checkpoint directory structure
259
+ tp_size, pp_size, exp_size, num_experts = parse_parallel_size(checkpoint_path)
260
+ combined_weights = {}
261
+
262
+ # Load checkpoints from all ranks
263
+ for exp_rank in range(exp_size):
264
+ num_layers_per_pipeline_stage = 0
265
+ for pp_rank in range(pp_size):
266
+ tp_partition = defaultdict(list)
267
+ for tp_rank in range(tp_size):
268
+ # Construct checkpoint path based on parallel ranks
269
+ if pp_size > 1:
270
+ rank_dir = f'mp_rank_{tp_rank:02d}_{pp_rank:03d}'
271
+ else:
272
+ rank_dir = f'mp_rank_{tp_rank:02d}'
273
+
274
+ if exp_size > 1:
275
+ rank_dir = f'{rank_dir}_{exp_rank:03d}'
276
+
277
+ ckpt_file = os.path.join(checkpoint_path, rank_dir, 'model_optim_rng.pt')
278
+ try:
279
+ state_dict = FmkAdp.load_checkpoint(ckpt_file, to_cpu=True, weights_only=False)
280
+ partition = get_weights_from_state_dict(state_dict)
281
+ for key, weight in partition.items():
282
+ tp_partition[key].append(weight)
283
+
284
+ except Exception as load_error:
285
+ logger.warning(f"Error loading {ckpt_file}: {load_error}")
286
+
287
+ if not tp_partition:
288
+ raise ValueError('No state loaded.')
289
+
290
+ if not num_layers_per_pipeline_stage:
291
+ num_layers_per_pipeline_stage = _parse_num_layers_per_stage(tp_partition)
292
+
293
+ consolidated_weight = _consolidate_tp_weights(tp_partition)
294
+ for key, value in consolidated_weight.items():
295
+ key = _parse_real_layer_idx(key, num_layers_per_pipeline_stage, pp_size, pp_rank)
296
+ if num_experts:
297
+ key = _parse_real_expert_idx(key, num_experts // exp_size, exp_rank)
298
+ combined_weights[key] = value
299
+
300
+ logger.info(f"Found {len(combined_weights)} total parameters across all ranks")
301
+
302
+ return combined_weights
@@ -0,0 +1,83 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+
18
+ from msprobe.core.common.log import logger
19
+ from msprobe.core.compare.npy_compare import CompareOps
20
+
21
+
22
+
23
+ def in_different_shape(a, b):
24
+ if a.shape != b.shape:
25
+ logger.warning(f"a, b are in different shape. a: {a.shape}, b: {b.shape}")
26
+ return True
27
+ return False
28
+
29
+
30
+ def l2_distance(a, b):
31
+ if a is None or b is None:
32
+ return None
33
+ if in_different_shape(a, b):
34
+ return None
35
+ return np.linalg.norm(a - b).item()
36
+
37
+
38
+ def cos_sim(a, b):
39
+ if a is None or b is None:
40
+ return None
41
+
42
+ if in_different_shape(a, b):
43
+ return None
44
+ if a.ndim > 0:
45
+ a = a.flatten().squeeze()
46
+ b = b.flatten().squeeze()
47
+
48
+ num = a.dot(b)
49
+ a_norm = np.linalg.norm(a)
50
+ b_norm = np.linalg.norm(b)
51
+
52
+ if a_norm == 0 and b_norm == 0:
53
+ return 1.
54
+ if a_norm == 0 or b_norm == 0:
55
+ logger.warning(f'One tensor norm is zero.')
56
+ return None
57
+
58
+ sim = num / (a_norm * b_norm)
59
+
60
+ return sim.item()
61
+
62
+
63
+ def numel(a, b):
64
+ n1 = a.size
65
+ n2 = b.size
66
+ if n1 != n2:
67
+ logger.warning('parameters have different number of element')
68
+ return (n1, n2)
69
+ return n1
70
+
71
+
72
+ def shape(a, b):
73
+ if in_different_shape(a, b):
74
+ return [list(a.shape), list(b.shape)]
75
+ return list(a.shape)
76
+
77
+
78
+ METRIC_FUNC = {
79
+ 'l2': l2_distance,
80
+ 'cos': cos_sim,
81
+ 'numel': numel,
82
+ 'shape': shape
83
+ }
@@ -0,0 +1,12 @@
1
+ self_attention.linear_qkv.layer_norm_: input_layernorm.
2
+ language_model.: ''
3
+ encoder: decoder
4
+ .input_norm.: .input_layernorm.
5
+ query_key_value: linear_qkv
6
+ .dense.: .linear_proj.
7
+ post_attention_norm: pre_mlp_layernorm
8
+ dense_h_to_4h: linear_fc1
9
+ dense_4h_to_h: linear_fc2
10
+ mlp.local_experts: mlp.experts.local_experts
11
+ final_norm: final_layernorm
12
+ word_embeddings_for_head: output_layer
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.config_check.config_checker import ConfigChecker
17
+ from msprobe.core.config_check.ckpt_compare.ckpt_comparator import compare_checkpoints
18
+ from msprobe.core.common.log import logger
19
+
20
+
21
+ def pack(shell_path, output_path, framework):
22
+ ConfigChecker(shell_path=shell_path, output_zip_path=output_path, fmk=framework)
23
+
24
+
25
+ def compare(bench_zip_path, cmp_zip_path, output_path, framework):
26
+ ConfigChecker.compare(bench_zip_path, cmp_zip_path, output_path, framework)
27
+
28
+
29
+ def _config_checking_parser(parser):
30
+ parser.add_argument('-d', '--dump', nargs='*', help='Collect the train config into a zip file')
31
+ parser.add_argument('-c', '--compare', nargs=2, help='Compare two zip files or checkpoints')
32
+ parser.add_argument('-o', '--output', help='output path, default is current directory')
33
+
34
+
35
+ def _run_config_checking_command(args):
36
+ if args.dump is not None:
37
+ output_dirpath = args.output if args.output else "./config_check_pack.zip"
38
+ pack(args.dump, output_dirpath, args.framework)
39
+ elif args.compare:
40
+ if args.compare[0].endswith('zip'):
41
+ logger.info('The input paths is zip files, comparing packed config.')
42
+ output_dirpath = args.output if args.output else "./config_check_result"
43
+ compare(args.compare[0], args.compare[1], output_dirpath, args.framework)
44
+ else:
45
+ logger.info('Comparing model checkpoint.')
46
+ output_dirpath = args.output if args.output else "./ckpt_similarity.json"
47
+ compare_checkpoints(args.compare[0], args.compare[1], output_dirpath)
48
+
49
+ else:
50
+ logger.error("The param is not correct, you need to give '-d' for dump or '-c' for compare.")
51
+ raise Exception("The param is not correct, you need to give '-d' for dump or '-c' for compare.")
@@ -0,0 +1,100 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import shutil
18
+
19
+ import pandas as pd
20
+
21
+ from msprobe.core.common.file_utils import save_excel, split_zip_file_path, \
22
+ create_directory, extract_zip
23
+ from msprobe.core.common.framework_adapter import FmkAdp
24
+ from msprobe.core.config_check.checkers.base_checker import PackInput
25
+ from msprobe.core.config_check.utils.utils import config_checking_print
26
+ from msprobe.core.common.const import Const
27
+
28
+
29
+ class ConfigChecker:
30
+ checkers = {}
31
+ pre_forward_fun_list = []
32
+ result_filename = "result.xlsx"
33
+ result_header = ["filename", "pass_check"]
34
+ step = 0
35
+
36
+ def __init__(self, model=None, shell_path=None, output_zip_path="./config_check_pack.zip", fmk="pytorch"):
37
+ FmkAdp.set_fmk(fmk)
38
+ self.pack_input = PackInput(output_zip_path, model, shell_path)
39
+ file_path, file_name = split_zip_file_path(self.pack_input.output_zip_path)
40
+ if not os.path.exists(file_path):
41
+ create_directory(file_path)
42
+ self.pack()
43
+
44
+ @staticmethod
45
+ def compare(bench_zip_path, cmp_zip_path, output_path, fmk=Const.PT_FRAMEWORK):
46
+ if os.path.exists(output_path):
47
+ shutil.rmtree(output_path)
48
+ bench_dir = os.path.join(output_path, "bench")
49
+ cmp_dir = os.path.join(output_path, "cmp")
50
+ extract_zip(bench_zip_path, bench_dir)
51
+ config_checking_print(f"extract zip file {bench_zip_path} to {bench_dir}")
52
+ extract_zip(cmp_zip_path, cmp_dir)
53
+ config_checking_print(f"extract zip file {cmp_zip_path} to {cmp_dir}")
54
+
55
+ result = []
56
+ summary_result = []
57
+ for checker in ConfigChecker.checkers.values():
58
+ checker_name, pass_check, df = checker.compare_ex(bench_dir, cmp_dir, output_path, fmk)
59
+ if checker_name:
60
+ summary_result.append([checker_name, pass_check])
61
+ if df is not None:
62
+ result.append((df, checker_name))
63
+ summary_result_df = pd.DataFrame(summary_result, columns=ConfigChecker.result_header)
64
+ result.insert(0, (summary_result_df, "summary"))
65
+ save_excel(os.path.join(output_path, ConfigChecker.result_filename), result)
66
+ config_checking_print(f"config checking result save to {os.path.realpath(output_path)}")
67
+
68
+ @staticmethod
69
+ def apply_patches(fmk=Const.PT_FRAMEWORK):
70
+ for checker in ConfigChecker.checkers.values():
71
+ checker.apply_patches(fmk)
72
+
73
+ def pack(self):
74
+ config_checking_print(f"pack result zip path {os.path.realpath(self.pack_input.output_zip_path)}")
75
+
76
+ def hook(model, args, kwargs):
77
+ for collect_func in self.pre_forward_fun_list:
78
+ collect_func(model, args, kwargs, ConfigChecker.step)
79
+ ConfigChecker.step += 1
80
+
81
+ if self.pack_input.model:
82
+ FmkAdp.register_forward_pre_hook(self.pack_input.model, hook, with_kwargs=True)
83
+ for checker in ConfigChecker.checkers.values():
84
+ if checker.input_needed and not getattr(self.pack_input, checker.input_needed):
85
+ continue
86
+ if FmkAdp.is_initialized() and FmkAdp.get_rank() != 0 and not checker.multi_rank:
87
+ continue
88
+ checker.pack(self.pack_input)
89
+
90
+
91
+ def register_checker_item(key, cls=None):
92
+ if cls is None:
93
+ # 无参数时,返回装饰器函数
94
+ return lambda cls: register_checker_item(key, cls)
95
+ ConfigChecker.checkers[key] = cls
96
+ return cls
97
+
98
+
99
+ def register_pre_forward_fun_list(func):
100
+ ConfigChecker.pre_forward_fun_list.append(func)
@@ -13,7 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from msprobe.pytorch.parse_tool import cli
17
-
18
- if __name__ == '__main__':
19
- cli.parse()
16
+ dependency:
17
+ - transformers
18
+ - deepspeed
19
+ - megatron
20
+ - numpy
21
+ - datasets
22
+ - peft
@@ -0,0 +1,57 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ HCCL_DETERMINISTIC:
17
+ npu:
18
+ name: HCCL_DETERMINISTIC
19
+ default_value: False
20
+ gpu:
21
+ name: NCCL_DETERMINISTIC
22
+ default_value: False
23
+
24
+ HCCL_ALGO:
25
+ npu:
26
+ name: HCCL_ALGO
27
+ default_value: None
28
+ gpu:
29
+ name: NCCL_ALGO
30
+ default_value: None
31
+
32
+ HCCL_INTRA_ROCE_ENABLE:
33
+ npu:
34
+ name: HCCL_INTRA_ROCE_ENABLE
35
+ default_value: 0
36
+
37
+
38
+ HCCL_INTRA_PICE_ENABLE:
39
+ npu:
40
+ name: HCCL_INTRA_ROCE_ENABLE
41
+ default_value: 1
42
+
43
+ ASCEND_LAUNCH_BLOCKING:
44
+ npu:
45
+ name: ASCEND_LAUNCH_BLOCKING
46
+ default_value: 0
47
+ gpu:
48
+ name: CUDA_LAUNCH_BLOCKING
49
+ default_value: 0
50
+
51
+ ASCEND_RT_VISIBLE_DEVICES:
52
+ npu:
53
+ name: ASCEND_RT_VISIBLE_DEVICES
54
+ default_value: None
55
+ gpu:
56
+ name: CUDA_VISIBLE_DEVICES
57
+ default_value: None
@@ -0,0 +1,21 @@
1
+ learning_rate:
2
+ - lr
3
+ - learningrate
4
+
5
+ batch_size:
6
+ - batch
7
+ - bs
8
+ - batch_size_per_gpu
9
+
10
+ epochs:
11
+ - num_epochs
12
+ - max_epochs
13
+ - epoch
14
+
15
+ weight_decay:
16
+ - wd
17
+ - weightdecay
18
+
19
+ dropout_rate:
20
+ - dropout
21
+ - drop_rate